use crate::IpBitwiseExt;
use std::net::Ipv4Addr;
use std::str::FromStr;
use std::ops::Not;
use std::fmt::{Debug, Display, Formatter, Result as FmtResult};
#[repr(align(4))]
#[derive(Copy, Clone, Eq, PartialEq, Hash)]
pub struct Ipv4Mask {
mask: [u8; 4],
}
impl Ipv4Mask {
pub const fn new(len: u8) -> Self {
#[rustfmt::skip]
const MASKS: [[u8; 4]; 33] = [
[0, 0, 0, 0],
[128, 0, 0, 0], [192, 0, 0, 0], [224, 0, 0, 0], [240, 0, 0, 0],
[248, 0, 0, 0], [252, 0, 0, 0], [254, 0, 0, 0], [255, 0, 0, 0],
[255, 128, 0, 0], [255, 192, 0, 0], [255, 224, 0, 0], [255, 240, 0, 0],
[255, 248, 0, 0], [255, 252, 0, 0], [255, 254, 0, 0], [255, 255, 0, 0],
[255, 255, 128, 0], [255, 255, 192, 0], [255, 255, 224, 0], [255, 255, 240, 0],
[255, 255, 248, 0], [255, 255, 252, 0], [255, 255, 254, 0], [255, 255, 255, 0],
[255, 255, 255, 128], [255, 255, 255, 192], [255, 255, 255, 224], [255, 255, 255, 240],
[255, 255, 255, 248], [255, 255, 255, 252], [255, 255, 255, 254], [255, 255, 255, 255],
];
let mask = MASKS[len as usize];
Self { mask }
}
pub fn from_bytes(bytes: [u8; 4]) -> Option<Self> {
Self::from_u32(u32::from_be_bytes(bytes))
}
pub fn from_u32(x: u32) -> Option<Self> {
let ones = if cfg!(target_feature = "popcnt") {
x.count_ones() as u8
} else {
(!x).leading_zeros() as u8
};
let zeros = x.trailing_zeros() as u8;
if ones + zeros == 32 {
let mask = x.to_be_bytes();
Some(Self { mask })
} else {
None
}
}
pub const fn octets(self) -> [u8; 4] {
self.mask
}
pub const fn as_u32(self) -> u32 {
let bytes = self.octets();
(bytes[0] as u32) << 24 | (bytes[1] as u32) << 16 | (bytes[2] as u32) << 8 | bytes[3] as u32
}
pub const fn len(self) -> u8 {
let x = self.as_u32();
#[cfg(target_feature = "popcnt")]
let len = x.count_ones() as u8;
#[cfg(not(target_feature = "popcnt"))]
let len = (!x).leading_zeros() as u8;
len
}
}
impl Display for Ipv4Mask {
fn fmt(&self, f: &mut Formatter) -> FmtResult {
if f.alternate() {
write!(f, "/{}", self.len())
} else {
let bytes = self.octets();
write!(f, "{}.{}.{}.{}", bytes[0], bytes[1], bytes[2], bytes[3])
}
}
}
impl Debug for Ipv4Mask {
fn fmt(&self, f: &mut Formatter) -> FmtResult {
Display::fmt(self, f)
}
}
impl Not for Ipv4Mask {
type Output = [u8; 4];
fn not(self) -> [u8; 4] {
let x = u32::from_ne_bytes(self.octets());
(!x).to_ne_bytes()
}
}
impl FromStr for Ipv4Mask {
type Err = InvalidIpv4Mask;
fn from_str(s: &str) -> Result<Self, InvalidIpv4Mask> {
let bytes = s.parse::<Ipv4Addr>().map_err(|_| InvalidIpv4Mask)?.octets();
Self::from_bytes(bytes).ok_or(InvalidIpv4Mask)
}
}
#[derive(Debug)]
pub struct InvalidIpv4Mask;
#[derive(Copy, Clone, Eq, PartialEq, Hash)]
pub struct MaskedIpv4 {
pub ip: Ipv4Addr,
pub mask: Ipv4Mask,
}
impl MaskedIpv4 {
pub const fn new(ip: Ipv4Addr, mask: Ipv4Mask) -> Self {
Self { ip, mask }
}
pub const fn cidr(ip: Ipv4Addr, mask_len: u8) -> Self {
let mask = Ipv4Mask::new(mask_len);
Self::new(ip, mask)
}
pub fn from_cidr_str(s: &str) -> Option<Self> {
let mut parts = s.split("/");
let ip = parts.next()?.parse::<Ipv4Addr>().ok()?;
let mask_len = parts.next()?.parse::<u8>().ok()?;
if mask_len > 32 {
None
} else {
Some(Self::cidr(ip, mask_len))
}
}
pub fn from_network_str(s: &str) -> Option<Self> {
let mut parts = s.split(" ");
let ip = parts.next()?.parse().ok()?;
let mask = parts.next()?.parse().ok()?;
Some(Self::new(ip, mask))
}
pub fn to_cidr_string(&self) -> String {
format!("{:#}", self)
}
pub fn to_network_string(&self) -> String {
format!("{}", self)
}
pub fn network_address(&self) -> Ipv4Addr {
self.ip.bitand(self.mask)
}
pub fn network(&self) -> MaskedIpv4 {
Self::new(self.network_address(), self.mask)
}
pub fn is_network_address(&self) -> bool {
self.mask.len() <= 30 && self.ip == self.network_address()
}
pub fn broadcast_address(&self) -> Ipv4Addr {
self.ip.bitor(!self.mask)
}
pub fn is_broadcast_address(&self) -> bool {
self.mask.len() <= 30 && self.ip == self.broadcast_address()
}
pub fn network_bits(&self) -> u8 {
self.mask.len()
}
pub fn host_bits(&self) -> u8 {
32 - self.network_bits()
}
pub fn host_count(&self) -> usize {
let host_bits = self.host_bits();
match host_bits {
0 => 1,
1 => 2,
_ => 2usize.checked_shl(host_bits as u32).unwrap() - 2,
}
}
pub fn host_count_u64(&self) -> u64 {
let host_bits = self.host_bits();
match host_bits {
0 => 1,
1 => 2,
_ => (2 << host_bits) - 2,
}
}
pub fn network_count(&self, len: u8) -> usize {
if len < self.mask.len() {
0
} else if len > 32 {
panic!("Invalid mask length > 32")
} else {
let borrowed_bits = len - self.mask.len();
2usize.checked_shl(borrowed_bits as u32).unwrap()
}
}
pub fn network_count_u64(&self, len: u8) -> u64 {
if len < self.mask.len() {
0
} else if len > 32 {
panic!("Invalid mask length > 32")
} else {
let borrowed_bits = len - self.mask.len();
2 << borrowed_bits
}
}
pub fn contains(&self, ip: Ipv4Addr) -> bool {
self.ip.bitand(self.mask) == ip.bitand(self.mask)
}
}
impl Display for MaskedIpv4 {
fn fmt(&self, f: &mut Formatter) -> FmtResult {
if f.alternate() {
write!(f, "{}/{}", self.ip, self.mask.len())
} else {
write!(f, "{} {}", self.ip, self.mask)
}
}
}
impl Debug for MaskedIpv4 {
fn fmt(&self, f: &mut Formatter) -> FmtResult {
Display::fmt(self, f)
}
}
impl FromStr for MaskedIpv4 {
type Err = InvalidMaskedIpv4;
fn from_str(s: &str) -> Result<Self, InvalidMaskedIpv4> {
Self::from_cidr_str(s)
.or_else(|| Self::from_network_str(s))
.ok_or(InvalidMaskedIpv4)
}
}
#[derive(Debug)]
pub struct InvalidMaskedIpv4;