use crate::errors::Error;
use core::ops::Deref;
const PRIVATE_IP4_SUBNETS: [CIDR; 12] = [
CIDR::new(Ip::new([0, 0, 0, 0]), 8),
CIDR::new(Ip::new([10, 0, 0, 0]), 8),
CIDR::new(Ip::new([100, 64, 0, 0]), 10),
CIDR::new(Ip::new([127, 0, 0, 0]), 8),
CIDR::new(Ip::new([169, 254, 0, 0]), 16),
CIDR::new(Ip::new([172, 16, 0, 0]), 12),
CIDR::new(Ip::new([192, 0, 2, 0]), 24),
CIDR::new(Ip::new([192, 168, 0, 0]), 16),
CIDR::new(Ip::new([198, 18, 0, 0]), 15),
CIDR::new(Ip::new([198, 51, 100, 0]), 24),
CIDR::new(Ip::new([203, 0, 113, 0]), 24),
CIDR::new(Ip::new([255, 255, 255, 255]), 32),
];
#[derive(Debug, PartialEq, Eq, Copy, Clone)]
pub struct Ip {
octets: [u8; 4],
}
#[derive(Debug, PartialEq, Eq, Clone)]
pub struct CIDR {
ip: Ip,
mask: u8,
}
impl Ip {
#[inline]
pub const fn new(octets: [u8; 4]) -> Ip {
Ip { octets }
}
#[inline]
pub const fn as_octets(&self) -> [u8; 4] {
self.octets
}
#[inline]
pub const fn as_bits(&self) -> u32 {
(self.octets[0] as u32) << 24
| (self.octets[1] as u32) << 16
| (self.octets[2] as u32) << 8
| (self.octets[3] as u32)
}
#[inline]
pub const fn is_unicast(&self) -> bool {
self.octets[0] < 224
}
pub fn is_public(&self) -> bool {
for pip in &PRIVATE_IP4_SUBNETS {
if self.as_bits() & pip.as_bitmask() == pip.as_prefix().as_bits() {
return false;
}
}
true
}
pub fn parse(input: &[u8]) -> Result<Ip, Error> {
if input.len() < 7 {
return Err(Error::InputTooShort);
}
if input.len() > 15 {
return Err(Error::InputTooLong);
}
let mut sections = 0;
let mut last_char_was_dot = true;
let mut octets = [0, 0, 0, 0];
let mut accumulator: u8 = 0;
for c in input {
match c {
b'0'..=b'9' => {
if accumulator == 0 && !last_char_was_dot {
return Err(Error::LeadingZero);
}
accumulator = if let Some(a) = accumulator.checked_mul(10) {
a
} else {
return Err(Error::OctetOverflow);
};
accumulator = if let Some(a) = accumulator.checked_add(c - b'0') {
a
} else {
return Err(Error::OctetOverflow);
};
last_char_was_dot = false;
}
b'.' => {
if last_char_was_dot {
return Err(Error::MissingOctet);
}
octets[sections] = accumulator;
sections += 1;
accumulator = 0;
if sections > 3 {
return Err(Error::TooManyOctets);
}
last_char_was_dot = true;
}
_ => return Err(Error::IllegalCharacter),
}
}
octets[sections] = accumulator;
if sections < 3 {
Err(Error::InsufficientOctets)
} else if last_char_was_dot {
Err(Error::MissingOctet)
} else {
Ok(Ip { octets })
}
}
}
impl Deref for Ip {
type Target = [u8; 4];
fn deref(&self) -> &Self::Target {
&self.octets
}
}
impl CIDR {
#[inline]
pub const fn new(ip: Ip, mask: u8) -> CIDR {
if mask > 32 {
panic!("CIDR mask can't be higher than 32");
}
CIDR { ip, mask }
}
#[inline]
pub const fn as_prefix(&self) -> Ip {
self.ip
}
#[inline]
pub const fn as_mask(&self) -> u8 {
self.mask
}
#[inline]
pub const fn is_unicast(&self) -> bool {
self.ip.is_unicast()
}
#[inline]
pub const fn as_bitmask(&self) -> u32 {
!((2_u64.pow(32 - self.mask as u32) - 1) as u32)
}
#[inline]
pub fn is_public(&self) -> bool {
self.ip.is_public()
}
pub fn parse(input: &[u8]) -> Result<CIDR, Error> {
if input.len() > 18 {
return Err(Error::InputTooLong);
}
let sep_pos = if let Some(pos) = input.iter().position(|c| c == &b'/') {
pos
} else {
return Err(Error::MissingMask);
};
let ip = Ip::parse(&input[..sep_pos])?;
let mask_input = &input[sep_pos + 1..];
if mask_input.len() > 2 {
return Err(Error::InputTooLong);
}
if mask_input.is_empty() {
return Err(Error::InputTooShort);
}
let mut mask: u8 = 0;
for c in mask_input {
match c {
b'0'..=b'9' => {
mask *= 10;
mask += c - b'0';
}
_ => return Err(Error::IllegalCharacter),
}
}
if mask > 32 {
return Err(Error::MaskOverflow);
}
if mask == 0 && mask_input.len() > 1 {
return Err(Error::LeadingZero);
}
Ok(CIDR { ip, mask })
}
#[inline]
pub const fn contains(&self, ip: Ip) -> bool {
let mask_bits = self.as_bitmask();
ip.as_bits() & mask_bits == self.ip.as_bits() & mask_bits
}
}
#[cfg(test)]
mod tests {
use crate::errors::Error;
#[test]
fn parse_valid_ips() {
assert_eq!(
super::Ip::parse("1.1.1.1".as_bytes()),
Ok(super::Ip::new([1, 1, 1, 1]))
);
assert_eq!(
super::Ip::parse("100.200.10.0".as_bytes()),
Ok(super::Ip::new([100, 200, 10, 0]))
);
assert_eq!(
super::Ip::parse("255.255.255.255".as_bytes()),
Ok(super::Ip::new([255, 255, 255, 255]))
);
assert_eq!(
super::Ip::parse("0.0.0.0".as_bytes()),
Ok(super::Ip::new([0, 0, 0, 0]))
);
}
#[test]
fn reject_invalid_ips() {
assert_eq!(
super::Ip::parse("1.1.1.1.".as_bytes()),
Err(Error::TooManyOctets)
);
assert_eq!(
super::Ip::parse("1.1.1.1.1".as_bytes()),
Err(Error::TooManyOctets)
);
assert_eq!(
super::Ip::parse("1.1.1.".as_bytes()),
Err(Error::InputTooShort)
);
assert_eq!(
super::Ip::parse("1.1.1".as_bytes()),
Err(Error::InputTooShort)
);
assert_eq!(
super::Ip::parse("100.100.1".as_bytes()),
Err(Error::InsufficientOctets)
);
assert_eq!(
super::Ip::parse("100.100.1.".as_bytes()),
Err(Error::MissingOctet)
);
assert_eq!(
super::Ip::parse("255.255.255.256".as_bytes()),
Err(Error::OctetOverflow)
);
assert_eq!(
super::Ip::parse("256.255.255.255".as_bytes()),
Err(Error::OctetOverflow)
);
assert_eq!(
super::Ip::parse("1.10.100.1000".as_bytes()),
Err(Error::OctetOverflow)
);
assert_eq!(
super::Ip::parse("1000.100.10.1".as_bytes()),
Err(Error::OctetOverflow)
);
assert_eq!(
super::Ip::parse("af.fe.ff.ac".as_bytes()),
Err(Error::IllegalCharacter)
);
assert_eq!(
super::Ip::parse("00.0.0.0".as_bytes()),
Err(Error::LeadingZero)
);
assert_eq!(
super::Ip::parse("1.01.2.3".as_bytes()),
Err(Error::LeadingZero)
);
assert_eq!(
super::Ip::parse("1:1:1:1".as_bytes()),
Err(Error::IllegalCharacter)
);
assert_eq!(
super::Ip::parse("1.1.2_.3".as_bytes()),
Err(Error::IllegalCharacter)
);
}
#[test]
fn public_ip() {
assert_eq!(
super::Ip::parse("1.1.1.1".as_bytes()).unwrap().is_public(),
true
);
assert_eq!(
super::Ip::parse("10.10.100.254".as_bytes())
.unwrap()
.is_public(),
false
);
assert_eq!(
super::Ip::parse("10.10.100.254".as_bytes())
.unwrap()
.is_public(),
false
);
assert_eq!(
super::Ip::parse("172.10.100.254".as_bytes())
.unwrap()
.is_public(),
true
);
assert_eq!(
super::Ip::parse("172.20.100.254".as_bytes())
.unwrap()
.is_public(),
false
);
}
#[test]
fn create_new_cidr() {
super::CIDR::new(super::Ip::new([1, 1, 1, 1]), 15);
}
#[test]
#[should_panic]
fn reject_large_mask() {
super::CIDR::new(super::Ip::new([1, 1, 1, 1]), 33);
}
#[test]
fn accept_valid_cidr() {
assert_eq!(
super::CIDR::parse("1.1.1.1/32".as_bytes()),
Ok(super::CIDR::new(super::Ip::new([1, 1, 1, 1]), 32))
);
assert_eq!(
super::CIDR::parse("1.1.1.1/1".as_bytes()),
Ok(super::CIDR::new(super::Ip::new([1, 1, 1, 1]), 1))
);
assert_eq!(
super::CIDR::parse("255.255.255.255/32".as_bytes()),
Ok(super::CIDR::new(super::Ip::new([255, 255, 255, 255]), 32))
);
assert_eq!(
super::CIDR::parse("255.255.255.129/25".as_bytes()),
Ok(super::CIDR::new(super::Ip::new([255, 255, 255, 129]), 25))
);
assert_eq!(
super::CIDR::parse("128.0.0.0/1".as_bytes()),
Ok(super::CIDR::new(super::Ip::new([128, 0, 0, 0,]), 1))
);
assert_eq!(
super::CIDR::parse("32.40.50.24/29".as_bytes()),
Ok(super::CIDR::new(super::Ip::new([32, 40, 50, 24]), 29))
);
assert_eq!(
super::CIDR::parse("10.0.0.1/8".as_bytes()),
Ok(super::CIDR::new(super::Ip::new([10, 0, 0, 1]), 8))
);
}
#[test]
fn reject_invalid_cidr() {
assert_eq!(
super::CIDR::parse("1.1.1.1/33".as_bytes()),
Err(Error::MaskOverflow)
);
assert_eq!(
super::CIDR::parse("255.255.255.255/00".as_bytes()),
Err(Error::LeadingZero)
);
assert_eq!(
super::CIDR::parse("1.1.1.1//1".as_bytes()),
Err(Error::IllegalCharacter)
);
assert_eq!(
super::CIDR::parse("50.40.50.23/160".as_bytes()),
Err(Error::InputTooLong)
);
assert_eq!(
super::CIDR::parse("1.1.1.1/".as_bytes()),
Err(Error::InputTooShort)
);
assert_eq!(
super::CIDR::parse("1.111.1.1/".as_bytes()),
Err(Error::InputTooShort)
);
assert_eq!(
super::CIDR::parse("1.1.1.1/000".as_bytes()),
Err(Error::InputTooLong)
);
assert_eq!(
super::CIDR::parse("1.1.1.1/99".as_bytes()),
Err(Error::MaskOverflow)
);
assert_eq!(
super::CIDR::parse("1.1.1.1".as_bytes()),
Err(Error::MissingMask)
);
assert_eq!(
super::CIDR::parse("100.1.1.1".as_bytes()),
Err(Error::MissingMask)
);
assert_eq!(
super::CIDR::parse("1.1.1.1032".as_bytes()),
Err(Error::MissingMask)
);
assert_eq!(
super::CIDR::parse("1.1.1.1032/".as_bytes()),
Err(Error::OctetOverflow)
);
assert_eq!(
super::CIDR::parse("1.1.1.1/.".as_bytes()),
Err(Error::IllegalCharacter)
);
}
#[test]
fn ip_in_cidr() {
assert_eq!(
super::CIDR::parse("34.0.0.1/24".as_bytes())
.unwrap()
.contains(super::Ip::parse("34.0.0.254".as_bytes()).unwrap()),
true,
);
assert_eq!(
super::CIDR::parse("34.0.1.1/24".as_bytes())
.unwrap()
.contains(super::Ip::parse("34.0.0.254".as_bytes()).unwrap()),
false
);
}
#[test]
fn public_ip_cidr() {
assert_eq!(
super::CIDR::parse("1.1.1.1/32".as_bytes())
.unwrap()
.is_public(),
true,
);
assert_eq!(
super::CIDR::parse("10.10.100.254/24".as_bytes())
.unwrap()
.is_public(),
false,
);
assert_eq!(
super::CIDR::parse("10.10.100.254/24".as_bytes())
.unwrap()
.is_public(),
false,
);
assert_eq!(
super::CIDR::parse("172.10.100.254/24".as_bytes())
.unwrap()
.is_public(),
true,
);
assert_eq!(
super::CIDR::parse("172.20.100.254/24".as_bytes())
.unwrap()
.is_public(),
false,
);
}
}