use rfham_core::error::CoreError;
use serde_with::{DeserializeFromStr, SerializeDisplay};
use std::{fmt::Display, net::Ipv4Addr, str::FromStr};
#[derive(Clone, Debug, PartialEq, Eq, DeserializeFromStr, SerializeDisplay)]
pub struct IpNetwork {
address: u32,
mask: u32,
}
impl Display for IpNetwork {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}/{}", self.address(), self.prefix_length())
}
}
impl FromStr for IpNetwork {
type Err = CoreError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let pair: Vec<&str> = s.split('/').collect::<Vec<_>>();
if pair.len() == 2 {
let address = Ipv4Addr::from_str(pair[0])
.map_err(|_| CoreError::InvalidValueFromStr(s.to_string(), "IpNetwork"))?;
let prefix_length = u8::from_str(pair[1])
.map_err(|_| CoreError::InvalidValueFromStr(s.to_string(), "IpNetwork"))?;
if prefix_length > 32 {
return Err(CoreError::InvalidValueFromStr(s.to_string(), "IpNetwork"));
}
Ok(IpNetwork::from_cidr(address, prefix_length))
} else {
Err(CoreError::InvalidValueFromStr(s.to_string(), "IpNetwork"))
}
}
}
impl IpNetwork {
pub fn from_cidr(address: Ipv4Addr, prefix_length: u8) -> Self {
assert!(prefix_length <= 32);
let mask = match prefix_length {
0 => 0u32,
32 => u32::MAX,
n => !(u32::MAX >> n),
};
Self {
address: address.to_bits(),
mask,
}
}
pub fn from_mask(address: Ipv4Addr, net_mask: Ipv4Addr) -> Self {
Self {
address: address.to_bits(),
mask: net_mask.to_bits(),
}
}
pub fn address(&self) -> Ipv4Addr {
self.address.into()
}
pub fn address_u32(&self) -> u32 {
self.address
}
pub fn mask(&self) -> Ipv4Addr {
self.mask.into()
}
pub fn mask_u32(&self) -> u32 {
self.mask
}
pub fn prefix_length(&self) -> u8 {
self.mask_u32().leading_ones() as u8
}
pub fn contains(&self, address: Ipv4Addr) -> bool {
(self.address & self.mask) == (address.to_bits() & self.mask)
}
}
#[cfg(test)]
mod tests {
use super::IpNetwork;
use pretty_assertions::assert_eq;
use std::{net::Ipv4Addr, str::FromStr};
#[test]
fn test_cidr_parse_and_display() {
let net: IpNetwork = "192.168.1.0/24".parse().unwrap();
assert_eq!(net.to_string(), "192.168.1.0/24");
assert_eq!(net.prefix_length(), 24);
}
#[test]
fn test_cidr_contains() {
let net: IpNetwork = "10.0.0.0/8".parse().unwrap();
assert!(net.contains(Ipv4Addr::from_str("10.1.2.3").unwrap()));
assert!(net.contains(Ipv4Addr::from_str("10.255.255.255").unwrap()));
assert!(!net.contains(Ipv4Addr::from_str("11.0.0.1").unwrap()));
}
#[test]
fn test_cidr_host_route() {
let net: IpNetwork = "203.0.113.5/32".parse().unwrap();
assert_eq!(net.prefix_length(), 32);
assert!(net.contains(Ipv4Addr::from_str("203.0.113.5").unwrap()));
assert!(!net.contains(Ipv4Addr::from_str("203.0.113.6").unwrap()));
}
#[test]
fn test_from_mask() {
let addr = Ipv4Addr::from_str("192.168.0.0").unwrap();
let mask = Ipv4Addr::from_str("255.255.255.0").unwrap();
let net = IpNetwork::from_mask(addr, mask);
assert_eq!(net.prefix_length(), 24);
assert!(net.contains(Ipv4Addr::from_str("192.168.0.99").unwrap()));
}
#[test]
fn test_invalid_cidr_returns_error() {
assert!("notanip/24".parse::<IpNetwork>().is_err());
assert!("192.168.1.0/33".parse::<IpNetwork>().is_err()); assert!("192.168.1.0".parse::<IpNetwork>().is_err()); }
}