Skip to main content

bare_types/net/
cidr.rs

1//! CIDR (Classless Inter-Domain Routing) notation for IP networks.
2//!
3//! This module provides a type-safe abstraction for CIDR notation,
4//! which represents IP network prefixes in the form `address/prefix_length`.
5//!
6//! # Examples
7//!
8//! ```rust
9//! use bare_types::net::{Cidr, IpAddr};
10//!
11//! // Parse CIDR notation
12//! let cidr: Cidr = "192.168.1.0/24".parse().unwrap();
13//!
14//! // Check if an IP address is in network
15//! let ip: IpAddr = "192.168.1.100".parse().unwrap();
16//! assert!(cidr.contains(&ip));
17//!
18//! // Get network address
19//! let network = cidr.network_address();
20//! println!("Network: {}", network);
21//!
22//! // Get broadcast address (only for IPv4)
23//! if let Some(broadcast) = cidr.broadcast_address() {
24//!     println!("Broadcast: {}", broadcast);
25//! }
26//! ```
27
28use core::fmt;
29use core::str::FromStr;
30
31use super::IpAddr;
32
33#[cfg(feature = "serde")]
34use serde::{Deserialize, Serialize};
35
36/// Error type for CIDR parsing.
37#[derive(Debug, Clone, Copy, PartialEq, Eq)]
38#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
39#[non_exhaustive]
40pub enum CidrError {
41    /// Invalid CIDR format
42    ///
43    /// The input string is not in the correct CIDR format.
44    /// Expected format: "`address/prefix_length`" (e.g., "192.168.1.0/24").
45    InvalidFormat,
46    /// Invalid prefix length
47    ///
48    /// The prefix length is not valid for the IP address type.
49    /// IPv4 addresses require prefix length 0-32.
50    /// IPv6 addresses require prefix length 0-128.
51    InvalidPrefixLength,
52    /// Invalid IP address
53    ///
54    /// The address part of the CIDR notation is not a valid IP address.
55    InvalidIpAddr,
56}
57
58impl fmt::Display for CidrError {
59    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
60        match self {
61            Self::InvalidFormat => write!(f, "invalid CIDR format"),
62            Self::InvalidPrefixLength => write!(f, "invalid prefix length"),
63            Self::InvalidIpAddr => write!(f, "invalid IP address"),
64        }
65    }
66}
67
68#[cfg(feature = "std")]
69impl std::error::Error for CidrError {}
70
71/// A CIDR (Classless Inter-Domain Routing) notation for IP networks.
72///
73/// This type represents an IP network prefix in the form `address/prefix_length`,
74/// where `address` is an IP address and `prefix_length` is the number of
75/// leading bits that represent the network portion.
76///
77/// # Examples
78///
79/// ```rust
80/// use bare_types::net::{Cidr, IpAddr};
81///
82/// // Create CIDR from string
83/// let cidr: Cidr = "192.168.1.0/24".parse().unwrap();
84///
85/// // Check if IP is in network
86/// let check_ip: IpAddr = "192.168.1.100".parse().unwrap();
87/// assert!(cidr.contains(&check_ip));
88///
89/// // Get network address
90/// let network = cidr.network_address();
91///
92/// // Get broadcast address (only for IPv4)
93/// let broadcast = cidr.broadcast_address();
94///
95/// // Create from IP address and prefix length
96/// let ip: IpAddr = "10.0.0.0".parse().unwrap();
97/// let cidr2 = Cidr::new(ip, 8).unwrap();
98/// ```
99#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
100#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
101pub struct Cidr {
102    /// The IP address
103    address: IpAddr,
104    /// The prefix length (number of network bits)
105    prefix_length: u8,
106}
107
108impl Cidr {
109    /// Creates a new CIDR from an IP address and prefix length.
110    ///
111    /// # Errors
112    ///
113    /// Returns `CidrError::InvalidPrefixLength` if the prefix length is invalid:
114    /// - IPv4: 0-32
115    /// - IPv6: 0-128
116    ///
117    /// # Examples
118    ///
119    /// ```rust
120    /// use bare_types::net::{Cidr, IpAddr};
121    ///
122    /// let ip: IpAddr = "192.168.1.0".parse().unwrap();
123    /// let cidr = Cidr::new(ip, 24).unwrap();
124    /// ```
125    pub const fn new(address: IpAddr, prefix_length: u8) -> Result<Self, CidrError> {
126        let max_prefix = if address.as_inner().is_ipv4() {
127            32
128        } else {
129            128
130        };
131
132        if prefix_length > max_prefix {
133            return Err(CidrError::InvalidPrefixLength);
134        }
135
136        Ok(Self {
137            address,
138            prefix_length,
139        })
140    }
141
142    /// Returns the IP address.
143    ///
144    /// # Examples
145    ///
146    /// ```rust
147    /// use bare_types::net::{Cidr, IpAddr};
148    ///
149    /// let cidr: Cidr = "192.168.1.0/24".parse().unwrap();
150    /// let ip = cidr.address();
151    /// ```
152    #[must_use]
153    #[inline]
154    pub const fn address(&self) -> IpAddr {
155        self.address
156    }
157
158    /// Returns the prefix length.
159    ///
160    /// # Examples
161    ///
162    /// ```rust
163    /// use bare_types::net::{Cidr, IpAddr};
164    ///
165    /// let cidr: Cidr = "192.168.1.0/24".parse().unwrap();
166    /// assert_eq!(cidr.prefix_length(), 24);
167    /// ```
168    #[must_use]
169    #[inline]
170    pub const fn prefix_length(&self) -> u8 {
171        self.prefix_length
172    }
173
174    /// Returns the network address.
175    ///
176    /// The network address is the IP address with all host bits set to zero.
177    ///
178    /// # Examples
179    ///
180    /// ```rust
181    /// use bare_types::net::Cidr;
182    ///
183    /// let cidr: Cidr = "192.168.1.100/24".parse().unwrap();
184    /// let network = cidr.network_address();
185    /// assert_eq!(network.to_string(), "192.168.1.0");
186    /// ```
187    #[must_use]
188    pub fn network_address(&self) -> IpAddr {
189        let inner = self.address.as_inner();
190
191        if let core::net::IpAddr::V4(ip) = inner {
192            let octets = ip.octets();
193            let ip = u32::from_be_bytes([octets[0], octets[1], octets[2], octets[3]]);
194            let mask = u32::MAX << (32 - u32::from(self.prefix_length));
195            let network = ip & mask;
196            let network_octets = network.to_be_bytes();
197            IpAddr::new(core::net::IpAddr::V4(core::net::Ipv4Addr::new(
198                network_octets[0],
199                network_octets[1],
200                network_octets[2],
201                network_octets[3],
202            )))
203        } else if let core::net::IpAddr::V6(ip) = inner {
204            let segments = ip.segments();
205            let mut network_segments = [0u16; 8];
206
207            let full_segments = (self.prefix_length / 16) as usize;
208            let partial_bits = self.prefix_length % 16;
209
210            network_segments[..full_segments].copy_from_slice(&segments[..full_segments]);
211
212            if partial_bits > 0 && full_segments < 8 {
213                let mask = u16::MAX << (16 - u32::from(partial_bits));
214                network_segments[full_segments] = segments[full_segments] & mask;
215            }
216
217            IpAddr::new(core::net::IpAddr::V6(core::net::Ipv6Addr::new(
218                network_segments[0],
219                network_segments[1],
220                network_segments[2],
221                network_segments[3],
222                network_segments[4],
223                network_segments[5],
224                network_segments[6],
225                network_segments[7],
226            )))
227        } else {
228            self.address
229        }
230    }
231
232    /// Returns the broadcast address.
233    ///
234    /// The broadcast address is the IP address with all host bits set to one.
235    /// Only applicable for IPv4 networks.
236    ///
237    /// # Examples
238    ///
239    /// ```rust
240    /// use bare_types::net::Cidr;
241    ///
242    /// let cidr: Cidr = "192.168.1.0/24".parse().unwrap();
243    /// let broadcast = cidr.broadcast_address().unwrap();
244    /// assert_eq!(broadcast.to_string(), "192.168.1.255");
245    /// ```
246    #[must_use]
247    #[allow(clippy::missing_const_for_fn)]
248    pub fn broadcast_address(&self) -> Option<IpAddr> {
249        let inner = self.address.as_inner();
250
251        if let core::net::IpAddr::V4(ip) = inner {
252            let octets = ip.octets();
253            let ip = u32::from_be_bytes([octets[0], octets[1], octets[2], octets[3]]);
254            let mask = u32::MAX << (32 - u32::from(self.prefix_length));
255            let broadcast = ip | !mask;
256            let broadcast_octets = broadcast.to_be_bytes();
257            Some(IpAddr::new(core::net::IpAddr::V4(
258                core::net::Ipv4Addr::new(
259                    broadcast_octets[0],
260                    broadcast_octets[1],
261                    broadcast_octets[2],
262                    broadcast_octets[3],
263                ),
264            )))
265        } else {
266            None // IPv6 doesn't have broadcast
267        }
268    }
269
270    /// Checks if an IP address is contained in this CIDR network.
271    ///
272    /// # Examples
273    ///
274    /// ```rust
275    /// use bare_types::net::{Cidr, IpAddr};
276    ///
277    /// let cidr: Cidr = "192.168.1.0/24".parse().unwrap();
278    /// let ip1: IpAddr = "192.168.1.100".parse().unwrap();
279    /// let ip2: IpAddr = "192.168.2.100".parse().unwrap();
280    ///
281    /// assert!(cidr.contains(&ip1));
282    /// assert!(!cidr.contains(&ip2));
283    /// ```
284    #[must_use]
285    pub fn contains(&self, ip: &IpAddr) -> bool {
286        let network = self.network_address();
287        let network_inner = network.as_inner();
288        let ip_inner = ip.as_inner();
289
290        match (network_inner, ip_inner) {
291            (core::net::IpAddr::V4(network), core::net::IpAddr::V4(ip)) => {
292                let network_octets = network.octets();
293                let network = u32::from_be_bytes([
294                    network_octets[0],
295                    network_octets[1],
296                    network_octets[2],
297                    network_octets[3],
298                ]);
299                let ip_octets = ip.octets();
300                let ip =
301                    u32::from_be_bytes([ip_octets[0], ip_octets[1], ip_octets[2], ip_octets[3]]);
302                let mask = u32::MAX << (32 - u32::from(self.prefix_length));
303                (network & mask) == (ip & mask)
304            }
305            (core::net::IpAddr::V6(network), core::net::IpAddr::V6(ip)) => {
306                let network_segments = network.segments();
307                let ip_segments = ip.segments();
308
309                let full_segments = (self.prefix_length / 16) as usize;
310                let partial_bits = self.prefix_length % 16;
311
312                for i in 0..full_segments {
313                    if network_segments[i] != ip_segments[i] {
314                        return false;
315                    }
316                }
317
318                if partial_bits > 0 && full_segments < 8 {
319                    let mask = u16::MAX << (16 - u32::from(partial_bits));
320                    if (network_segments[full_segments] & mask)
321                        != (ip_segments[full_segments] & mask)
322                    {
323                        return false;
324                    }
325                }
326
327                true
328            }
329            _ => false,
330        }
331    }
332
333    /// Returns the number of addresses in this network.
334    ///
335    /// # Examples
336    ///
337    /// ```rust
338    /// use bare_types::net::Cidr;
339    ///
340    /// let cidr: Cidr = "192.168.1.0/24".parse().unwrap();
341    /// assert_eq!(cidr.size(), 256);
342    /// ```
343    #[must_use]
344    #[allow(clippy::missing_const_for_fn)]
345    pub fn size(&self) -> u128 {
346        if self.address.as_inner().is_ipv4() {
347            1u128 << (32 - u32::from(self.prefix_length))
348        } else {
349            let shift = 128 - u32::from(self.prefix_length);
350            if shift == 128 {
351                u128::MAX
352            } else {
353                1u128 << shift
354            }
355        }
356    }
357}
358
359#[cfg(feature = "arbitrary")]
360impl<'a> arbitrary::Arbitrary<'a> for Cidr {
361    fn arbitrary(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result<Self> {
362        let address = IpAddr::arbitrary(u)?;
363
364        let max_prefix = if address.as_inner().is_ipv4() {
365            32
366        } else {
367            128
368        };
369
370        let prefix_length = u8::arbitrary(u)? % (max_prefix + 1);
371
372        Ok(Self {
373            address,
374            prefix_length,
375        })
376    }
377}
378
379impl FromStr for Cidr {
380    type Err = CidrError;
381
382    fn from_str(s: &str) -> Result<Self, Self::Err> {
383        let parts: Vec<&str> = s.split('/').collect();
384
385        if parts.len() != 2 {
386            return Err(CidrError::InvalidFormat);
387        }
388
389        let address: IpAddr = parts[0].parse().map_err(|_| CidrError::InvalidIpAddr)?;
390
391        let prefix_length: u8 = parts[1]
392            .parse()
393            .map_err(|_| CidrError::InvalidPrefixLength)?;
394
395        Self::new(address, prefix_length)
396    }
397}
398
399impl fmt::Display for Cidr {
400    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
401        write!(f, "{}/{}", self.address, self.prefix_length)
402    }
403}
404
405#[cfg(test)]
406mod tests {
407    use super::*;
408
409    #[test]
410    fn test_cidr_creation() {
411        let ip: IpAddr = "192.168.1.0".parse().unwrap();
412        let cidr = Cidr::new(ip, 24).unwrap();
413        assert_eq!(cidr.address(), ip);
414        assert_eq!(cidr.prefix_length(), 24);
415    }
416
417    #[test]
418    fn test_cidr_parsing() {
419        let cidr: Cidr = "192.168.1.0/24".parse().unwrap();
420        assert_eq!(cidr.address().to_string(), "192.168.1.0");
421        assert_eq!(cidr.prefix_length(), 24);
422    }
423
424    #[test]
425    fn test_invalid_prefix_length() {
426        let ip: IpAddr = "192.168.1.0".parse().unwrap();
427        assert!(Cidr::new(ip, 33).is_err());
428    }
429
430    #[test]
431    fn test_network_address() {
432        let cidr: Cidr = "192.168.1.100/24".parse().unwrap();
433        let network = cidr.network_address();
434        assert_eq!(network.to_string(), "192.168.1.0");
435    }
436
437    #[test]
438    fn test_broadcast_address() {
439        let cidr: Cidr = "192.168.1.0/24".parse().unwrap();
440        let broadcast = cidr.broadcast_address().unwrap();
441        assert_eq!(broadcast.to_string(), "192.168.1.255");
442    }
443
444    #[test]
445    fn test_contains() {
446        let cidr: Cidr = "192.168.1.0/24".parse().unwrap();
447
448        let ip1: IpAddr = "192.168.1.100".parse().unwrap();
449        let ip2: IpAddr = "192.168.2.100".parse().unwrap();
450
451        assert!(cidr.contains(&ip1));
452        assert!(!cidr.contains(&ip2));
453    }
454
455    #[test]
456    fn test_size() {
457        let cidr: Cidr = "192.168.1.0/24".parse().unwrap();
458        assert_eq!(cidr.size(), 256);
459    }
460
461    #[test]
462    fn test_ipv6_cidr() {
463        let cidr: Cidr = "2001:db8::/32".parse().unwrap();
464        assert_eq!(cidr.prefix_length(), 32);
465
466        let ip: IpAddr = "2001:db8:85a3::8a2e:370:7334".parse().unwrap();
467        assert!(cidr.contains(&ip));
468    }
469
470    #[test]
471    fn test_display() {
472        let cidr: Cidr = "192.168.1.0/24".parse().unwrap();
473        assert_eq!(format!("{}", cidr), "192.168.1.0/24");
474    }
475
476    #[test]
477    fn test_ipv6_network_address() {
478        let cidr: Cidr = "2001:db8:85a3::8a2e:370:7334/64".parse().unwrap();
479        let network = cidr.network_address();
480        assert_eq!(network.to_string(), "2001:db8:85a3::");
481    }
482
483    #[test]
484    fn test_ipv6_contains() {
485        let cidr: Cidr = "2001:db8::/32".parse().unwrap();
486        let ip1: IpAddr = "2001:db8:85a3::8a2e:370:7334".parse().unwrap();
487        let ip2: IpAddr = "2001:db9::1".parse().unwrap();
488
489        assert!(cidr.contains(&ip1));
490        assert!(!cidr.contains(&ip2));
491    }
492
493    #[test]
494    fn test_ipv6_size() {
495        let cidr: Cidr = "2001:db8::/32".parse().unwrap();
496        assert_eq!(cidr.size(), 1u128 << 96);
497    }
498
499    #[test]
500    fn test_ipv6_broadcast_none() {
501        let cidr: Cidr = "2001:db8::/32".parse().unwrap();
502        assert!(cidr.broadcast_address().is_none());
503    }
504
505    #[test]
506    fn test_ipv6_max_prefix() {
507        let cidr: Cidr = "2001:db8::/128".parse().unwrap();
508        assert_eq!(cidr.prefix_length(), 128);
509        assert_eq!(cidr.size(), 1);
510    }
511
512    #[test]
513    fn test_ipv6_zero_prefix() {
514        let cidr: Cidr = "2001:db8::/0".parse().unwrap();
515        assert_eq!(cidr.prefix_length(), 0);
516        assert_eq!(cidr.size(), u128::MAX);
517    }
518
519    #[test]
520    fn test_ipv4_max_prefix() {
521        let cidr: Cidr = "192.168.1.0/32".parse().unwrap();
522        assert_eq!(cidr.prefix_length(), 32);
523        assert_eq!(cidr.size(), 1);
524    }
525
526    #[test]
527    fn test_ipv4_zero_prefix() {
528        let cidr: Cidr = "192.168.1.0/0".parse().unwrap();
529        assert_eq!(cidr.prefix_length(), 0);
530        assert_eq!(cidr.size(), 1u128 << 32);
531    }
532}