Skip to main content

reliakit_primitives/
net.rs

1use crate::{PrimitiveError, PrimitiveResult};
2use core::fmt;
3use core::net::{IpAddr, Ipv4Addr, Ipv6Addr};
4use core::str::FromStr;
5
6/// An IP network in CIDR notation: a base address plus a prefix length.
7///
8/// `Cidr` validates that the prefix length is in range for the address family
9/// (`0..=32` for IPv4, `0..=128` for IPv6) at construction. It accepts both IPv4
10/// and IPv6 via [`core::net::IpAddr`], stores the address exactly as supplied
11/// (host bits are not cleared), and offers membership testing with
12/// [`contains`](Cidr::contains) and the canonical network base with
13/// [`network`](Cidr::network).
14///
15/// This type is allocation-free and `no_std`-friendly.
16///
17/// # Examples
18///
19/// ```
20/// use reliakit_primitives::Cidr;
21///
22/// let net: Cidr = "192.168.1.0/24".parse().unwrap();
23/// assert!(net.contains("192.168.1.42".parse().unwrap()));
24/// assert!(!net.contains("192.168.2.1".parse().unwrap()));
25/// assert_eq!(net.prefix_len(), 24);
26/// ```
27#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
28pub struct Cidr {
29    addr: IpAddr,
30    prefix_len: u8,
31}
32
33impl Cidr {
34    /// Creates a `Cidr` from an address and prefix length.
35    ///
36    /// Returns [`PrimitiveError::Invalid`] if `prefix_len` exceeds the maximum
37    /// for the address family (32 for IPv4, 128 for IPv6).
38    pub fn new(addr: IpAddr, prefix_len: u8) -> PrimitiveResult<Self> {
39        let max = match addr {
40            IpAddr::V4(_) => 32,
41            IpAddr::V6(_) => 128,
42        };
43        if prefix_len > max {
44            return Err(PrimitiveError::Invalid {
45                message: "prefix length out of range for the address family",
46            });
47        }
48        Ok(Self { addr, prefix_len })
49    }
50
51    /// Returns the base address exactly as supplied (host bits not cleared).
52    pub const fn address(&self) -> IpAddr {
53        self.addr
54    }
55
56    /// Returns the prefix length (number of leading network bits).
57    pub const fn prefix_len(&self) -> u8 {
58        self.prefix_len
59    }
60
61    /// Returns `true` if the network is IPv4.
62    pub const fn is_ipv4(&self) -> bool {
63        matches!(self.addr, IpAddr::V4(_))
64    }
65
66    /// Returns `true` if the network is IPv6.
67    pub const fn is_ipv6(&self) -> bool {
68        matches!(self.addr, IpAddr::V6(_))
69    }
70
71    /// Returns the canonical network address with host bits cleared.
72    ///
73    /// For `192.168.1.42/24` this returns `192.168.1.0`.
74    pub fn network(&self) -> IpAddr {
75        match self.addr {
76            IpAddr::V4(a) => {
77                IpAddr::V4(Ipv4Addr::from_bits(a.to_bits() & mask_v4(self.prefix_len)))
78            }
79            IpAddr::V6(a) => {
80                IpAddr::V6(Ipv6Addr::from_bits(a.to_bits() & mask_v6(self.prefix_len)))
81            }
82        }
83    }
84
85    /// Returns `true` if `ip` falls within this network.
86    ///
87    /// An address of a different family than the network is never contained.
88    pub fn contains(&self, ip: IpAddr) -> bool {
89        match (self.addr, ip) {
90            (IpAddr::V4(net), IpAddr::V4(probe)) => {
91                let m = mask_v4(self.prefix_len);
92                net.to_bits() & m == probe.to_bits() & m
93            }
94            (IpAddr::V6(net), IpAddr::V6(probe)) => {
95                let m = mask_v6(self.prefix_len);
96                net.to_bits() & m == probe.to_bits() & m
97            }
98            _ => false,
99        }
100    }
101}
102
103/// Builds a left-aligned mask of `prefix` set bits for a 32-bit address.
104fn mask_v4(prefix: u8) -> u32 {
105    match prefix {
106        0 => 0,
107        p if p >= 32 => u32::MAX,
108        p => u32::MAX << (32 - p),
109    }
110}
111
112/// Builds a left-aligned mask of `prefix` set bits for a 128-bit address.
113fn mask_v6(prefix: u8) -> u128 {
114    match prefix {
115        0 => 0,
116        p if p >= 128 => u128::MAX,
117        p => u128::MAX << (128 - p),
118    }
119}
120
121impl fmt::Display for Cidr {
122    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
123        write!(f, "{}/{}", self.addr, self.prefix_len)
124    }
125}
126
127impl FromStr for Cidr {
128    type Err = PrimitiveError;
129
130    fn from_str(s: &str) -> Result<Self, Self::Err> {
131        let (addr_part, prefix_part) = s.split_once('/').ok_or(PrimitiveError::Invalid {
132            message: "CIDR must be written as address/prefix",
133        })?;
134        let addr = addr_part
135            .parse::<IpAddr>()
136            .map_err(|_| PrimitiveError::Invalid {
137                message: "invalid IP address in CIDR",
138            })?;
139        let prefix_len = prefix_part
140            .parse::<u8>()
141            .map_err(|_| PrimitiveError::Invalid {
142                message: "invalid prefix length in CIDR",
143            })?;
144        Self::new(addr, prefix_len)
145    }
146}
147
148impl TryFrom<&str> for Cidr {
149    type Error = PrimitiveError;
150
151    fn try_from(value: &str) -> Result<Self, Self::Error> {
152        value.parse()
153    }
154}
155
156#[cfg(test)]
157mod tests {
158    use super::Cidr;
159    use crate::PrimitiveErrorKind;
160    use core::net::IpAddr;
161
162    fn ip(s: &str) -> IpAddr {
163        s.parse().unwrap()
164    }
165
166    #[test]
167    fn parses_ipv4_cidr() {
168        let net: Cidr = "192.168.1.0/24".parse().unwrap();
169        assert!(net.is_ipv4());
170        assert_eq!(net.prefix_len(), 24);
171        assert_eq!(net.address(), ip("192.168.1.0"));
172    }
173
174    #[test]
175    fn ipv4_contains_membership() {
176        let net: Cidr = "10.0.0.0/8".parse().unwrap();
177        assert!(net.contains(ip("10.255.255.255")));
178        assert!(net.contains(ip("10.0.0.1")));
179        assert!(!net.contains(ip("11.0.0.1")));
180    }
181
182    #[test]
183    fn host_bits_preserved_but_network_masks() {
184        let net: Cidr = "192.168.1.42/24".parse().unwrap();
185        assert_eq!(net.address(), ip("192.168.1.42"));
186        assert_eq!(net.network(), ip("192.168.1.0"));
187    }
188
189    #[test]
190    fn prefix_zero_matches_everything() {
191        let net: Cidr = "0.0.0.0/0".parse().unwrap();
192        assert!(net.contains(ip("8.8.8.8")));
193        assert!(net.contains(ip("255.255.255.255")));
194    }
195
196    #[test]
197    fn prefix_32_is_single_host() {
198        let net: Cidr = "192.168.1.5/32".parse().unwrap();
199        assert!(net.contains(ip("192.168.1.5")));
200        assert!(!net.contains(ip("192.168.1.6")));
201    }
202
203    #[test]
204    fn parses_ipv6_cidr() {
205        let net: Cidr = "2001:db8::/32".parse().unwrap();
206        assert!(net.is_ipv6());
207        assert_eq!(net.prefix_len(), 32);
208        assert!(net.contains(ip("2001:db8:1234::1")));
209        assert!(!net.contains(ip("2001:db9::1")));
210    }
211
212    #[test]
213    fn ipv6_prefix_128_single_host() {
214        let net: Cidr = "::1/128".parse().unwrap();
215        assert!(net.contains(ip("::1")));
216        assert!(!net.contains(ip("::2")));
217    }
218
219    #[test]
220    fn cross_family_never_contained() {
221        let v4: Cidr = "10.0.0.0/8".parse().unwrap();
222        assert!(!v4.contains(ip("::1")));
223        let v6: Cidr = "2001:db8::/32".parse().unwrap();
224        assert!(!v6.contains(ip("10.0.0.1")));
225    }
226
227    #[test]
228    fn rejects_prefix_out_of_range() {
229        assert_eq!(
230            "192.168.0.0/33".parse::<Cidr>().unwrap_err().kind(),
231            PrimitiveErrorKind::InvalidFormat
232        );
233        assert_eq!(
234            "2001:db8::/129".parse::<Cidr>().unwrap_err().kind(),
235            PrimitiveErrorKind::InvalidFormat
236        );
237    }
238
239    #[test]
240    fn rejects_malformed() {
241        assert!("192.168.0.0".parse::<Cidr>().is_err()); // no prefix
242        assert!("not-an-ip/24".parse::<Cidr>().is_err());
243        assert!("192.168.0.0/abc".parse::<Cidr>().is_err());
244        assert!("192.168.0.0/".parse::<Cidr>().is_err());
245    }
246
247    #[test]
248    fn display_round_trips() {
249        let net: Cidr = "172.16.0.0/12".parse().unwrap();
250        assert_eq!(net.to_string(), "172.16.0.0/12");
251        let v6: Cidr = "fe80::/10".parse().unwrap();
252        assert_eq!(v6.to_string(), "fe80::/10");
253    }
254
255    #[test]
256    fn try_from_str() {
257        assert!(Cidr::try_from("10.0.0.0/8").is_ok());
258        assert!(Cidr::try_from("bad").is_err());
259    }
260}