use alloc::{str::FromStr, string::String};
use core::{
fmt::{self, Debug, Display, Write},
net::IpAddr,
ops::Deref,
};
use bytestring::ByteString;
use percent_encoding::{NON_ALPHANUMERIC, percent_decode_str, percent_encode};
use serde::{Deserialize, Deserializer, Serialize, Serializer, de};
use url::Url;
#[derive(Clone, PartialEq, Eq)]
pub struct ServerAddr {
protocol: Protocol,
transport: Transport,
host: Host,
port: u16,
username: ByteString,
password: ByteString,
}
#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)]
pub enum Protocol {
PossiblyPlain,
TLS,
}
#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)]
pub enum Transport {
TCP,
Websocket,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum Host {
Ip(IpAddr),
Dns(ByteString),
}
impl ServerAddr {
pub fn protocol(&self) -> Protocol {
self.protocol
}
pub fn transport(&self) -> Transport {
self.transport
}
pub fn host(&self) -> &Host {
&self.host
}
pub fn port(&self) -> u16 {
self.port
}
fn is_default_port(&self) -> bool {
self.port == protocol_transport_to_port(self.protocol, self.transport)
}
pub fn username(&self) -> Option<&str> {
if self.username.is_empty() {
None
} else {
Some(&self.username)
}
}
pub fn password(&self) -> Option<&str> {
if self.password.is_empty() {
None
} else {
Some(&self.password)
}
}
}
impl FromStr for ServerAddr {
type Err = ServerAddrError;
fn from_str(value: &str) -> Result<Self, Self::Err> {
let url = value.parse::<Url>().map_err(ServerAddrError::InvalidUrl)?;
let (protocol, transport) = match url.scheme() {
"nats" => (Protocol::PossiblyPlain, Transport::TCP),
"tls" => (Protocol::TLS, Transport::TCP),
"ws" => (Protocol::PossiblyPlain, Transport::Websocket),
"wss" => (Protocol::TLS, Transport::Websocket),
_ => return Err(ServerAddrError::InvalidScheme),
};
let host = match url.host() {
Some(url::Host::Ipv4(addr)) => Host::Ip(IpAddr::V4(addr)),
Some(url::Host::Ipv6(addr)) => Host::Ip(IpAddr::V6(addr)),
Some(url::Host::Domain(host)) => {
let host = host
.strip_prefix('[')
.and_then(|host| host.strip_suffix(']'))
.unwrap_or(host);
match host.parse::<IpAddr>() {
Ok(ip) => Host::Ip(ip),
Err(_) => Host::Dns(host.into()),
}
}
None => return Err(ServerAddrError::MissingHost),
};
let port = if let Some(port) = url.port() {
port
} else {
protocol_transport_to_port(protocol, transport)
};
let username = percent_decode_str(url.username())
.decode_utf8()
.map_err(|_| ServerAddrError::UsernameInvalidUtf8)?
.deref()
.into();
let password = percent_decode_str(url.password().unwrap_or_default())
.decode_utf8()
.map_err(|_| ServerAddrError::PasswordInvalidUtf8)?
.deref()
.into();
Ok(Self {
protocol,
transport,
host,
port,
username,
password,
})
}
}
impl Debug for ServerAddr {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let username = if self.username.is_empty() {
"<none>"
} else {
"<redacted>"
};
let password = if self.password.is_empty() {
"<none>"
} else {
"<redacted>"
};
f.debug_struct("ServerAddr")
.field("protocol", &self.protocol)
.field("transport", &self.transport)
.field("host", &self.host)
.field("port", &self.port)
.field("username", &username)
.field("password", &password)
.finish()
}
}
impl Display for ServerAddr {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(match (self.protocol, self.transport) {
(Protocol::PossiblyPlain, Transport::TCP) => "nats",
(Protocol::TLS, Transport::TCP) => "tls",
(Protocol::PossiblyPlain, Transport::Websocket) => "ws",
(Protocol::TLS, Transport::Websocket) => "wss",
})?;
f.write_str("://")?;
if let Some(username) = self.username() {
Display::fmt(&percent_encode(username.as_bytes(), NON_ALPHANUMERIC), f)?;
if let Some(password) = self.password() {
write!(
f,
":{}",
percent_encode(password.as_bytes(), NON_ALPHANUMERIC)
)?;
}
f.write_char('@')?;
}
match &self.host {
Host::Ip(IpAddr::V4(addr)) => Display::fmt(addr, f)?,
Host::Ip(IpAddr::V6(addr)) => write!(f, "[{addr}]")?,
Host::Dns(record) => Display::fmt(record, f)?,
}
if !self.is_default_port() {
write!(f, ":{}", self.port)?;
}
Ok(())
}
}
impl<'de> Deserialize<'de> for ServerAddr {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
let val = String::deserialize(deserializer)?;
val.parse().map_err(de::Error::custom)
}
}
impl Serialize for ServerAddr {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
serializer.collect_str(self)
}
}
#[derive(Debug, thiserror::Error)]
pub enum ServerAddrError {
#[error("invalid Url")]
InvalidUrl(#[source] url::ParseError),
#[error("invalid Url scheme")]
InvalidScheme,
#[error("missing host")]
MissingHost,
#[error("username is not utf-8")]
UsernameInvalidUtf8,
#[error("password is not utf-8")]
PasswordInvalidUtf8,
}
fn protocol_transport_to_port(protocol: Protocol, transport: Transport) -> u16 {
match (protocol, transport) {
(Protocol::PossiblyPlain | Protocol::TLS, Transport::TCP) => 4222,
(Protocol::PossiblyPlain, Transport::Websocket) => 80,
(Protocol::TLS, Transport::Websocket) => 443,
}
}
#[cfg(test)]
mod tests {
use alloc::string::ToString;
use core::net::{IpAddr, Ipv4Addr, Ipv6Addr};
use super::{Host, Protocol, ServerAddr, Transport};
#[test]
fn nats() {
let server_addr = "nats://127.0.0.1".parse::<ServerAddr>().unwrap();
assert_eq!(server_addr.transport(), Transport::TCP);
assert_eq!(server_addr.protocol(), Protocol::PossiblyPlain);
assert_eq!(
server_addr.host(),
&Host::Ip(IpAddr::V4(Ipv4Addr::LOCALHOST))
);
assert_eq!(server_addr.port(), 4222);
assert_eq!(server_addr.username(), None);
assert_eq!(server_addr.password(), None);
assert_eq!(server_addr.to_string(), "nats://127.0.0.1");
}
#[test]
fn nats_non_default_port() {
let server_addr = "nats://127.0.0.1:4321".parse::<ServerAddr>().unwrap();
assert_eq!(server_addr.transport(), Transport::TCP);
assert_eq!(server_addr.protocol(), Protocol::PossiblyPlain);
assert_eq!(
server_addr.host(),
&Host::Ip(IpAddr::V4(Ipv4Addr::LOCALHOST))
);
assert_eq!(server_addr.port(), 4321);
assert_eq!(server_addr.username(), None);
assert_eq!(server_addr.password(), None);
assert_eq!(server_addr.to_string(), "nats://127.0.0.1:4321");
}
#[test]
fn nats_ipv6() {
let server_addr = "nats://[::1]".parse::<ServerAddr>().unwrap();
assert_eq!(server_addr.transport(), Transport::TCP);
assert_eq!(server_addr.protocol(), Protocol::PossiblyPlain);
assert_eq!(
server_addr.host(),
&Host::Ip(IpAddr::V6(Ipv6Addr::LOCALHOST))
);
assert_eq!(server_addr.port(), 4222);
assert_eq!(server_addr.username(), None);
assert_eq!(server_addr.password(), None);
assert_eq!(server_addr.to_string(), "nats://[::1]");
}
#[test]
fn tls() {
let server_addr = "tls://127.0.0.1".parse::<ServerAddr>().unwrap();
assert_eq!(server_addr.transport(), Transport::TCP);
assert_eq!(server_addr.protocol(), Protocol::TLS);
assert_eq!(
server_addr.host(),
&Host::Ip(IpAddr::V4(Ipv4Addr::LOCALHOST))
);
assert_eq!(server_addr.port(), 4222);
assert_eq!(server_addr.username(), None);
assert_eq!(server_addr.password(), None);
assert_eq!(server_addr.to_string(), "tls://127.0.0.1");
}
#[test]
fn ws() {
let server_addr = "ws://127.0.0.1".parse::<ServerAddr>().unwrap();
assert_eq!(server_addr.transport(), Transport::Websocket);
assert_eq!(server_addr.protocol(), Protocol::PossiblyPlain);
assert_eq!(
server_addr.host(),
&Host::Ip(IpAddr::V4(Ipv4Addr::LOCALHOST))
);
assert_eq!(server_addr.port(), 80);
assert_eq!(server_addr.username(), None);
assert_eq!(server_addr.password(), None);
assert_eq!(server_addr.to_string(), "ws://127.0.0.1");
}
#[test]
fn wss() {
let server_addr = "wss://127.0.0.1".parse::<ServerAddr>().unwrap();
assert_eq!(server_addr.transport(), Transport::Websocket);
assert_eq!(server_addr.protocol(), Protocol::TLS);
assert_eq!(
server_addr.host(),
&Host::Ip(IpAddr::V4(Ipv4Addr::LOCALHOST))
);
assert_eq!(server_addr.port(), 443);
assert_eq!(server_addr.username(), None);
assert_eq!(server_addr.password(), None);
assert_eq!(server_addr.to_string(), "wss://127.0.0.1");
}
}