1#![warn(missing_docs)]
8
9use std::{
10 fmt,
11 net::{Ipv4Addr, Ipv6Addr, SocketAddrV4, SocketAddrV6},
12 num::ParseIntError,
13 str::FromStr,
14};
15
16#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Debug)]
20pub struct DomainName {
21 bytes: Vec<u8>,
22}
23
24impl AsRef<[u8]> for DomainName {
25 fn as_ref(&self) -> &[u8] {
26 self.bytes.as_ref()
27 }
28}
29
30#[derive(Clone, PartialEq, Eq, Debug)]
38pub enum ParseDomainNameError {
39 TooLong {
43 max: usize,
45 len: usize,
47 },
48}
49
50impl fmt::Display for ParseDomainNameError {
51 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
52 match self {
53 Self::TooLong { max, len } => write!(f, "length must be <={}, but {}", max, len),
54 }
55 }
56}
57
58impl std::error::Error for ParseDomainNameError {}
59
60impl DomainName {
61 const MAX_LEN: usize = 0xff;
62
63 pub fn new(bytes: Vec<u8>) -> std::result::Result<Self, ParseDomainNameError> {
67 if bytes.len() > Self::MAX_LEN {
68 return Err(ParseDomainNameError::TooLong {
69 max: Self::MAX_LEN,
70 len: bytes.len(),
71 });
72 }
73 Ok(Self { bytes })
74 }
75}
76
77impl fmt::Display for DomainName {
78 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
79 String::from_utf8_lossy(self.bytes.as_ref()).fmt(f)
80 }
81}
82
83impl FromStr for DomainName {
84 type Err = ParseDomainNameError;
85
86 fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
87 Ok(Self::new(s.as_bytes().to_vec())?)
88 }
89}
90
91#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Debug)]
96pub enum Addr {
97 Ipv4(Ipv4Addr),
99
100 DomainName(DomainName),
102
103 Ipv6(Ipv6Addr),
105}
106
107impl fmt::Display for Addr {
108 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
109 match self {
110 Addr::Ipv4(addr) => addr.fmt(f),
111 Addr::DomainName(addr) => addr.fmt(f),
112 Addr::Ipv6(addr) => addr.fmt(f),
113 }
114 }
115}
116
117#[derive(Clone, PartialEq, Eq, Debug)]
128pub enum ParseAddrError {
129 DomainName(ParseDomainNameError),
132}
133
134impl fmt::Display for ParseAddrError {
135 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
136 match self {
137 Self::DomainName(err) => err.fmt(f),
138 }
139 }
140}
141
142impl From<ParseDomainNameError> for ParseAddrError {
143 fn from(err: ParseDomainNameError) -> Self {
144 Self::DomainName(err)
145 }
146}
147
148impl std::error::Error for ParseAddrError {}
149
150impl FromStr for Addr {
151 type Err = ParseAddrError;
152
153 fn from_str(s: &str) -> Result<Self, Self::Err> {
154 if let Ok(addr) = s.parse::<Ipv4Addr>() {
155 return Ok(Self::Ipv4(addr));
156 }
157 if let Ok(addr) = s.parse::<Ipv6Addr>() {
158 return Ok(Self::Ipv6(addr));
159 }
160 Ok(Self::DomainName(DomainName::new(s.as_bytes().to_vec())?))
161 }
162}
163
164type Port = u16;
165
166#[derive(Clone, PartialEq, Eq, Hash, Debug)]
172pub struct SocketDomainName {
173 domain_name: DomainName,
174 port: Port,
175}
176
177impl SocketDomainName {
178 pub fn new(domain_name: DomainName, port: Port) -> Self {
180 Self { domain_name, port }
181 }
182
183 pub fn domain_name(&self) -> &DomainName {
185 &self.domain_name
186 }
187
188 pub fn port(&self) -> Port {
190 self.port
191 }
192}
193
194impl fmt::Display for SocketDomainName {
195 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
196 write!(f, "{}:{}", self.domain_name, self.port)
197 }
198}
199
200#[derive(Clone, PartialEq, Eq, Hash, Debug)]
206pub enum SocketAddr {
207 Ipv4(SocketAddrV4),
209
210 DomainName(SocketDomainName),
212
213 Ipv6(SocketAddrV6),
215}
216
217impl SocketAddr {
218 pub fn new(addr: Addr, port: Port) -> Self {
220 match addr {
221 Addr::Ipv4(addr) => Self::Ipv4(SocketAddrV4::new(addr, port)),
222 Addr::DomainName(addr) => Self::DomainName(SocketDomainName::new(addr, port)),
223 Addr::Ipv6(addr) => Self::Ipv6(SocketAddrV6::new(addr, port, 0, 0)),
224 }
225 }
226}
227
228impl fmt::Display for SocketAddr {
229 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
230 match self {
231 SocketAddr::Ipv4(addr) => addr.fmt(f),
232 SocketAddr::DomainName(addr) => addr.fmt(f),
233 SocketAddr::Ipv6(addr) => addr.fmt(f),
234 }
235 }
236}
237
238#[derive(Clone, PartialEq, Eq, Debug)]
246pub enum ParseSocketAddrError {
247 PortNotFound,
249
250 InvalidPort(ParseIntError),
254
255 InvalidAddr(ParseAddrError),
260}
261
262impl fmt::Display for ParseSocketAddrError {
263 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
264 match self {
265 Self::PortNotFound => write!(f, "port not found"),
266 Self::InvalidPort(err) => err.fmt(f),
267 Self::InvalidAddr(err) => err.fmt(f),
268 }
269 }
270}
271
272impl std::error::Error for ParseSocketAddrError {}
273
274impl From<ParseIntError> for ParseSocketAddrError {
275 fn from(err: ParseIntError) -> Self {
276 Self::InvalidPort(err)
277 }
278}
279
280impl From<ParseAddrError> for ParseSocketAddrError {
281 fn from(err: ParseAddrError) -> Self {
282 Self::InvalidAddr(err)
283 }
284}
285
286impl FromStr for SocketAddr {
287 type Err = ParseSocketAddrError;
288
289 fn from_str(s: &str) -> Result<Self, Self::Err> {
290 let colon = s.rfind(':');
291 if colon.is_none() {
292 return Err(Self::Err::PortNotFound);
293 }
294 let colon = colon.unwrap();
295
296 let mut addr_part = &s[..colon];
297 if addr_part.starts_with('[') && addr_part.ends_with(']') {
298 addr_part = &addr_part[1..addr_part.len() - 1];
299 }
300 let port_part = &s[colon + 1..];
301
302 let port = port_part.parse::<Port>()?;
303 let addr = addr_part.parse()?;
304 Ok(Self::new(addr, port))
305 }
306}
307
308#[test]
309fn test_socket_addr_ext_from_str() {
310 let test: SocketAddr = "www.example.net:8080".parse().unwrap();
311 let domain = SocketDomainName::new(DomainName::new(b"www.example.net".to_vec()).unwrap(), 8080);
312 let expected = SocketAddr::DomainName(domain);
313 assert_eq!(test, expected);
314}