use std::fmt;
use std::net::IpAddr;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct IpAddrMask {
pub addr: IpAddr,
pub prefix_len: u8,
}
impl IpAddrMask {
pub fn new(addr: IpAddr, prefix_len: u8) -> Self {
Self { addr, prefix_len }
}
pub fn from_cidr(s: &str) -> Result<Self, String> {
let parts: Vec<&str> = s.split('/').collect();
if parts.len() != 2 {
return Err(format!("Invalid CIDR notation: {}", s));
}
let addr: IpAddr = parts[0]
.parse()
.map_err(|e| format!("Invalid IP address: {}", e))?;
let prefix_len: u8 = parts[1]
.parse()
.map_err(|e| format!("Invalid prefix length: {}", e))?;
let max_prefix = match addr {
IpAddr::V4(_) => 32,
IpAddr::V6(_) => 128,
};
if prefix_len > max_prefix {
return Err(format!(
"Prefix length {} exceeds maximum {} for {:?}",
prefix_len, max_prefix, addr
));
}
Ok(Self { addr, prefix_len })
}
pub fn is_ipv6(&self) -> bool {
matches!(self.addr, IpAddr::V6(_))
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[repr(u8)]
pub enum Protocol {
Hopopt = 0,
Icmp = 1,
Igmp = 2,
Tcp = 6,
Udp = 17,
Gre = 47,
Esp = 50,
Ah = 51,
Icmpv6 = 58,
}
impl Protocol {
pub fn as_u8(self) -> u8 {
self as u8
}
}
impl fmt::Display for Protocol {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Protocol::Hopopt => write!(f, "HOPOPT"),
Protocol::Icmp => write!(f, "ICMP"),
Protocol::Igmp => write!(f, "IGMP"),
Protocol::Tcp => write!(f, "TCP"),
Protocol::Udp => write!(f, "UDP"),
Protocol::Gre => write!(f, "GRE"),
Protocol::Esp => write!(f, "ESP"),
Protocol::Ah => write!(f, "AH"),
Protocol::Icmpv6 => write!(f, "ICMPv6"),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_ip_addr_mask_new() {
let mask = IpAddrMask::new("10.0.0.1".parse().unwrap(), 24);
assert_eq!(mask.prefix_len, 24);
assert!(!mask.is_ipv6());
}
#[test]
fn test_ip_addr_mask_from_cidr_v4() {
let mask = IpAddrMask::from_cidr("192.168.1.0/24").unwrap();
assert_eq!(mask.addr, "192.168.1.0".parse::<IpAddr>().unwrap());
assert_eq!(mask.prefix_len, 24);
}
#[test]
fn test_ip_addr_mask_from_cidr_v6() {
let mask = IpAddrMask::from_cidr("fe80::1/64").unwrap();
assert!(mask.is_ipv6());
assert_eq!(mask.prefix_len, 64);
}
#[test]
fn test_ip_addr_mask_invalid_cidr() {
assert!(IpAddrMask::from_cidr("not-an-ip/24").is_err());
assert!(IpAddrMask::from_cidr("192.168.1.0").is_err());
assert!(IpAddrMask::from_cidr("192.168.1.0/33").is_err());
assert!(IpAddrMask::from_cidr("::1/129").is_err());
}
#[test]
fn test_protocol_values() {
assert_eq!(Protocol::Hopopt.as_u8(), 0);
assert_eq!(Protocol::Icmp.as_u8(), 1);
assert_eq!(Protocol::Igmp.as_u8(), 2);
assert_eq!(Protocol::Tcp.as_u8(), 6);
assert_eq!(Protocol::Udp.as_u8(), 17);
assert_eq!(Protocol::Gre.as_u8(), 47);
assert_eq!(Protocol::Esp.as_u8(), 50);
assert_eq!(Protocol::Ah.as_u8(), 51);
assert_eq!(Protocol::Icmpv6.as_u8(), 58);
}
#[test]
fn test_ip_addr_mask_boundary_prefixes_v4() {
let zero = IpAddrMask::from_cidr("0.0.0.0/0").unwrap();
assert_eq!(zero.prefix_len, 0);
let host = IpAddrMask::from_cidr("10.0.0.1/32").unwrap();
assert_eq!(host.prefix_len, 32);
}
#[test]
fn test_ip_addr_mask_boundary_prefixes_v6() {
let zero = IpAddrMask::from_cidr("::/0").unwrap();
assert_eq!(zero.prefix_len, 0);
assert!(zero.is_ipv6());
let host = IpAddrMask::from_cidr("::1/128").unwrap();
assert_eq!(host.prefix_len, 128);
}
#[test]
fn test_ip_addr_mask_equality() {
let a = IpAddrMask::new("10.0.0.1".parse().unwrap(), 24);
let b = IpAddrMask::new("10.0.0.1".parse().unwrap(), 24);
assert_eq!(a, b);
let c = IpAddrMask::new("10.0.0.1".parse().unwrap(), 16);
assert_ne!(a, c);
}
#[test]
fn test_ip_addr_mask_multiple_slashes() {
assert!(IpAddrMask::from_cidr("10.0.0.1/24/8").is_err());
}
#[test]
fn test_ip_addr_mask_empty_string() {
assert!(IpAddrMask::from_cidr("").is_err());
}
#[test]
fn test_ip_addr_mask_v4_is_not_ipv6() {
let mask = IpAddrMask::new("192.168.0.1".parse().unwrap(), 24);
assert!(!mask.is_ipv6());
}
#[test]
fn test_ip_addr_mask_v6_is_ipv6() {
let mask = IpAddrMask::new("::1".parse().unwrap(), 128);
assert!(mask.is_ipv6());
}
#[test]
fn test_protocol_copy() {
let p = Protocol::Tcp;
let p2 = p; assert_eq!(p, p2);
}
#[test]
fn test_protocol_display() {
assert_eq!(Protocol::Tcp.to_string(), "TCP");
assert_eq!(Protocol::Udp.to_string(), "UDP");
assert_eq!(Protocol::Icmp.to_string(), "ICMP");
assert_eq!(Protocol::Icmpv6.to_string(), "ICMPv6");
assert_eq!(Protocol::Gre.to_string(), "GRE");
assert_eq!(Protocol::Esp.to_string(), "ESP");
assert_eq!(Protocol::Ah.to_string(), "AH");
assert_eq!(Protocol::Igmp.to_string(), "IGMP");
assert_eq!(Protocol::Hopopt.to_string(), "HOPOPT");
}
}