Skip to main content

aranya_util/
addr.rs

1// This module was carefully written to ensure no arithmetic path can
2// panic, so we suppress the lint rather than add checked-math noise.
3#![allow(clippy::arithmetic_side_effects)]
4//! Network address handling and utilities.
5//!
6//! This module provides the [`Addr`] type for representing network addresses
7//! (hostname/IP and port) and associated functionality.
8//!
9//!
10//! The Addr type supports parsing addresses from strings in various formats
11//! (e.g., "spideroak.com:80", "192.168.1.1:8080", "\[::1\]:443"), asynchronous DNS
12//! resolution via [`Addr::lookup`], and conversion to and from standard library
13//! types like [`std::net::SocketAddr`], [`std::net::Ipv4Addr`], and
14//! [`std::net::Ipv6Addr`].
15
16use std::{
17    cmp::Ordering,
18    error, fmt,
19    hash::{Hash, Hasher},
20    io,
21    mem::size_of,
22    net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr},
23    ops::Deref,
24    slice,
25    str::{self, FromStr},
26};
27
28use anyhow::Result;
29use buggy::Bug;
30use serde::{
31    de::{self, Visitor},
32    Deserialize, Deserializer, Serialize, Serializer,
33};
34use tokio::net::{self, ToSocketAddrs};
35use tracing::{debug, instrument};
36
37macro_rules! const_assert {
38    ($($tt:tt)*) => {
39        const _: () = assert!($($tt)*);
40    }
41}
42
43/// A "host:port" network address.
44///
45/// The host can be a domain name, IPv4, or IPv6 address and the
46/// port must be a valid [u16].
47///
48/// `Addr` ensures that the host part is a syntactically valid domain name or IP address.
49/// It provides methods for DNS lookup, conversion to socket addresses, and serde
50/// serialization/deserialization.
51#[derive(Copy, Clone, Debug, Hash, Eq, PartialEq, Ord, PartialOrd)]
52pub struct Addr {
53    host: Host,
54    port: u16,
55}
56const_assert!(size_of::<Addr>() == 256);
57
58impl Addr {
59    /// Creates a new `Addr` from a host representation and a port number.
60    ///
61    /// The `host` can be a domain name (e.g., "spideroak.com"), an IPv4 address string
62    /// (e.g., "192.168.1.1"), or an IPv6 address string (e.g., "::1").
63    ///
64    /// Returns an error if the host is not a valid domain name or IP address string.
65    ///
66    /// # Errors
67    ///
68    /// Returns `AddrError::InvalidAddr` if the `host` string is not a valid
69    /// domain name or IP address.
70    pub fn new<T>(host: T, port: u16) -> Result<Self, AddrError>
71    where
72        T: AsRef<str>,
73    {
74        let host = host.as_ref();
75        let host = Host::from_domain(host)
76            .or_else(|| host.parse::<Ipv4Addr>().ok().map(Into::into))
77            .or_else(|| host.parse::<Ipv6Addr>().ok().map(Into::into))
78            .ok_or(AddrError::InvalidAddr(
79                "not a valid domain name or IP address",
80            ))?;
81        Ok(Self { host, port })
82    }
83
84    /// Returns a reference to the host part (domain name or IP address) of the `Addr`.
85    pub fn host(&self) -> &str {
86        &self.host
87    }
88
89    /// Returns the port number of the `Addr`.
90    pub fn port(&self) -> u16 {
91        self.port
92    }
93
94    /// Performs an asynchronous DNS lookup for this `Addr`.
95    ///
96    /// Resolves the host part of the address to one or more [`SocketAddr`] values.
97    #[instrument(skip_all, fields(host = %self))]
98    pub async fn lookup(&self) -> io::Result<impl Iterator<Item = SocketAddr> + '_> {
99        debug!("performing DNS lookup");
100        net::lookup_host(Into::<(&str, u16)>::into(self)).await
101    }
102
103    /// Converts the `Addr` into a type that implements [`ToSocketAddrs`].
104    pub fn to_socket_addrs(&self) -> impl ToSocketAddrs + '_ {
105        Into::<(&str, u16)>::into(self)
106    }
107}
108
109impl<'a> From<&'a Addr> for (&'a str, u16) {
110    fn from(addr: &'a Addr) -> Self {
111        (&addr.host, addr.port)
112    }
113}
114
115impl<T> From<T> for Addr
116where
117    T: Into<SocketAddr>,
118{
119    fn from(value: T) -> Self {
120        let addr = value.into();
121        Self {
122            host: addr.ip().into(),
123            port: addr.port(),
124        }
125    }
126}
127
128impl FromStr for Addr {
129    type Err = AddrError;
130
131    /// Parses a string into an `Addr`.
132    ///
133    /// The string can be in several forms:
134    /// - `host:port` (e.g., "spideroak.com:80", "192.168.1.1:8080")
135    /// - IPv6 address with port: `[ipv6_addr]:port` (e.g., "\[::1\]:443")
136    /// - A string representation of a `SocketAddr` (which `std::net::SocketAddr::from_str` can parse).
137    ///
138    /// This function first attempts to parse using [`SocketAddr`], then falls
139    /// back to splitting the string at the first ":" character to parse the
140    /// host and port.
141    ///
142    /// # Errors
143    ///
144    /// Returns `AddrError::InvalidAddr` if the string format is invalid or the port
145    /// number is malformed.
146    fn from_str(s: &str) -> Result<Self, Self::Err> {
147        if let Ok(addr) = SocketAddr::from_str(s) {
148            return Ok(addr.into());
149        }
150        match s.split_once(':') {
151            Some((host, port)) => {
152                let port = port
153                    .parse()
154                    .map_err(|_| AddrError::InvalidAddr("invalid port syntax"))?;
155                Self::new(host, port)
156            }
157            None => Err(AddrError::InvalidAddr("missing ':' in `host:port`")),
158        }
159    }
160}
161
162impl fmt::Display for Addr {
163    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
164        if self.host().contains(':') {
165            let ip = Ipv6Addr::from_str(self.host()).map_err(|_| fmt::Error)?;
166            SocketAddr::from((ip, self.port())).fmt(f)
167        } else {
168            write!(f, "{}:{}", self.host(), self.port())
169        }
170    }
171}
172
173impl Serialize for Addr {
174    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
175    where
176        S: Serializer,
177    {
178        serializer.serialize_str(&self.to_string())
179    }
180}
181
182impl<'de> Deserialize<'de> for Addr {
183    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
184    where
185        D: Deserializer<'de>,
186    {
187        struct AddrVisitor;
188        impl Visitor<'_> for AddrVisitor {
189            type Value = Addr;
190
191            fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
192                formatter.write_str("a 'host:port' network address")
193            }
194
195            fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
196            where
197                E: de::Error,
198            {
199                value.parse().map_err(E::custom)
200            }
201        }
202        deserializer.deserialize_str(AddrVisitor)
203    }
204}
205
206/// A hostname.
207#[derive(Copy, Clone)]
208struct Host {
209    // NB: `Host` is exactly 254 bytes long. This allows `Addr`
210    // to be exactly 256 bytes long.
211    len: u8,
212    buf: [u8; 253],
213}
214
215impl Host {
216    /// Creates a `Host` from a domain name.
217    fn from_domain(domain: &str) -> Option<Self> {
218        if !is_domain_name(domain) {
219            None
220        } else {
221            Self::try_from_str(domain)
222        }
223    }
224
225    /// Creates a `Host` from an IPv4 address.
226    fn from_ipv4(ip: &Ipv4Addr) -> Self {
227        Self::from_fmt(FmtBuf::fmt_ipv4(ip))
228    }
229
230    /// Creates a `Host` from an IPv6 address.
231    fn from_ipv6(ip: &Ipv6Addr) -> Self {
232        Self::from_fmt(FmtBuf::fmt_ipv6(ip))
233    }
234
235    #[inline(always)]
236    fn try_from_str(s: &str) -> Option<Self> {
237        let mut buf = [0u8; 253];
238        let src = s.as_bytes();
239        buf.get_mut(..src.len())?.copy_from_slice(src);
240        Some(Self {
241            // We copied <= 253 bytes, so `src.len() < u8::MAX`.
242            len: src.len() as u8,
243            buf,
244        })
245    }
246
247    #[inline(always)]
248    fn from_fmt(fmt: FmtBuf) -> Self {
249        debug_assert!(fmt.len < 253);
250
251        // NB: the compiler can prove that `len` is in bounds.
252        let mut buf = [0u8; 253];
253        buf.copy_from_slice(&fmt.buf[..253]);
254        Self { len: fmt.len, buf }
255    }
256
257    #[inline(always)]
258    fn as_bytes(&self) -> &[u8] {
259        // SAFETY: the buffer is always valid and length is
260        // correct.
261        unsafe { slice::from_raw_parts(self.buf.as_ptr(), usize::from(self.len)) }
262    }
263
264    #[inline(always)]
265    fn as_str(&self) -> &str {
266        // SAFETY: `Host` only stores valid UTF-8.
267        unsafe { str::from_utf8_unchecked(self.as_bytes()) }
268    }
269}
270
271impl Eq for Host {}
272impl PartialEq for Host {
273    #[inline]
274    fn eq(&self, other: &Self) -> bool {
275        self.as_str() == other.as_str()
276    }
277}
278
279impl Ord for Host {
280    #[inline]
281    fn cmp(&self, other: &Self) -> Ordering {
282        Ord::cmp(self.as_str(), other.as_str())
283    }
284}
285
286impl PartialOrd for Host {
287    #[inline]
288    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
289        Some(Ord::cmp(self, other))
290    }
291}
292
293impl Hash for Host {
294    #[inline]
295    fn hash<H>(&self, state: &mut H)
296    where
297        H: Hasher,
298    {
299        Hash::hash(self.as_str(), state)
300    }
301}
302
303impl fmt::Debug for Host {
304    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
305        f.write_str(self.as_str())
306    }
307}
308
309impl Deref for Host {
310    type Target = str;
311
312    #[inline]
313    fn deref(&self) -> &Self::Target {
314        self.as_str()
315    }
316}
317
318impl<T> From<T> for Host
319where
320    T: Into<IpAddr>,
321{
322    #[inline]
323    fn from(ip: T) -> Self {
324        match ip.into() {
325            IpAddr::V4(addr) => Self::from_ipv4(&addr),
326            IpAddr::V6(addr) => Self::from_ipv6(&addr),
327        }
328    }
329}
330
331/// Reports whether `s` is a valid domain name.
332///
333/// See
334/// <https://github.com/golang/go/blob/a66a3bf494f652bc4fb209d861cbdba1dea71303/src/net/dnsclient.go#L78>.
335fn is_domain_name(s: &str) -> bool {
336    if s == "." {
337        return true;
338    }
339    if s.is_empty() || s.len() > 253 {
340        return false;
341    }
342
343    let mut last = b'.';
344    let mut non_numeric = false;
345    let mut part_len = 0;
346    for c in s.as_bytes() {
347        match c {
348            b'a'..=b'z' | b'A'..=b'Z' | b'_' => {
349                non_numeric = true;
350                part_len += 1;
351            }
352            b'0'..=b'9' => {
353                part_len += 1;
354            }
355            b'-' => {
356                if last == b'.' {
357                    return false;
358                }
359                part_len += 1;
360                non_numeric = true;
361            }
362            b'.' => {
363                if last == b'.' || last == b'-' {
364                    return false;
365                }
366                if part_len > 63 || part_len == 0 {
367                    return false;
368                }
369                part_len = 0;
370            }
371            _ => return false,
372        };
373        last = *c;
374    }
375    if last == b'-' || part_len > 63 {
376        return false;
377    }
378    non_numeric
379}
380
381/// Used to format IP addresses.
382struct FmtBuf {
383    /// The number of bytes written.
384    len: u8,
385    /// The size of this buffer lets the compiler prove that all
386    /// writes are in bounds without panicking.
387    ///
388    /// Contents are `buf[..len]`.
389    buf: [u8; 256],
390}
391
392impl FmtBuf {
393    /// Creates a new `FmtBuf`.
394    #[inline(always)]
395    const fn new() -> Self {
396        Self {
397            len: 0,
398            buf: [0u8; 256],
399        }
400    }
401
402    /// The number of bytes that can still be written to the
403    /// buffer.
404    #[inline(always)]
405    fn available(&self) -> usize {
406        self.buf.len() - usize::from(self.len)
407    }
408
409    /// Returns the used portion of the buffer.
410    #[inline(always)]
411    #[cfg(test)]
412    #[allow(clippy::indexing_slicing)]
413    fn as_bytes(&self) -> &[u8] {
414        // NB: the compiler can prove that `len` is in bounds.
415        &self.buf[..usize::from(self.len)]
416    }
417
418    /// Writes `c` to the buffer.
419    #[inline(always)]
420    #[allow(clippy::indexing_slicing)]
421    fn write(&mut self, c: u8) {
422        debug_assert!(self.available() > 0);
423
424        // NB: the compiler can prove that `self.idx` is in
425        // bounds.
426        self.buf[usize::from(self.len)] = c;
427        self.len += 1;
428    }
429
430    /// Writes `s` to the buffer.
431    #[inline(always)]
432    fn write_str(&mut self, s: &str) {
433        debug_assert!(self.available() >= s.len());
434
435        for c in s.as_bytes() {
436            self.write(*c);
437        }
438    }
439
440    /// Writes `x` as a base-10 integer to the buffer.
441    #[inline(always)]
442    fn itoa10(&mut self, x: u8) {
443        if x >= 100 {
444            self.write(base10(x / 100))
445        }
446        if x >= 10 {
447            self.write(base10(x / 10 % 10))
448        }
449        self.write(base10(x % 10))
450    }
451
452    /// Writes `x` as a base-16 integer to the buffer.
453    #[inline(always)]
454    fn itoa16(&mut self, x: u16) {
455        if x >= 0x1000 {
456            self.write(base16((x >> 12) as u8));
457        }
458        if x >= 0x100 {
459            self.write(base16(((x >> 8) & 0xf) as u8));
460        }
461        if x >= 0x10 {
462            self.write(base16(((x >> 4) & 0x0f) as u8));
463        }
464        self.write(base16((x & 0x0f) as u8));
465    }
466
467    /// Formats `ip` in its dotted quad notation.
468    fn fmt_ipv4(ip: &Ipv4Addr) -> Self {
469        let octets = ip.octets();
470
471        let mut buf = Self::new();
472        buf.itoa10(octets[0]);
473        buf.write(b'.');
474        buf.itoa10(octets[1]);
475        buf.write(b'.');
476        buf.itoa10(octets[2]);
477        buf.write(b'.');
478        buf.itoa10(octets[3]);
479        buf
480    }
481
482    /// Formats `ip` per [RFC
483    /// 5952](https://tools.ietf.org/html/rfc5952).
484    fn fmt_ipv6(ip: &Ipv6Addr) -> Self {
485        let mut buf = Self::new();
486
487        if let Some(ip) = ip.to_ipv4_mapped() {
488            let octets = ip.octets();
489            buf.write_str("::ffff:");
490            buf.itoa10(octets[0]);
491            buf.write(b'.');
492            buf.itoa10(octets[1]);
493            buf.write(b'.');
494            buf.itoa10(octets[2]);
495            buf.write(b'.');
496            buf.itoa10(octets[3]);
497            return buf;
498        }
499
500        let segments = ip.segments();
501
502        let zeros = {
503            #[derive(Copy, Clone, Default)]
504            struct Span {
505                start: usize,
506                len: usize,
507            }
508            impl Span {
509                const fn contains(&self, idx: usize) -> bool {
510                    self.start <= idx && idx < self.start + self.len
511                }
512            }
513
514            let mut max = Span::default();
515            let mut cur = Span::default();
516
517            for (i, &seg) in segments.iter().enumerate() {
518                if seg == 0 {
519                    if cur.len == 0 {
520                        cur.start = i;
521                    }
522                    cur.len += 1;
523
524                    if cur.len >= 2 && cur.len > max.len {
525                        max = cur;
526                    }
527                } else {
528                    cur = Span::default();
529                }
530            }
531            max
532        };
533
534        // TODO(eric): if we make this a little simpler we can
535        // probably convince the compiler to elide all bounds
536        // checks. That would let us make the internal buffer
537        // 253 bytes.
538        let mut iter = segments.iter().enumerate();
539        while let Some((i, &seg)) = iter.next() {
540            if zeros.contains(i) {
541                buf.write_str("::");
542
543                if let Some((_, &seg)) = iter.nth(zeros.len - 1) {
544                    buf.itoa16(seg);
545                }
546            } else {
547                if i > 0 {
548                    buf.write(b':')
549                }
550                buf.itoa16(seg);
551            }
552        }
553        buf
554    }
555}
556
557/// Converts `c`, which must be in `0..=9`, to its base-10
558/// representation.
559const fn base10(x: u8) -> u8 {
560    debug_assert!(x <= 9);
561
562    x + b'0'
563}
564
565/// Converts `c`, which must be in `0..=15`, to its base-16
566/// representation.
567const fn base16(x: u8) -> u8 {
568    debug_assert!(x <= 15);
569
570    if x < 10 {
571        base10(x)
572    } else {
573        x - 10 + b'a'
574    }
575}
576
577/// An error returned by [`Addr`].
578#[derive(Debug)]
579pub enum AddrError {
580    /// An internal bug was discovered.
581    Bug(Bug),
582    /// The provided address string is invalid.
583    InvalidAddr(&'static str),
584}
585
586impl error::Error for AddrError {
587    fn source(&self) -> Option<&(dyn error::Error + 'static)> {
588        match self {
589            Self::Bug(err) => Some(err),
590            _ => None,
591        }
592    }
593}
594
595impl fmt::Display for AddrError {
596    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
597        match self {
598            Self::Bug(err) => write!(f, "{err}"),
599            Self::InvalidAddr(msg) => {
600                write!(f, "invalid network address: {msg}")
601            }
602        }
603    }
604}
605
606impl From<Bug> for AddrError {
607    fn from(err: Bug) -> Self {
608        Self::Bug(err)
609    }
610}
611
612#[allow(clippy::indexing_slicing, clippy::expect_used)]
613#[cfg(test)]
614mod tests {
615    use super::*;
616
617    #[test]
618    fn test_base10() {
619        const DIGITS: &[u8] = b"0123456789";
620        for x in 0..=9u8 {
621            let want = DIGITS[x as usize];
622            let got = base10(x);
623            assert_eq!(got, want);
624        }
625    }
626
627    #[test]
628    fn test_base16() {
629        const DIGITS: &[u8] = b"0123456789abcdef";
630        for x in 0..=15u8 {
631            let want = DIGITS[x as usize];
632            let got = base16(x);
633            assert_eq!(got, want);
634        }
635    }
636
637    #[test]
638    fn test_addr_parse() {
639        let tests = ["127.0.0.1:8080", "[2001:db8::1]:8080"];
640        for test in tests {
641            let got = Addr::from_str(test).unwrap();
642            let want = SocketAddr::from_str(test).unwrap();
643            assert_eq!(got, want.into());
644        }
645    }
646
647    #[test]
648    fn test_host_ipv4() {
649        let ips = [
650            Ipv4Addr::UNSPECIFIED,
651            Ipv4Addr::LOCALHOST,
652            Ipv4Addr::BROADCAST,
653            Ipv4Addr::new(127, 0, 0, 1),
654            Ipv4Addr::new(1, 1, 1, 1),
655            Ipv4Addr::new(1, 2, 3, 4),
656            Ipv4Addr::new(4, 3, 2, 1),
657            Ipv4Addr::new(127, 127, 127, 127),
658            Ipv4Addr::new(100, 10, 1, 0),
659        ];
660        for (i, ip) in ips.into_iter().enumerate() {
661            let want = ip.to_string();
662            let got = String::from_utf8(FmtBuf::fmt_ipv4(&ip).as_bytes().to_vec())
663                .expect("`FmtBuf` should be valid UTF-8");
664            assert_eq!(got, want, "#{i}");
665        }
666    }
667
668    #[test]
669    fn test_host_ipv6() {
670        let ips = [
671            Ipv6Addr::UNSPECIFIED,
672            Ipv6Addr::LOCALHOST,
673            Ipv6Addr::new(0, 0, 0, 0, 0, 0xffff, 0xc00a, 0x2ff),
674            Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0xc000, 0x280),
675            Ipv6Addr::new(
676                0x1111, 0x2222, 0x3333, 0x4444, 0x5555, 0x6666, 0x7777, 0x8888,
677            ),
678            Ipv6Addr::new(0xae, 0, 0, 0, 0, 0xffff, 0x0102, 0x0304),
679            Ipv6Addr::new(1, 0, 0, 0, 0, 0, 0, 0),
680            Ipv6Addr::new(1, 0, 0, 4, 0, 0, 0, 8),
681            Ipv6Addr::new(1, 0, 0, 4, 5, 0, 0, 8),
682            Ipv6Addr::new(1, 2, 3, 4, 5, 6, 7, 8),
683            Ipv6Addr::new(8, 7, 6, 5, 4, 3, 2, 1),
684            Ipv6Addr::new(127, 127, 127, 127, 127, 127, 127, 127),
685            Ipv6Addr::new(16, 16, 16, 16, 16, 16, 16, 16),
686        ];
687        for (i, ip) in ips.into_iter().enumerate() {
688            let want = ip.to_string();
689            let got = String::from_utf8(FmtBuf::fmt_ipv6(&ip).as_bytes().to_vec())
690                .expect("`FmtBuf` should be valid UTF-8");
691            assert_eq!(got, want, "#{i}");
692        }
693    }
694}