Skip to main content

bare_types/net/
cidr.rs

1//! CIDR (Classless Inter-Domain Routing) notation for IP networks.
2//!
3//! This module provides a type-safe abstraction for CIDR notation,
4//! which represents IP network prefixes in the form `address/prefix_length`.
5//!
6//! # Examples
7//!
8//! ```rust
9//! use bare_types::net::{Cidr, IpAddr};
10//!
11//! // Parse CIDR notation
12//! let cidr: Cidr = "192.168.1.0/24".parse().unwrap();
13//!
14//! // Check if an IP address is in network
15//! let ip: IpAddr = "192.168.1.100".parse().unwrap();
16//! assert!(cidr.contains(&ip));
17//!
18//! // Get network address
19//! let network = cidr.network_address();
20//! println!("Network: {}", network);
21//!
22//! // Get broadcast address (only for IPv4)
23//! if let Some(broadcast) = cidr.broadcast_address() {
24//!     println!("Broadcast: {}", broadcast);
25//! }
26//! ```
27
28use core::fmt;
29use core::str::FromStr;
30
31use super::IpAddr;
32
33#[cfg(feature = "serde")]
34use serde::{Deserialize, Serialize};
35
36/// Error type for CIDR parsing.
37#[derive(Debug, Clone, Copy, PartialEq, Eq)]
38#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
39#[non_exhaustive]
40pub enum CidrError {
41    /// Invalid CIDR format
42    ///
43    /// The input string is not in the correct CIDR format.
44    /// Expected format: "`address/prefix_length`" (e.g., "192.168.1.0/24").
45    InvalidFormat,
46    /// Invalid prefix length
47    ///
48    /// The prefix length is not valid for the IP address type.
49    /// IPv4 addresses require prefix length 0-32.
50    /// IPv6 addresses require prefix length 0-128.
51    InvalidPrefixLength,
52    /// Invalid IP address
53    ///
54    /// The address part of the CIDR notation is not a valid IP address.
55    InvalidIpAddr,
56}
57
58impl fmt::Display for CidrError {
59    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
60        match self {
61            Self::InvalidFormat => write!(f, "invalid CIDR format"),
62            Self::InvalidPrefixLength => write!(f, "invalid prefix length"),
63            Self::InvalidIpAddr => write!(f, "invalid IP address"),
64        }
65    }
66}
67
68#[cfg(feature = "std")]
69impl std::error::Error for CidrError {}
70
71/// A CIDR (Classless Inter-Domain Routing) notation for IP networks.
72///
73/// This type represents an IP network prefix in the form `address/prefix_length`,
74/// where `address` is an IP address and `prefix_length` is the number of
75/// leading bits that represent the network portion.
76///
77/// # Examples
78///
79/// ```rust
80/// use bare_types::net::{Cidr, IpAddr};
81///
82/// // Create CIDR from string
83/// let cidr: Cidr = "192.168.1.0/24".parse().unwrap();
84///
85/// // Check if IP is in network
86/// let check_ip: IpAddr = "192.168.1.100".parse().unwrap();
87/// assert!(cidr.contains(&check_ip));
88///
89/// // Get network address
90/// let network = cidr.network_address();
91///
92/// // Get broadcast address (only for IPv4)
93/// let broadcast = cidr.broadcast_address();
94///
95/// // Create from IP address and prefix length
96/// let ip: IpAddr = "10.0.0.0".parse().unwrap();
97/// let cidr2 = Cidr::new(ip, 8).unwrap();
98/// ```
99#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
100#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
101pub struct Cidr {
102    /// The IP address
103    address: IpAddr,
104    /// The prefix length (number of network bits)
105    prefix_length: u8,
106}
107
108impl Cidr {
109    /// Creates a new CIDR from an IP address and prefix length.
110    ///
111    /// # Errors
112    ///
113    /// Returns `CidrError::InvalidPrefixLength` if the prefix length is invalid:
114    /// - IPv4: 0-32
115    /// - IPv6: 0-128
116    ///
117    /// # Examples
118    ///
119    /// ```rust
120    /// use bare_types::net::{Cidr, IpAddr};
121    ///
122    /// let ip: IpAddr = "192.168.1.0".parse().unwrap();
123    /// let cidr = Cidr::new(ip, 24).unwrap();
124    /// ```
125    pub const fn new(address: IpAddr, prefix_length: u8) -> Result<Self, CidrError> {
126        let max_prefix = if address.as_inner().is_ipv4() {
127            32
128        } else {
129            128
130        };
131
132        if prefix_length > max_prefix {
133            return Err(CidrError::InvalidPrefixLength);
134        }
135
136        Ok(Self {
137            address,
138            prefix_length,
139        })
140    }
141
142    /// Returns the IP address.
143    ///
144    /// # Examples
145    ///
146    /// ```rust
147    /// use bare_types::net::{Cidr, IpAddr};
148    ///
149    /// let cidr: Cidr = "192.168.1.0/24".parse().unwrap();
150    /// let ip = cidr.address();
151    /// ```
152    #[must_use]
153    #[inline]
154    pub const fn address(&self) -> IpAddr {
155        self.address
156    }
157
158    /// Returns the prefix length.
159    ///
160    /// # Examples
161    ///
162    /// ```rust
163    /// use bare_types::net::{Cidr, IpAddr};
164    ///
165    /// let cidr: Cidr = "192.168.1.0/24".parse().unwrap();
166    /// assert_eq!(cidr.prefix_length(), 24);
167    /// ```
168    #[must_use]
169    #[inline]
170    pub const fn prefix_length(&self) -> u8 {
171        self.prefix_length
172    }
173
174    /// Returns the network address.
175    ///
176    /// The network address is the IP address with all host bits set to zero.
177    ///
178    /// # Examples
179    ///
180    /// ```rust
181    /// use bare_types::net::Cidr;
182    ///
183    /// let cidr: Cidr = "192.168.1.100/24".parse().unwrap();
184    /// let network = cidr.network_address();
185    /// assert_eq!(network.to_string(), "192.168.1.0");
186    /// ```
187    #[must_use]
188    pub fn network_address(&self) -> IpAddr {
189        let inner = self.address.as_inner();
190
191        if let core::net::IpAddr::V4(ip) = inner {
192            let octets = ip.octets();
193            let ip = u32::from_be_bytes([octets[0], octets[1], octets[2], octets[3]]);
194            let mask = u32::MAX << (32 - u32::from(self.prefix_length));
195            let network = ip & mask;
196            let network_octets = network.to_be_bytes();
197            IpAddr::new(core::net::IpAddr::V4(core::net::Ipv4Addr::new(
198                network_octets[0],
199                network_octets[1],
200                network_octets[2],
201                network_octets[3],
202            )))
203        } else if let core::net::IpAddr::V6(ip) = inner {
204            let segments = ip.segments();
205            let mut network_segments = [0u16; 8];
206
207            let full_segments = (self.prefix_length / 16) as usize;
208            let partial_bits = self.prefix_length % 16;
209
210            network_segments[..full_segments].copy_from_slice(&segments[..full_segments]);
211
212            if partial_bits > 0 && full_segments < 8 {
213                let mask = u16::MAX << (16 - u32::from(partial_bits));
214                network_segments[full_segments] = segments[full_segments] & mask;
215            }
216
217            IpAddr::new(core::net::IpAddr::V6(core::net::Ipv6Addr::new(
218                network_segments[0],
219                network_segments[1],
220                network_segments[2],
221                network_segments[3],
222                network_segments[4],
223                network_segments[5],
224                network_segments[6],
225                network_segments[7],
226            )))
227        } else {
228            self.address
229        }
230    }
231
232    /// Returns the broadcast address.
233    ///
234    /// The broadcast address is the IP address with all host bits set to one.
235    /// Only applicable for IPv4 networks.
236    ///
237    /// # Examples
238    ///
239    /// ```rust
240    /// use bare_types::net::Cidr;
241    ///
242    /// let cidr: Cidr = "192.168.1.0/24".parse().unwrap();
243    /// let broadcast = cidr.broadcast_address().unwrap();
244    /// assert_eq!(broadcast.to_string(), "192.168.1.255");
245    /// ```
246    #[must_use]
247    #[allow(clippy::missing_const_for_fn)]
248    pub fn broadcast_address(&self) -> Option<IpAddr> {
249        let inner = self.address.as_inner();
250
251        if let core::net::IpAddr::V4(ip) = inner {
252            let octets = ip.octets();
253            let ip = u32::from_be_bytes([octets[0], octets[1], octets[2], octets[3]]);
254            let mask = u32::MAX << (32 - u32::from(self.prefix_length));
255            let broadcast = ip | !mask;
256            let broadcast_octets = broadcast.to_be_bytes();
257            Some(IpAddr::new(core::net::IpAddr::V4(
258                core::net::Ipv4Addr::new(
259                    broadcast_octets[0],
260                    broadcast_octets[1],
261                    broadcast_octets[2],
262                    broadcast_octets[3],
263                ),
264            )))
265        } else {
266            None // IPv6 doesn't have broadcast
267        }
268    }
269
270    /// Checks if an IP address is contained in this CIDR network.
271    ///
272    /// # Examples
273    ///
274    /// ```rust
275    /// use bare_types::net::{Cidr, IpAddr};
276    ///
277    /// let cidr: Cidr = "192.168.1.0/24".parse().unwrap();
278    /// let ip1: IpAddr = "192.168.1.100".parse().unwrap();
279    /// let ip2: IpAddr = "192.168.2.100".parse().unwrap();
280    ///
281    /// assert!(cidr.contains(&ip1));
282    /// assert!(!cidr.contains(&ip2));
283    /// ```
284    #[must_use]
285    pub fn contains(&self, ip: &IpAddr) -> bool {
286        let network = self.network_address();
287        let network_inner = network.as_inner();
288        let ip_inner = ip.as_inner();
289
290        match (network_inner, ip_inner) {
291            (core::net::IpAddr::V4(network), core::net::IpAddr::V4(ip)) => {
292                let network_octets = network.octets();
293                let network = u32::from_be_bytes([
294                    network_octets[0],
295                    network_octets[1],
296                    network_octets[2],
297                    network_octets[3],
298                ]);
299                let ip_octets = ip.octets();
300                let ip =
301                    u32::from_be_bytes([ip_octets[0], ip_octets[1], ip_octets[2], ip_octets[3]]);
302                let mask = u32::MAX << (32 - u32::from(self.prefix_length));
303                (network & mask) == (ip & mask)
304            }
305            (core::net::IpAddr::V6(network), core::net::IpAddr::V6(ip)) => {
306                let network_segments = network.segments();
307                let ip_segments = ip.segments();
308
309                let full_segments = (self.prefix_length / 16) as usize;
310                let partial_bits = self.prefix_length % 16;
311
312                for i in 0..full_segments {
313                    if network_segments[i] != ip_segments[i] {
314                        return false;
315                    }
316                }
317
318                if partial_bits > 0 && full_segments < 8 {
319                    let mask = u16::MAX << (16 - u32::from(partial_bits));
320                    if (network_segments[full_segments] & mask)
321                        != (ip_segments[full_segments] & mask)
322                    {
323                        return false;
324                    }
325                }
326
327                true
328            }
329            _ => false,
330        }
331    }
332
333    /// Returns the number of addresses in this network.
334    ///
335    /// # Examples
336    ///
337    /// ```rust
338    /// use bare_types::net::Cidr;
339    ///
340    /// let cidr: Cidr = "192.168.1.0/24".parse().unwrap();
341    /// assert_eq!(cidr.size(), 256);
342    /// ```
343    #[must_use]
344    #[allow(clippy::missing_const_for_fn)]
345    pub fn size(&self) -> u128 {
346        if self.address.as_inner().is_ipv4() {
347            1u128 << (32 - u32::from(self.prefix_length))
348        } else {
349            let shift = 128 - u32::from(self.prefix_length);
350            if shift == 128 {
351                u128::MAX
352            } else {
353                1u128 << shift
354            }
355        }
356    }
357}
358
359impl FromStr for Cidr {
360    type Err = CidrError;
361
362    fn from_str(s: &str) -> Result<Self, Self::Err> {
363        let parts: Vec<&str> = s.split('/').collect();
364
365        if parts.len() != 2 {
366            return Err(CidrError::InvalidFormat);
367        }
368
369        let address: IpAddr = parts[0].parse().map_err(|_| CidrError::InvalidIpAddr)?;
370
371        let prefix_length: u8 = parts[1]
372            .parse()
373            .map_err(|_| CidrError::InvalidPrefixLength)?;
374
375        Self::new(address, prefix_length)
376    }
377}
378
379impl fmt::Display for Cidr {
380    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
381        write!(f, "{}/{}", self.address, self.prefix_length)
382    }
383}
384
385#[cfg(test)]
386mod tests {
387    use super::*;
388
389    #[test]
390    fn test_cidr_creation() {
391        let ip: IpAddr = "192.168.1.0".parse().unwrap();
392        let cidr = Cidr::new(ip, 24).unwrap();
393        assert_eq!(cidr.address(), ip);
394        assert_eq!(cidr.prefix_length(), 24);
395    }
396
397    #[test]
398    fn test_cidr_parsing() {
399        let cidr: Cidr = "192.168.1.0/24".parse().unwrap();
400        assert_eq!(cidr.address().to_string(), "192.168.1.0");
401        assert_eq!(cidr.prefix_length(), 24);
402    }
403
404    #[test]
405    fn test_invalid_prefix_length() {
406        let ip: IpAddr = "192.168.1.0".parse().unwrap();
407        assert!(Cidr::new(ip, 33).is_err());
408    }
409
410    #[test]
411    fn test_network_address() {
412        let cidr: Cidr = "192.168.1.100/24".parse().unwrap();
413        let network = cidr.network_address();
414        assert_eq!(network.to_string(), "192.168.1.0");
415    }
416
417    #[test]
418    fn test_broadcast_address() {
419        let cidr: Cidr = "192.168.1.0/24".parse().unwrap();
420        let broadcast = cidr.broadcast_address().unwrap();
421        assert_eq!(broadcast.to_string(), "192.168.1.255");
422    }
423
424    #[test]
425    fn test_contains() {
426        let cidr: Cidr = "192.168.1.0/24".parse().unwrap();
427
428        let ip1: IpAddr = "192.168.1.100".parse().unwrap();
429        let ip2: IpAddr = "192.168.2.100".parse().unwrap();
430
431        assert!(cidr.contains(&ip1));
432        assert!(!cidr.contains(&ip2));
433    }
434
435    #[test]
436    fn test_size() {
437        let cidr: Cidr = "192.168.1.0/24".parse().unwrap();
438        assert_eq!(cidr.size(), 256);
439    }
440
441    #[test]
442    fn test_ipv6_cidr() {
443        let cidr: Cidr = "2001:db8::/32".parse().unwrap();
444        assert_eq!(cidr.prefix_length(), 32);
445
446        let ip: IpAddr = "2001:db8:85a3::8a2e:370:7334".parse().unwrap();
447        assert!(cidr.contains(&ip));
448    }
449
450    #[test]
451    fn test_display() {
452        let cidr: Cidr = "192.168.1.0/24".parse().unwrap();
453        assert_eq!(format!("{}", cidr), "192.168.1.0/24");
454    }
455
456    #[test]
457    fn test_ipv6_network_address() {
458        let cidr: Cidr = "2001:db8:85a3::8a2e:370:7334/64".parse().unwrap();
459        let network = cidr.network_address();
460        assert_eq!(network.to_string(), "2001:db8:85a3::");
461    }
462
463    #[test]
464    fn test_ipv6_contains() {
465        let cidr: Cidr = "2001:db8::/32".parse().unwrap();
466        let ip1: IpAddr = "2001:db8:85a3::8a2e:370:7334".parse().unwrap();
467        let ip2: IpAddr = "2001:db9::1".parse().unwrap();
468
469        assert!(cidr.contains(&ip1));
470        assert!(!cidr.contains(&ip2));
471    }
472
473    #[test]
474    fn test_ipv6_size() {
475        let cidr: Cidr = "2001:db8::/32".parse().unwrap();
476        assert_eq!(cidr.size(), 1u128 << 96);
477    }
478
479    #[test]
480    fn test_ipv6_broadcast_none() {
481        let cidr: Cidr = "2001:db8::/32".parse().unwrap();
482        assert!(cidr.broadcast_address().is_none());
483    }
484
485    #[test]
486    fn test_ipv6_max_prefix() {
487        let cidr: Cidr = "2001:db8::/128".parse().unwrap();
488        assert_eq!(cidr.prefix_length(), 128);
489        assert_eq!(cidr.size(), 1);
490    }
491
492    #[test]
493    fn test_ipv6_zero_prefix() {
494        let cidr: Cidr = "2001:db8::/0".parse().unwrap();
495        assert_eq!(cidr.prefix_length(), 0);
496        assert_eq!(cidr.size(), u128::MAX);
497    }
498
499    #[test]
500    fn test_ipv4_max_prefix() {
501        let cidr: Cidr = "192.168.1.0/32".parse().unwrap();
502        assert_eq!(cidr.prefix_length(), 32);
503        assert_eq!(cidr.size(), 1);
504    }
505
506    #[test]
507    fn test_ipv4_zero_prefix() {
508        let cidr: Cidr = "192.168.1.0/0".parse().unwrap();
509        assert_eq!(cidr.prefix_length(), 0);
510        assert_eq!(cidr.size(), 1u128 << 32);
511    }
512}