1#![allow(clippy::arithmetic_side_effects)]
4use std::{
17 cmp::Ordering,
18 error, fmt,
19 hash::{Hash, Hasher},
20 io,
21 mem::size_of,
22 net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr},
23 ops::Deref,
24 slice,
25 str::{self, FromStr},
26};
27
28use anyhow::Result;
29use buggy::Bug;
30use serde::{
31 de::{self, Visitor},
32 Deserialize, Deserializer, Serialize, Serializer,
33};
34use tokio::net::{self, ToSocketAddrs};
35use tracing::{debug, instrument};
36
37macro_rules! const_assert {
38 ($($tt:tt)*) => {
39 const _: () = assert!($($tt)*);
40 }
41}
42
43#[derive(Copy, Clone, Debug, Hash, Eq, PartialEq, Ord, PartialOrd)]
52pub struct Addr {
53 host: Host,
54 port: u16,
55}
56const_assert!(size_of::<Addr>() == 256);
57
58impl Addr {
59 pub fn new<T>(host: T, port: u16) -> Result<Self, AddrError>
71 where
72 T: AsRef<str>,
73 {
74 let host = host.as_ref();
75 let host = Host::from_domain(host)
76 .or_else(|| host.parse::<Ipv4Addr>().ok().map(Into::into))
77 .or_else(|| host.parse::<Ipv6Addr>().ok().map(Into::into))
78 .ok_or(AddrError::InvalidAddr(
79 "not a valid domain name or IP address",
80 ))?;
81 Ok(Self { host, port })
82 }
83
84 pub fn host(&self) -> &str {
86 &self.host
87 }
88
89 pub fn port(&self) -> u16 {
91 self.port
92 }
93
94 #[instrument(skip_all, fields(host = %self))]
98 pub async fn lookup(&self) -> io::Result<impl Iterator<Item = SocketAddr> + '_> {
99 debug!("performing DNS lookup");
100 net::lookup_host(Into::<(&str, u16)>::into(self)).await
101 }
102
103 pub fn to_socket_addrs(&self) -> impl ToSocketAddrs + '_ {
105 Into::<(&str, u16)>::into(self)
106 }
107}
108
109impl<'a> From<&'a Addr> for (&'a str, u16) {
110 fn from(addr: &'a Addr) -> Self {
111 (&addr.host, addr.port)
112 }
113}
114
115impl<T> From<T> for Addr
116where
117 T: Into<SocketAddr>,
118{
119 fn from(value: T) -> Self {
120 let addr = value.into();
121 Self {
122 host: addr.ip().into(),
123 port: addr.port(),
124 }
125 }
126}
127
128impl FromStr for Addr {
129 type Err = AddrError;
130
131 fn from_str(s: &str) -> Result<Self, Self::Err> {
147 if let Ok(addr) = SocketAddr::from_str(s) {
148 return Ok(addr.into());
149 }
150 match s.split_once(':') {
151 Some((host, port)) => {
152 let port = port
153 .parse()
154 .map_err(|_| AddrError::InvalidAddr("invalid port syntax"))?;
155 Self::new(host, port)
156 }
157 None => Err(AddrError::InvalidAddr("missing ':' in `host:port`")),
158 }
159 }
160}
161
162impl fmt::Display for Addr {
163 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
164 if self.host().contains(':') {
165 let ip = Ipv6Addr::from_str(self.host()).map_err(|_| fmt::Error)?;
166 SocketAddr::from((ip, self.port())).fmt(f)
167 } else {
168 write!(f, "{}:{}", self.host(), self.port())
169 }
170 }
171}
172
173impl Serialize for Addr {
174 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
175 where
176 S: Serializer,
177 {
178 serializer.serialize_str(&self.to_string())
179 }
180}
181
182impl<'de> Deserialize<'de> for Addr {
183 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
184 where
185 D: Deserializer<'de>,
186 {
187 struct AddrVisitor;
188 impl Visitor<'_> for AddrVisitor {
189 type Value = Addr;
190
191 fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
192 formatter.write_str("a 'host:port' network address")
193 }
194
195 fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
196 where
197 E: de::Error,
198 {
199 value.parse().map_err(E::custom)
200 }
201 }
202 deserializer.deserialize_str(AddrVisitor)
203 }
204}
205
206#[derive(Copy, Clone)]
208struct Host {
209 len: u8,
212 buf: [u8; 253],
213}
214
215impl Host {
216 fn from_domain(domain: &str) -> Option<Self> {
218 if !is_domain_name(domain) {
219 None
220 } else {
221 Self::try_from_str(domain)
222 }
223 }
224
225 fn from_ipv4(ip: &Ipv4Addr) -> Self {
227 Self::from_fmt(FmtBuf::fmt_ipv4(ip))
228 }
229
230 fn from_ipv6(ip: &Ipv6Addr) -> Self {
232 Self::from_fmt(FmtBuf::fmt_ipv6(ip))
233 }
234
235 #[inline(always)]
236 fn try_from_str(s: &str) -> Option<Self> {
237 let mut buf = [0u8; 253];
238 let src = s.as_bytes();
239 buf.get_mut(..src.len())?.copy_from_slice(src);
240 Some(Self {
241 len: src.len() as u8,
243 buf,
244 })
245 }
246
247 #[inline(always)]
248 fn from_fmt(fmt: FmtBuf) -> Self {
249 debug_assert!(fmt.len < 253);
250
251 let mut buf = [0u8; 253];
253 buf.copy_from_slice(&fmt.buf[..253]);
254 Self { len: fmt.len, buf }
255 }
256
257 #[inline(always)]
258 fn as_bytes(&self) -> &[u8] {
259 unsafe { slice::from_raw_parts(self.buf.as_ptr(), usize::from(self.len)) }
262 }
263
264 #[inline(always)]
265 fn as_str(&self) -> &str {
266 unsafe { str::from_utf8_unchecked(self.as_bytes()) }
268 }
269}
270
271impl Eq for Host {}
272impl PartialEq for Host {
273 #[inline]
274 fn eq(&self, other: &Self) -> bool {
275 self.as_str() == other.as_str()
276 }
277}
278
279impl Ord for Host {
280 #[inline]
281 fn cmp(&self, other: &Self) -> Ordering {
282 Ord::cmp(self.as_str(), other.as_str())
283 }
284}
285
286impl PartialOrd for Host {
287 #[inline]
288 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
289 Some(Ord::cmp(self, other))
290 }
291}
292
293impl Hash for Host {
294 #[inline]
295 fn hash<H>(&self, state: &mut H)
296 where
297 H: Hasher,
298 {
299 Hash::hash(self.as_str(), state)
300 }
301}
302
303impl fmt::Debug for Host {
304 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
305 f.write_str(self.as_str())
306 }
307}
308
309impl Deref for Host {
310 type Target = str;
311
312 #[inline]
313 fn deref(&self) -> &Self::Target {
314 self.as_str()
315 }
316}
317
318impl<T> From<T> for Host
319where
320 T: Into<IpAddr>,
321{
322 #[inline]
323 fn from(ip: T) -> Self {
324 match ip.into() {
325 IpAddr::V4(addr) => Self::from_ipv4(&addr),
326 IpAddr::V6(addr) => Self::from_ipv6(&addr),
327 }
328 }
329}
330
331fn is_domain_name(s: &str) -> bool {
336 if s == "." {
337 return true;
338 }
339 if s.is_empty() || s.len() > 253 {
340 return false;
341 }
342
343 let mut last = b'.';
344 let mut non_numeric = false;
345 let mut part_len = 0;
346 for c in s.as_bytes() {
347 match c {
348 b'a'..=b'z' | b'A'..=b'Z' | b'_' => {
349 non_numeric = true;
350 part_len += 1;
351 }
352 b'0'..=b'9' => {
353 part_len += 1;
354 }
355 b'-' => {
356 if last == b'.' {
357 return false;
358 }
359 part_len += 1;
360 non_numeric = true;
361 }
362 b'.' => {
363 if last == b'.' || last == b'-' {
364 return false;
365 }
366 if part_len > 63 || part_len == 0 {
367 return false;
368 }
369 part_len = 0;
370 }
371 _ => return false,
372 };
373 last = *c;
374 }
375 if last == b'-' || part_len > 63 {
376 return false;
377 }
378 non_numeric
379}
380
381struct FmtBuf {
383 len: u8,
385 buf: [u8; 256],
390}
391
392impl FmtBuf {
393 #[inline(always)]
395 const fn new() -> Self {
396 Self {
397 len: 0,
398 buf: [0u8; 256],
399 }
400 }
401
402 #[inline(always)]
405 fn available(&self) -> usize {
406 self.buf.len() - usize::from(self.len)
407 }
408
409 #[inline(always)]
411 #[cfg(test)]
412 #[allow(clippy::indexing_slicing)]
413 fn as_bytes(&self) -> &[u8] {
414 &self.buf[..usize::from(self.len)]
416 }
417
418 #[inline(always)]
420 #[allow(clippy::indexing_slicing)]
421 fn write(&mut self, c: u8) {
422 debug_assert!(self.available() > 0);
423
424 self.buf[usize::from(self.len)] = c;
427 self.len += 1;
428 }
429
430 #[inline(always)]
432 fn write_str(&mut self, s: &str) {
433 debug_assert!(self.available() >= s.len());
434
435 for c in s.as_bytes() {
436 self.write(*c);
437 }
438 }
439
440 #[inline(always)]
442 fn itoa10(&mut self, x: u8) {
443 if x >= 100 {
444 self.write(base10(x / 100))
445 }
446 if x >= 10 {
447 self.write(base10(x / 10 % 10))
448 }
449 self.write(base10(x % 10))
450 }
451
452 #[inline(always)]
454 fn itoa16(&mut self, x: u16) {
455 if x >= 0x1000 {
456 self.write(base16((x >> 12) as u8));
457 }
458 if x >= 0x100 {
459 self.write(base16(((x >> 8) & 0xf) as u8));
460 }
461 if x >= 0x10 {
462 self.write(base16(((x >> 4) & 0x0f) as u8));
463 }
464 self.write(base16((x & 0x0f) as u8));
465 }
466
467 fn fmt_ipv4(ip: &Ipv4Addr) -> Self {
469 let octets = ip.octets();
470
471 let mut buf = Self::new();
472 buf.itoa10(octets[0]);
473 buf.write(b'.');
474 buf.itoa10(octets[1]);
475 buf.write(b'.');
476 buf.itoa10(octets[2]);
477 buf.write(b'.');
478 buf.itoa10(octets[3]);
479 buf
480 }
481
482 fn fmt_ipv6(ip: &Ipv6Addr) -> Self {
485 let mut buf = Self::new();
486
487 if let Some(ip) = ip.to_ipv4_mapped() {
488 let octets = ip.octets();
489 buf.write_str("::ffff:");
490 buf.itoa10(octets[0]);
491 buf.write(b'.');
492 buf.itoa10(octets[1]);
493 buf.write(b'.');
494 buf.itoa10(octets[2]);
495 buf.write(b'.');
496 buf.itoa10(octets[3]);
497 return buf;
498 }
499
500 let segments = ip.segments();
501
502 let zeros = {
503 #[derive(Copy, Clone, Default)]
504 struct Span {
505 start: usize,
506 len: usize,
507 }
508 impl Span {
509 const fn contains(&self, idx: usize) -> bool {
510 self.start <= idx && idx < self.start + self.len
511 }
512 }
513
514 let mut max = Span::default();
515 let mut cur = Span::default();
516
517 for (i, &seg) in segments.iter().enumerate() {
518 if seg == 0 {
519 if cur.len == 0 {
520 cur.start = i;
521 }
522 cur.len += 1;
523
524 if cur.len >= 2 && cur.len > max.len {
525 max = cur;
526 }
527 } else {
528 cur = Span::default();
529 }
530 }
531 max
532 };
533
534 let mut iter = segments.iter().enumerate();
539 while let Some((i, &seg)) = iter.next() {
540 if zeros.contains(i) {
541 buf.write_str("::");
542
543 if let Some((_, &seg)) = iter.nth(zeros.len - 1) {
544 buf.itoa16(seg);
545 }
546 } else {
547 if i > 0 {
548 buf.write(b':')
549 }
550 buf.itoa16(seg);
551 }
552 }
553 buf
554 }
555}
556
557const fn base10(x: u8) -> u8 {
560 debug_assert!(x <= 9);
561
562 x + b'0'
563}
564
565const fn base16(x: u8) -> u8 {
568 debug_assert!(x <= 15);
569
570 if x < 10 {
571 base10(x)
572 } else {
573 x - 10 + b'a'
574 }
575}
576
577#[derive(Debug)]
579pub enum AddrError {
580 Bug(Bug),
582 InvalidAddr(&'static str),
584}
585
586impl error::Error for AddrError {
587 fn source(&self) -> Option<&(dyn error::Error + 'static)> {
588 match self {
589 Self::Bug(err) => Some(err),
590 _ => None,
591 }
592 }
593}
594
595impl fmt::Display for AddrError {
596 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
597 match self {
598 Self::Bug(err) => write!(f, "{err}"),
599 Self::InvalidAddr(msg) => {
600 write!(f, "invalid network address: {msg}")
601 }
602 }
603 }
604}
605
606impl From<Bug> for AddrError {
607 fn from(err: Bug) -> Self {
608 Self::Bug(err)
609 }
610}
611
612#[allow(clippy::indexing_slicing, clippy::expect_used)]
613#[cfg(test)]
614mod tests {
615 use super::*;
616
617 #[test]
618 fn test_base10() {
619 const DIGITS: &[u8] = b"0123456789";
620 for x in 0..=9u8 {
621 let want = DIGITS[x as usize];
622 let got = base10(x);
623 assert_eq!(got, want);
624 }
625 }
626
627 #[test]
628 fn test_base16() {
629 const DIGITS: &[u8] = b"0123456789abcdef";
630 for x in 0..=15u8 {
631 let want = DIGITS[x as usize];
632 let got = base16(x);
633 assert_eq!(got, want);
634 }
635 }
636
637 #[test]
638 fn test_addr_parse() {
639 let tests = ["127.0.0.1:8080", "[2001:db8::1]:8080"];
640 for test in tests {
641 let got = Addr::from_str(test).unwrap();
642 let want = SocketAddr::from_str(test).unwrap();
643 assert_eq!(got, want.into());
644 }
645 }
646
647 #[test]
648 fn test_host_ipv4() {
649 let ips = [
650 Ipv4Addr::UNSPECIFIED,
651 Ipv4Addr::LOCALHOST,
652 Ipv4Addr::BROADCAST,
653 Ipv4Addr::new(127, 0, 0, 1),
654 Ipv4Addr::new(1, 1, 1, 1),
655 Ipv4Addr::new(1, 2, 3, 4),
656 Ipv4Addr::new(4, 3, 2, 1),
657 Ipv4Addr::new(127, 127, 127, 127),
658 Ipv4Addr::new(100, 10, 1, 0),
659 ];
660 for (i, ip) in ips.into_iter().enumerate() {
661 let want = ip.to_string();
662 let got = String::from_utf8(FmtBuf::fmt_ipv4(&ip).as_bytes().to_vec())
663 .expect("`FmtBuf` should be valid UTF-8");
664 assert_eq!(got, want, "#{i}");
665 }
666 }
667
668 #[test]
669 fn test_host_ipv6() {
670 let ips = [
671 Ipv6Addr::UNSPECIFIED,
672 Ipv6Addr::LOCALHOST,
673 Ipv6Addr::new(0, 0, 0, 0, 0, 0xffff, 0xc00a, 0x2ff),
674 Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0xc000, 0x280),
675 Ipv6Addr::new(
676 0x1111, 0x2222, 0x3333, 0x4444, 0x5555, 0x6666, 0x7777, 0x8888,
677 ),
678 Ipv6Addr::new(0xae, 0, 0, 0, 0, 0xffff, 0x0102, 0x0304),
679 Ipv6Addr::new(1, 0, 0, 0, 0, 0, 0, 0),
680 Ipv6Addr::new(1, 0, 0, 4, 0, 0, 0, 8),
681 Ipv6Addr::new(1, 0, 0, 4, 5, 0, 0, 8),
682 Ipv6Addr::new(1, 2, 3, 4, 5, 6, 7, 8),
683 Ipv6Addr::new(8, 7, 6, 5, 4, 3, 2, 1),
684 Ipv6Addr::new(127, 127, 127, 127, 127, 127, 127, 127),
685 Ipv6Addr::new(16, 16, 16, 16, 16, 16, 16, 16),
686 ];
687 for (i, ip) in ips.into_iter().enumerate() {
688 let want = ip.to_string();
689 let got = String::from_utf8(FmtBuf::fmt_ipv6(&ip).as_bytes().to_vec())
690 .expect("`FmtBuf` should be valid UTF-8");
691 assert_eq!(got, want, "#{i}");
692 }
693 }
694}