use std::fmt;
use std::net::SocketAddr;
use std::str::FromStr;
use thiserror::Error;
#[derive(Debug, Clone, PartialEq, Eq, Error)]
pub enum EndpointParseError {
#[error("invalid endpoint format: {input}")]
InvalidFormat {
input: String,
},
#[error("missing port in endpoint: {input}")]
MissingPort {
input: String,
},
#[error("invalid port '{port}' in endpoint")]
InvalidPort {
port: String,
},
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
#[must_use]
pub struct Endpoint {
host: String,
port: u16,
}
impl Endpoint {
pub fn new(host: impl Into<String>, port: u16) -> Self {
Self {
host: host.into(),
port,
}
}
pub fn from_socket_addr(addr: SocketAddr) -> Self {
Self {
host: addr.ip().to_string(),
port: addr.port(),
}
}
pub fn parse(s: &str) -> Result<Self, EndpointParseError> {
let s = s.trim();
if s.is_empty() {
return Err(EndpointParseError::InvalidFormat {
input: s.to_string(),
});
}
if s.starts_with('[') {
if let Some(bracket_end) = s.find(']') {
let host = &s[1..bracket_end];
let rest = &s[bracket_end + 1..];
if let Some(port_str) = rest.strip_prefix(':') {
let port = port_str.parse::<u16>().map_err(|_| {
EndpointParseError::InvalidPort {
port: port_str.to_string(),
}
})?;
return Ok(Self {
host: host.to_string(),
port,
});
}
}
return Err(EndpointParseError::InvalidFormat {
input: s.to_string(),
});
}
let colon_pos = s.rfind(':').ok_or_else(|| EndpointParseError::MissingPort {
input: s.to_string(),
})?;
let host = &s[..colon_pos];
let port_str = &s[colon_pos + 1..];
if host.is_empty() {
return Err(EndpointParseError::InvalidFormat {
input: s.to_string(),
});
}
let port = port_str
.parse::<u16>()
.map_err(|_| EndpointParseError::InvalidPort {
port: port_str.to_string(),
})?;
Ok(Self {
host: host.to_string(),
port,
})
}
pub fn host(&self) -> &str {
&self.host
}
pub fn port(&self) -> u16 {
self.port
}
pub fn to_socket_addr(&self) -> Option<SocketAddr> {
if self.host.contains(':') {
format!("[{}]:{}", self.host, self.port).parse().ok()
} else {
format!("{}:{}", self.host, self.port).parse().ok()
}
}
}
impl fmt::Display for Endpoint {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
if self.host.contains(':') {
write!(f, "[{}]:{}", self.host, self.port)
} else {
write!(f, "{}:{}", self.host, self.port)
}
}
}
impl FromStr for Endpoint {
type Err = EndpointParseError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
Self::parse(s)
}
}
impl From<SocketAddr> for Endpoint {
fn from(addr: SocketAddr) -> Self {
Self::from_socket_addr(addr)
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::net::{IpAddr, Ipv4Addr};
#[test]
fn test_endpoint_parse_ipv4() {
let endpoint = Endpoint::parse("127.0.0.1:8080").unwrap();
assert_eq!(endpoint.host(), "127.0.0.1");
assert_eq!(endpoint.port(), 8080);
assert!(endpoint.to_socket_addr().is_some());
}
#[test]
fn test_endpoint_parse_ipv6() {
let endpoint = Endpoint::parse("[::1]:9000").unwrap();
assert_eq!(endpoint.host(), "::1");
assert_eq!(endpoint.port(), 9000);
assert!(endpoint.to_socket_addr().is_some());
}
#[test]
fn test_endpoint_parse_hostname() {
let endpoint = Endpoint::parse("localhost:8080").unwrap();
assert_eq!(endpoint.host(), "localhost");
assert_eq!(endpoint.port(), 8080);
assert!(endpoint.to_socket_addr().is_none());
}
#[test]
fn test_endpoint_parse_fqdn() {
let endpoint = Endpoint::parse("alice.prod.internal:8080").unwrap();
assert_eq!(endpoint.host(), "alice.prod.internal");
assert_eq!(endpoint.port(), 8080);
}
#[test]
fn test_endpoint_display() {
let endpoint = Endpoint::parse("192.168.1.1:443").unwrap();
assert_eq!(endpoint.to_string(), "192.168.1.1:443");
let endpoint_v6 = Endpoint::parse("[::1]:9000").unwrap();
assert_eq!(endpoint_v6.to_string(), "[::1]:9000");
}
#[test]
fn test_endpoint_from_socket_addr() {
let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)), 3000);
let endpoint = Endpoint::from(addr);
assert_eq!(endpoint.host(), "10.0.0.1");
assert_eq!(endpoint.port(), 3000);
}
#[test]
fn test_endpoint_new() {
let endpoint = Endpoint::new("my-service", 8080);
assert_eq!(endpoint.host(), "my-service");
assert_eq!(endpoint.port(), 8080);
}
#[test]
fn test_endpoint_from_str() {
let endpoint: Endpoint = "127.0.0.1:80".parse().unwrap();
assert_eq!(endpoint.port(), 80);
}
#[test]
fn test_endpoint_parse_invalid() {
assert!(Endpoint::parse("").is_err());
assert!(Endpoint::parse("127.0.0.1").is_err()); assert!(Endpoint::parse(":8080").is_err()); assert!(Endpoint::parse("host:not_a_port").is_err());
}
#[test]
fn test_endpoint_equality() {
let e1 = Endpoint::parse("127.0.0.1:8080").unwrap();
let e2 = Endpoint::parse("127.0.0.1:8080").unwrap();
let e3 = Endpoint::parse("127.0.0.1:8081").unwrap();
assert_eq!(e1, e2);
assert_ne!(e1, e3);
}
#[test]
fn test_endpoint_hash() {
use std::collections::HashSet;
let mut set = HashSet::new();
set.insert(Endpoint::parse("127.0.0.1:8080").unwrap());
assert!(set.contains(&Endpoint::parse("127.0.0.1:8080").unwrap()));
assert!(!set.contains(&Endpoint::parse("127.0.0.1:8081").unwrap()));
}
}