use std::net::IpAddr;
#[derive(Debug, Clone)]
pub struct CidrRange {
network: IpAddr,
prefix_len: u8,
mask: u128,
}
impl CidrRange {
pub fn parse(s: &str) -> Option<Self> {
if let Some((addr_str, prefix_str)) = s.split_once('/') {
let addr: IpAddr = addr_str.parse().ok()?;
let prefix_len: u8 = prefix_str.parse().ok()?;
let max_prefix = if addr.is_ipv4() { 32 } else { 128 };
if prefix_len > max_prefix {
return None;
}
let total_bits: u32 = if addr.is_ipv4() { 32 } else { 128 };
let mask = if prefix_len == 0 {
0
} else if prefix_len as u32 == total_bits {
u128::MAX
} else {
u128::MAX << (total_bits - prefix_len as u32)
};
Some(Self {
network: addr,
prefix_len,
mask,
})
} else {
let addr: IpAddr = s.parse().ok()?;
let prefix_len = if addr.is_ipv4() { 32 } else { 128 };
Some(Self {
network: addr,
prefix_len,
mask: u128::MAX,
})
}
}
pub fn prefix_len(&self) -> u8 {
self.prefix_len
}
pub fn contains(&self, ip: &IpAddr) -> bool {
let ip_bits = ip_to_u128(ip);
let net_bits = ip_to_u128(&self.network);
(ip_bits & self.mask) == (net_bits & self.mask)
}
}
pub fn check_ip_against_cidrs<'a>(ip_str: &str, entries: &'a [String]) -> Option<&'a str> {
let ip: IpAddr = ip_str.parse().ok()?;
for entry in entries {
if let Some(cidr) = CidrRange::parse(entry)
&& cidr.contains(&ip)
{
return Some(entry);
}
}
None
}
fn ip_to_u128(ip: &IpAddr) -> u128 {
match ip {
IpAddr::V4(v4) => u128::from(u32::from_be_bytes(v4.octets())),
IpAddr::V6(v6) => u128::from_be_bytes(v6.octets()),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn single_ip_match() {
let cidr = CidrRange::parse("192.168.1.100").unwrap();
assert!(cidr.contains(&"192.168.1.100".parse().unwrap()));
assert!(!cidr.contains(&"192.168.1.101".parse().unwrap()));
}
#[test]
fn cidr_range_match() {
let cidr = CidrRange::parse("10.0.0.0/8").unwrap();
assert!(cidr.contains(&"10.0.0.1".parse().unwrap()));
assert!(cidr.contains(&"10.255.255.255".parse().unwrap()));
assert!(!cidr.contains(&"11.0.0.1".parse().unwrap()));
}
#[test]
fn cidr_24_match() {
let cidr = CidrRange::parse("192.168.1.0/24").unwrap();
assert!(cidr.contains(&"192.168.1.1".parse().unwrap()));
assert!(cidr.contains(&"192.168.1.254".parse().unwrap()));
assert!(!cidr.contains(&"192.168.2.1".parse().unwrap()));
}
#[test]
fn check_against_list() {
let entries = vec!["10.0.0.0/8".into(), "192.168.1.100".into()];
assert_eq!(
check_ip_against_cidrs("10.0.0.5", &entries),
Some("10.0.0.0/8")
);
assert_eq!(
check_ip_against_cidrs("192.168.1.100", &entries),
Some("192.168.1.100")
);
assert_eq!(check_ip_against_cidrs("172.16.0.1", &entries), None);
}
#[test]
fn invalid_cidr_returns_none() {
assert!(CidrRange::parse("not-an-ip").is_none());
assert!(CidrRange::parse("10.0.0.0/33").is_none());
}
}