use crate::error::AntiSSRFError;
use ipnetwork::{IpNetwork, Ipv6Network};
use std::fmt;
use std::net::IpAddr;
use std::str::FromStr;
#[derive(Debug, Clone, PartialEq)]
pub struct CIDRBlock {
network: Ipv6Network,
original: String,
}
impl CIDRBlock {
pub fn parse(cidr: &str) -> Result<Self, AntiSSRFError> {
let original = cidr.to_string();
let network =
IpNetwork::from_str(cidr).map_err(|e| AntiSSRFError::InvalidCIDR(e.to_string()))?;
let ipv6_network = match network {
IpNetwork::V4(v4) => {
let mapped = v4.network().to_ipv6_mapped();
let prefix = v4.prefix() + 96;
Ipv6Network::new(mapped, prefix)
.map_err(|e| AntiSSRFError::InvalidCIDR(e.to_string()))?
}
IpNetwork::V6(v6) => v6,
};
Ok(Self {
network: ipv6_network,
original,
})
}
pub fn contains(&self, addr: IpAddr) -> bool {
let ipv6_addr = match addr {
IpAddr::V4(v4) => v4.to_ipv6_mapped(),
IpAddr::V6(v6) => v6,
};
self.network.contains(ipv6_addr)
}
pub fn contains_cidr(&self, other: &CIDRBlock) -> bool {
if other.network.prefix() < self.network.prefix() {
return false;
}
self.network.contains(other.network.network())
}
}
impl fmt::Display for CIDRBlock {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.original)
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::net::{Ipv4Addr, Ipv6Addr};
#[test]
fn parse_valid_ipv4() {
let block = CIDRBlock::parse("10.0.0.0/8").unwrap();
assert_eq!(block.to_string(), "10.0.0.0/8");
}
#[test]
fn parse_valid_ipv6() {
let block = CIDRBlock::parse("fe80::/10").unwrap();
assert_eq!(block.to_string(), "fe80::/10");
}
#[test]
fn parse_invalid_cidr() {
assert!(matches!(
CIDRBlock::parse("invalid"),
Err(AntiSSRFError::InvalidCIDR(_))
));
}
#[test]
fn parse_invalid_prefix() {
assert!(matches!(
CIDRBlock::parse("10.0.0.0/33"),
Err(AntiSSRFError::InvalidCIDR(_))
));
}
#[test]
fn contains_ipv4_in_ipv4_cidr() {
let block = CIDRBlock::parse("10.0.0.0/8").unwrap();
assert!(block.contains(IpAddr::V4(Ipv4Addr::new(10, 1, 2, 3))));
assert!(!block.contains(IpAddr::V4(Ipv4Addr::new(11, 0, 0, 1))));
}
#[test]
fn contains_ipv4_mapped_in_ipv4_cidr() {
let block = CIDRBlock::parse("10.0.0.0/8").unwrap();
let mapped = IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0xFFFF, 0x0a01, 0x0203));
assert!(block.contains(mapped));
}
#[test]
fn contains_ipv6_in_ipv6_cidr() {
let block = CIDRBlock::parse("fe80::/10").unwrap();
assert!(block.contains(IpAddr::V6(Ipv6Addr::new(0xfe80, 0, 0, 0, 0, 0, 0, 1))));
assert!(!block.contains(IpAddr::V6(Ipv6Addr::new(0xfc00, 0, 0, 0, 0, 0, 0, 1))));
}
#[test]
fn contains_ipv4_in_ipv6_mapped_cidr() {
let block = CIDRBlock::parse("::ffff:10.0.0.0/104").unwrap();
assert!(block.contains(IpAddr::V4(Ipv4Addr::new(10, 1, 2, 3))));
}
#[test]
fn contains_cidr_ipv4() {
let parent = CIDRBlock::parse("10.0.0.0/8").unwrap();
let child = CIDRBlock::parse("10.1.0.0/16").unwrap();
assert!(parent.contains_cidr(&child));
let non_child = CIDRBlock::parse("11.0.0.0/8").unwrap();
assert!(!parent.contains_cidr(&non_child));
let same_prefix = CIDRBlock::parse("10.0.0.0/8").unwrap();
assert!(parent.contains_cidr(&same_prefix));
}
#[test]
fn contains_cidr_larger_prefix_fails() {
let small = CIDRBlock::parse("10.1.0.0/16").unwrap();
let large = CIDRBlock::parse("10.0.0.0/8").unwrap();
assert!(!small.contains_cidr(&large));
}
#[test]
fn display_preserves_original() {
let block = CIDRBlock::parse("192.168.0.0/16").unwrap();
assert_eq!(format!("{}", block), "192.168.0.0/16");
}
#[test]
fn clone_and_eq() {
let a = CIDRBlock::parse("10.0.0.0/8").unwrap();
let b = a.clone();
assert_eq!(a, b);
}
}