use core::fmt;
use core::str::FromStr;
use super::IpAddr;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[non_exhaustive]
pub enum CidrError {
InvalidFormat,
InvalidPrefixLength,
InvalidIpAddr,
}
impl fmt::Display for CidrError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::InvalidFormat => write!(f, "invalid CIDR format"),
Self::InvalidPrefixLength => write!(f, "invalid prefix length"),
Self::InvalidIpAddr => write!(f, "invalid IP address"),
}
}
}
#[cfg(feature = "std")]
impl std::error::Error for CidrError {}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct Cidr {
address: IpAddr,
prefix_length: u8,
}
impl Cidr {
#[inline]
pub const fn new(address: IpAddr, prefix_length: u8) -> Result<Self, CidrError> {
let max_prefix = if address.as_inner().is_ipv4() {
32
} else {
128
};
if prefix_length > max_prefix {
return Err(CidrError::InvalidPrefixLength);
}
Ok(Self {
address,
prefix_length,
})
}
#[must_use]
#[inline]
pub const fn address(&self) -> IpAddr {
self.address
}
#[must_use]
#[inline]
pub const fn prefix_length(&self) -> u8 {
self.prefix_length
}
#[must_use]
pub fn network_address(&self) -> IpAddr {
let inner = self.address.as_inner();
if let core::net::IpAddr::V4(ip) = inner {
let octets = ip.octets();
let ip = u32::from_be_bytes([octets[0], octets[1], octets[2], octets[3]]);
let mask = u32::MAX << (32 - u32::from(self.prefix_length));
let network = ip & mask;
let network_octets = network.to_be_bytes();
IpAddr::new(core::net::IpAddr::V4(core::net::Ipv4Addr::new(
network_octets[0],
network_octets[1],
network_octets[2],
network_octets[3],
)))
} else if let core::net::IpAddr::V6(ip) = inner {
let segments = ip.segments();
let mut network_segments = [0u16; 8];
let full_segments = (self.prefix_length / 16) as usize;
let partial_bits = self.prefix_length % 16;
network_segments[..full_segments].copy_from_slice(&segments[..full_segments]);
if partial_bits > 0 && full_segments < 8 {
let mask = u16::MAX << (16 - u32::from(partial_bits));
network_segments[full_segments] = segments[full_segments] & mask;
}
IpAddr::new(core::net::IpAddr::V6(core::net::Ipv6Addr::new(
network_segments[0],
network_segments[1],
network_segments[2],
network_segments[3],
network_segments[4],
network_segments[5],
network_segments[6],
network_segments[7],
)))
} else {
self.address
}
}
#[must_use]
#[allow(clippy::missing_const_for_fn)]
pub fn broadcast_address(&self) -> Option<IpAddr> {
let inner = self.address.as_inner();
if let core::net::IpAddr::V4(ip) = inner {
let octets = ip.octets();
let ip = u32::from_be_bytes([octets[0], octets[1], octets[2], octets[3]]);
let mask = u32::MAX << (32 - u32::from(self.prefix_length));
let broadcast = ip | !mask;
let broadcast_octets = broadcast.to_be_bytes();
Some(IpAddr::new(core::net::IpAddr::V4(
core::net::Ipv4Addr::new(
broadcast_octets[0],
broadcast_octets[1],
broadcast_octets[2],
broadcast_octets[3],
),
)))
} else {
None }
}
#[inline]
#[must_use]
pub fn contains(&self, ip: &IpAddr) -> bool {
let network = self.network_address();
let network_inner = network.as_inner();
let ip_inner = ip.as_inner();
match (network_inner, ip_inner) {
(core::net::IpAddr::V4(network), core::net::IpAddr::V4(ip)) => {
let network_octets = network.octets();
let network = u32::from_be_bytes([
network_octets[0],
network_octets[1],
network_octets[2],
network_octets[3],
]);
let ip_octets = ip.octets();
let ip =
u32::from_be_bytes([ip_octets[0], ip_octets[1], ip_octets[2], ip_octets[3]]);
let mask = u32::MAX << (32 - u32::from(self.prefix_length));
(network & mask) == (ip & mask)
}
(core::net::IpAddr::V6(network), core::net::IpAddr::V6(ip)) => {
let network_segments = network.segments();
let ip_segments = ip.segments();
let full_segments = (self.prefix_length / 16) as usize;
let partial_bits = self.prefix_length % 16;
for i in 0..full_segments {
if network_segments[i] != ip_segments[i] {
return false;
}
}
if partial_bits > 0 && full_segments < 8 {
let mask = u16::MAX << (16 - u32::from(partial_bits));
if (network_segments[full_segments] & mask)
!= (ip_segments[full_segments] & mask)
{
return false;
}
}
true
}
_ => false,
}
}
#[inline]
#[must_use]
#[allow(clippy::missing_const_for_fn)]
pub fn size(&self) -> u128 {
if self.address.as_inner().is_ipv4() {
1u128 << (32 - u32::from(self.prefix_length))
} else {
let shift = 128 - u32::from(self.prefix_length);
if shift == 128 {
u128::MAX
} else {
1u128 << shift
}
}
}
}
#[cfg(feature = "arbitrary")]
impl<'a> arbitrary::Arbitrary<'a> for Cidr {
fn arbitrary(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result<Self> {
let address = IpAddr::arbitrary(u)?;
let max_prefix = if address.as_inner().is_ipv4() {
32
} else {
128
};
let prefix_length = u8::arbitrary(u)? % (max_prefix + 1);
Ok(Self {
address,
prefix_length,
})
}
}
impl FromStr for Cidr {
type Err = CidrError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let parts: Vec<&str> = s.split('/').collect();
if parts.len() != 2 {
return Err(CidrError::InvalidFormat);
}
let address: IpAddr = parts[0].parse().map_err(|_| CidrError::InvalidIpAddr)?;
let prefix_length: u8 = parts[1]
.parse()
.map_err(|_| CidrError::InvalidPrefixLength)?;
Self::new(address, prefix_length)
}
}
impl fmt::Display for Cidr {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}/{}", self.address, self.prefix_length)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_cidr_creation() {
let ip: IpAddr = "192.168.1.0".parse().unwrap();
let cidr = Cidr::new(ip, 24).unwrap();
assert_eq!(cidr.address(), ip);
assert_eq!(cidr.prefix_length(), 24);
}
#[test]
fn test_cidr_parsing() {
let cidr: Cidr = "192.168.1.0/24".parse().unwrap();
assert_eq!(cidr.address().to_string(), "192.168.1.0");
assert_eq!(cidr.prefix_length(), 24);
}
#[test]
fn test_invalid_prefix_length() {
let ip: IpAddr = "192.168.1.0".parse().unwrap();
assert!(Cidr::new(ip, 33).is_err());
}
#[test]
fn test_network_address() {
let cidr: Cidr = "192.168.1.100/24".parse().unwrap();
let network = cidr.network_address();
assert_eq!(network.to_string(), "192.168.1.0");
}
#[test]
fn test_broadcast_address() {
let cidr: Cidr = "192.168.1.0/24".parse().unwrap();
let broadcast = cidr.broadcast_address().unwrap();
assert_eq!(broadcast.to_string(), "192.168.1.255");
}
#[test]
fn test_contains() {
let cidr: Cidr = "192.168.1.0/24".parse().unwrap();
let ip1: IpAddr = "192.168.1.100".parse().unwrap();
let ip2: IpAddr = "192.168.2.100".parse().unwrap();
assert!(cidr.contains(&ip1));
assert!(!cidr.contains(&ip2));
}
#[test]
fn test_size() {
let cidr: Cidr = "192.168.1.0/24".parse().unwrap();
assert_eq!(cidr.size(), 256);
}
#[test]
fn test_ipv6_cidr() {
let cidr: Cidr = "2001:db8::/32".parse().unwrap();
assert_eq!(cidr.prefix_length(), 32);
let ip: IpAddr = "2001:db8:85a3::8a2e:370:7334".parse().unwrap();
assert!(cidr.contains(&ip));
}
#[test]
fn test_display() {
let cidr: Cidr = "192.168.1.0/24".parse().unwrap();
assert_eq!(format!("{}", cidr), "192.168.1.0/24");
}
#[test]
fn test_ipv6_network_address() {
let cidr: Cidr = "2001:db8:85a3::8a2e:370:7334/64".parse().unwrap();
let network = cidr.network_address();
assert_eq!(network.to_string(), "2001:db8:85a3::");
}
#[test]
fn test_ipv6_contains() {
let cidr: Cidr = "2001:db8::/32".parse().unwrap();
let ip1: IpAddr = "2001:db8:85a3::8a2e:370:7334".parse().unwrap();
let ip2: IpAddr = "2001:db9::1".parse().unwrap();
assert!(cidr.contains(&ip1));
assert!(!cidr.contains(&ip2));
}
#[test]
fn test_ipv6_size() {
let cidr: Cidr = "2001:db8::/32".parse().unwrap();
assert_eq!(cidr.size(), 1u128 << 96);
}
#[test]
fn test_ipv6_broadcast_none() {
let cidr: Cidr = "2001:db8::/32".parse().unwrap();
assert!(cidr.broadcast_address().is_none());
}
#[test]
fn test_ipv6_max_prefix() {
let cidr: Cidr = "2001:db8::/128".parse().unwrap();
assert_eq!(cidr.prefix_length(), 128);
assert_eq!(cidr.size(), 1);
}
#[test]
fn test_ipv6_zero_prefix() {
let cidr: Cidr = "2001:db8::/0".parse().unwrap();
assert_eq!(cidr.prefix_length(), 0);
assert_eq!(cidr.size(), u128::MAX);
}
#[test]
fn test_ipv4_max_prefix() {
let cidr: Cidr = "192.168.1.0/32".parse().unwrap();
assert_eq!(cidr.prefix_length(), 32);
assert_eq!(cidr.size(), 1);
}
#[test]
fn test_ipv4_zero_prefix() {
let cidr: Cidr = "192.168.1.0/0".parse().unwrap();
assert_eq!(cidr.prefix_length(), 0);
assert_eq!(cidr.size(), 1u128 << 32);
}
}