use core::cmp::Ordering;
use core::fmt;
use core::hash::{Hash, Hasher};
use core::net::{IpAddr, Ipv4Addr, Ipv6Addr};
use core::str::FromStr;
use super::error::{CidrParseKind, Error, Result};
#[derive(Copy, Clone)]
pub struct Ipv4Cidr {
octets: [u8; 4],
bits: u8,
}
impl Ipv4Cidr {
pub const MAX_BITS: u8 = 32;
pub const fn mask_of(bits: u8) -> u32 {
if bits == 0 {
return 0;
}
if bits > 32 {
panic!("bits must be <= 32");
}
u32::MAX << (32 - bits)
}
#[inline]
pub const fn new(octets: [u8; 4], bits: u8) -> Result<Self> {
if bits > 32 {
return Err(Error::OverflowIpv4CidrBit(bits));
}
let octets = (u32::from_be_bytes(octets) & Self::mask_of(bits)).to_be_bytes();
Ok(Self { octets, bits })
}
#[inline]
pub fn from_ip<I>(ip: I, bits: u8) -> Result<Self>
where
I: Into<Ipv4Addr>,
{
if bits > 32 {
return Err(Error::OverflowIpv6CidrBit(bits));
}
let octets = (ip.into().to_bits() & Self::mask_of(bits)).to_be_bytes();
Ok(Self { octets, bits })
}
#[inline]
pub const fn network_addr(&self) -> Ipv4Addr {
Ipv4Addr::from_bits(u32::from_be_bytes(self.octets))
}
#[inline]
pub const fn broadcast_addr(&self) -> Ipv4Addr {
let mask = Self::mask_of(self.bits);
let network = u32::from_be_bytes(self.octets);
Ipv4Addr::from_bits(network | !mask)
}
#[inline]
pub const fn hosts(&self) -> Ipv4Hosts {
let min = u32::from_be_bytes(self.octets);
let max = min + (u32::MAX ^ Self::mask_of(self.bits));
let end = if max == u32::MAX { None } else { Some(max) };
if self.bits >= 31 {
Ipv4Hosts {
cursor: min,
end: if let Some(v) = end { Some(v + 1) } else { None },
}
} else {
Ipv4Hosts {
cursor: min + 1,
end: if let Some(v) = end {
Some(v)
} else {
Some(u32::MAX)
},
}
}
}
#[inline]
pub fn supernet(&self) -> Option<Ipv4Cidr> {
match self.bits() {
0 => None,
bits => Some(Ipv4Cidr::new(self.octets(), bits - 1).unwrap()),
}
}
#[inline]
pub const fn mask(&self) -> u32 {
Self::mask_of(self.bits)
}
#[inline]
pub const fn octets(&self) -> [u8; 4] {
self.octets
}
#[inline]
pub const fn bits(&self) -> u8 {
self.bits
}
#[inline]
pub const fn contains(&self, addr: Ipv4Addr) -> bool {
let addr = addr.to_bits();
let mask = self.mask();
let cidr = u32::from_be_bytes(self.octets);
addr & mask == cidr
}
#[inline]
pub const fn contains_cidr(&self, other: &Self) -> bool {
self.overlaps(other) && self.bits() <= other.bits
}
#[inline]
pub const fn overlaps(&self, other: &Self) -> bool {
let min_bits = if self.bits < other.bits {
self.bits
} else {
other.bits
};
let mask = Self::mask_of(min_bits);
let x = u32::from_be_bytes(self.octets);
let y = u32::from_be_bytes(other.octets);
(x & mask) == (y & mask)
}
}
impl fmt::Display for Ipv4Cidr {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}/{}", Ipv4Addr::from(self.octets), self.bits)
}
}
impl fmt::Debug for Ipv4Cidr {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "Ipv4Cidr({self})")
}
}
impl PartialEq for Ipv4Cidr {
fn eq(&self, other: &Self) -> bool {
self.octets == other.octets && self.bits == other.bits
}
}
impl Eq for Ipv4Cidr {}
impl Hash for Ipv4Cidr {
fn hash<H: Hasher>(&self, state: &mut H) {
self.octets.hash(state);
self.bits.hash(state);
}
}
impl From<Ipv4Addr> for Ipv4Cidr {
fn from(addr: Ipv4Addr) -> Self {
Self::from_ip(addr, 32).unwrap()
}
}
impl TryFrom<([u8; 4], u8)> for Ipv4Cidr {
type Error = Error;
fn try_from((octets, bits): ([u8; 4], u8)) -> core::result::Result<Self, Self::Error> {
Self::from_ip(octets, bits)
}
}
impl FromStr for Ipv4Cidr {
type Err = Error;
fn from_str(s: &str) -> core::result::Result<Self, Self::Err> {
let (addr, bits) = s
.split_once('/')
.ok_or(Error::CidrParseError(CidrParseKind::Ipv4))?;
let addr = addr
.parse::<Ipv4Addr>()
.map_err(|_| Error::CidrParseError(CidrParseKind::Ipv4))?;
let bits = bits
.parse()
.map_err(|_| Error::CidrParseError(CidrParseKind::Ipv4))?;
Ipv4Cidr::from_ip(addr, bits)
}
}
impl PartialOrd for Ipv4Cidr {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for Ipv4Cidr {
fn cmp(&self, other: &Self) -> Ordering {
let a = u32::from_be_bytes(self.octets);
let b = u32::from_be_bytes(other.octets);
match a.cmp(&b) {
Ordering::Equal => self.bits.cmp(&other.bits),
ord => ord,
}
}
}
pub struct Ipv4Hosts {
cursor: u32,
end: Option<u32>,
}
impl Ipv4Hosts {
#[inline]
pub const fn len(&self) -> u32 {
debug_assert!(!(self.end.is_none() && self.cursor == 0));
match self.end {
Some(end) => end - self.cursor,
None => u32::MAX - self.cursor + 1,
}
}
}
impl Iterator for Ipv4Hosts {
type Item = Ipv4Addr;
fn next(&mut self) -> Option<Self::Item> {
if self.end.is_some_and(|end| self.cursor >= end) {
return None;
}
let rv = Some(Ipv4Addr::from_bits(self.cursor));
if self.cursor < u32::MAX {
self.cursor += 1;
} else {
self.end = Some(u32::MAX); }
rv
}
fn size_hint(&self) -> (usize, Option<usize>) {
let n: u32 = self.len();
usize::try_from(n).map_or((usize::MAX, None), |n| (n, Some(n)))
}
fn count(self) -> usize
where
Self: Sized,
{
self.size_hint().1.expect("count overflow")
}
}
#[derive(Copy, Clone)]
pub struct Ipv6Cidr {
octets: [u8; 16],
bits: u8,
}
impl Ipv6Cidr {
pub const MAX_BITS: u8 = 128;
pub const fn mask_of(bits: u8) -> u128 {
if bits == 0 {
return 0;
}
if bits > 128 {
panic!("bits must be <= 128");
}
u128::MAX << (128 - bits)
}
#[inline]
pub const fn new(octets: [u16; 8], bits: u8) -> Result<Self> {
if bits > 128 {
return Err(Error::OverflowIpv6CidrBit(bits));
}
let [a, b, c, d, e, f, g, h] = octets;
let addr16 = [
a.to_be(),
b.to_be(),
c.to_be(),
d.to_be(),
e.to_be(),
f.to_be(),
g.to_be(),
h.to_be(),
];
let octets = unsafe { core::mem::transmute::<[u16; 8], [u8; 16]>(addr16) };
let octets = (u128::from_be_bytes(octets) & Self::mask_of(bits)).to_be_bytes();
Ok(Self { octets, bits })
}
#[inline]
pub fn from_ip<I>(ip: I, bits: u8) -> Result<Self>
where
I: Into<Ipv6Addr>,
{
if bits > 128 {
return Err(Error::OverflowIpv6CidrBit(bits));
}
let octets = (ip.into().to_bits() & Self::mask_of(bits)).to_be_bytes();
Ok(Self { octets, bits })
}
#[inline]
pub const fn network_addr(&self) -> Ipv6Addr {
Ipv6Addr::from_bits(u128::from_be_bytes(self.octets))
}
#[inline]
pub const fn broadcast_addr(&self) -> Ipv6Addr {
let mask = Self::mask_of(self.bits);
let network = u128::from_be_bytes(self.octets);
Ipv6Addr::from_bits(network | !mask)
}
#[inline]
pub const fn hosts(&self) -> Ipv6Hosts {
let min = u128::from_be_bytes(self.octets);
let max = min + (u128::MAX ^ Self::mask_of(self.bits));
let end = if max == u128::MAX { None } else { Some(max) };
if self.bits >= 127 {
Ipv6Hosts {
cursor: min,
end: if let Some(v) = end { Some(v + 1) } else { None },
}
} else {
Ipv6Hosts {
cursor: min + 1,
end: if let Some(v) = end {
Some(v)
} else {
Some(u128::MAX)
},
}
}
}
#[inline]
pub fn supernet(&self) -> Option<Ipv6Cidr> {
match self.bits() {
0 => None,
bits => Some(Ipv6Cidr::from_ip(self.network_addr(), bits - 1).unwrap()),
}
}
#[inline]
pub const fn mask(&self) -> u128 {
Self::mask_of(self.bits)
}
#[inline]
pub const fn octets(&self) -> [u8; 16] {
self.octets
}
#[inline]
pub const fn bits(&self) -> u8 {
self.bits
}
#[inline]
pub const fn contains(&self, addr: Ipv6Addr) -> bool {
let addr = addr.to_bits();
let mask = self.mask();
let cidr = u128::from_be_bytes(self.octets);
addr & mask == cidr
}
#[inline]
pub const fn contains_cidr(&self, other: &Self) -> bool {
self.overlaps(other) && self.bits() <= other.bits
}
#[inline]
pub const fn overlaps(&self, other: &Self) -> bool {
let min_bits = if self.bits < other.bits {
self.bits
} else {
other.bits
};
let mask = Self::mask_of(min_bits);
let x = u128::from_be_bytes(self.octets);
let y = u128::from_be_bytes(other.octets);
(x & mask) == (y & mask)
}
}
impl fmt::Display for Ipv6Cidr {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}/{}", Ipv6Addr::from(self.octets), self.bits)
}
}
impl fmt::Debug for Ipv6Cidr {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "Ipv6Cidr({self})")
}
}
impl PartialEq for Ipv6Cidr {
fn eq(&self, other: &Self) -> bool {
self.octets == other.octets && self.bits == other.bits
}
}
impl Eq for Ipv6Cidr {}
impl Hash for Ipv6Cidr {
fn hash<H: Hasher>(&self, state: &mut H) {
self.octets.hash(state);
self.bits.hash(state);
}
}
impl From<Ipv6Addr> for Ipv6Cidr {
fn from(addr: Ipv6Addr) -> Self {
Self::from_ip(addr, 128).unwrap()
}
}
impl TryFrom<([u8; 16], u8)> for Ipv6Cidr {
type Error = Error;
fn try_from((octets, bits): ([u8; 16], u8)) -> core::result::Result<Self, Self::Error> {
Self::from_ip(octets, bits)
}
}
impl FromStr for Ipv6Cidr {
type Err = Error;
fn from_str(s: &str) -> core::result::Result<Self, Self::Err> {
let (addr, bits) = s
.split_once('/')
.ok_or(Error::CidrParseError(CidrParseKind::Ipv6))?;
let addr = addr
.parse::<Ipv6Addr>()
.map_err(|_| Error::CidrParseError(CidrParseKind::Ipv6))?;
let bits = bits
.parse()
.map_err(|_| Error::CidrParseError(CidrParseKind::Ipv6))?;
Ipv6Cidr::from_ip(addr, bits)
}
}
impl PartialOrd for Ipv6Cidr {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for Ipv6Cidr {
fn cmp(&self, other: &Self) -> Ordering {
let a = u128::from_be_bytes(self.octets);
let b = u128::from_be_bytes(other.octets);
match a.cmp(&b) {
Ordering::Equal => self.bits.cmp(&other.bits),
ord => ord,
}
}
}
pub struct Ipv6Hosts {
cursor: u128,
end: Option<u128>,
}
impl Ipv6Hosts {
#[inline]
pub const fn len(&self) -> u128 {
debug_assert!(!(self.end.is_none() && self.cursor == 0));
match self.end {
Some(end) => end - self.cursor,
None => u128::MAX - self.cursor + 1,
}
}
}
impl Iterator for Ipv6Hosts {
type Item = Ipv6Addr;
fn next(&mut self) -> Option<Self::Item> {
if self.end.is_some_and(|end| self.cursor >= end) {
return None;
}
let rv = Some(Ipv6Addr::from_bits(self.cursor));
if self.cursor < u128::MAX {
self.cursor += 1;
} else {
self.end = Some(u128::MAX); }
rv
}
fn size_hint(&self) -> (usize, Option<usize>) {
let n: u128 = self.len();
usize::try_from(n).map_or((usize::MAX, None), |n| (n, Some(n)))
}
fn count(self) -> usize
where
Self: Sized,
{
self.size_hint().1.expect("count overflow")
}
}
#[derive(Copy, Clone, Eq, PartialEq, Hash)]
pub enum Cidr {
V4(Ipv4Cidr),
V6(Ipv6Cidr),
}
impl Cidr {
pub fn new<I>(ip: I, bits: u8) -> Result<Self>
where
I: Into<IpAddr>,
{
match ip.into() {
IpAddr::V4(v4) => Ok(Cidr::V4(Ipv4Cidr::from_ip(v4, bits)?)),
IpAddr::V6(v6) => Ok(Cidr::V6(Ipv6Cidr::from_ip(v6, bits)?)),
}
}
pub const fn network_addr(&self) -> IpAddr {
match self {
Cidr::V4(v4) => IpAddr::V4(v4.network_addr()),
Cidr::V6(v6) => IpAddr::V6(v6.network_addr()),
}
}
pub const fn broadcast_addr(&self) -> IpAddr {
match self {
Cidr::V4(v4) => IpAddr::V4(v4.broadcast_addr()),
Cidr::V6(v6) => IpAddr::V6(v6.broadcast_addr()),
}
}
pub const fn hosts(&self) -> Hosts {
match self {
Cidr::V4(v4) => Hosts::V4(v4.hosts()),
Cidr::V6(v6) => Hosts::V6(v6.hosts()),
}
}
#[inline]
pub const fn bits(&self) -> u8 {
match self {
Cidr::V4(v4) => v4.bits(),
Cidr::V6(v6) => v6.bits(),
}
}
#[inline]
pub const fn contains(&self, addr: IpAddr) -> bool {
match (self, addr) {
(Cidr::V4(lh), IpAddr::V4(rh)) => lh.contains(rh),
(Cidr::V6(lh), IpAddr::V6(rh)) => lh.contains(rh),
_ => false,
}
}
#[inline]
pub const fn contains_cidr(&self, other: &Cidr) -> bool {
match (self, other) {
(Cidr::V4(lh), Cidr::V4(rh)) => lh.contains_cidr(rh),
(Cidr::V6(lh), Cidr::V6(rh)) => lh.contains_cidr(rh),
_ => false,
}
}
#[inline]
pub const fn overlaps(&self, other: &Self) -> bool {
match (self, other) {
(Cidr::V4(lh), Cidr::V4(rh)) => lh.overlaps(rh),
(Cidr::V6(lh), Cidr::V6(rh)) => lh.overlaps(rh),
_ => false,
}
}
#[inline]
pub const fn is_ipv4(&self) -> bool {
matches!(self, Cidr::V4(_))
}
#[inline]
pub const fn is_ipv6(&self) -> bool {
matches!(self, Cidr::V6(_))
}
}
impl fmt::Display for Cidr {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Cidr::V4(v4) => fmt::Display::fmt(&v4, f),
Cidr::V6(v6) => fmt::Display::fmt(&v6, f),
}
}
}
impl fmt::Debug for Cidr {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "Cidr({self})")
}
}
impl TryFrom<([u8; 4], u8)> for Cidr {
type Error = Error;
fn try_from((octets, bits): ([u8; 4], u8)) -> core::result::Result<Self, Self::Error> {
Ok(Cidr::V4(Ipv4Cidr::try_from((octets, bits))?))
}
}
impl TryFrom<([u8; 16], u8)> for Cidr {
type Error = Error;
fn try_from((octets, bits): ([u8; 16], u8)) -> core::result::Result<Self, Self::Error> {
Ok(Cidr::V6(Ipv6Cidr::try_from((octets, bits))?))
}
}
impl From<Ipv4Cidr> for Cidr {
fn from(v4: Ipv4Cidr) -> Self {
Cidr::V4(v4)
}
}
impl From<Ipv6Cidr> for Cidr {
fn from(v6: Ipv6Cidr) -> Self {
Cidr::V6(v6)
}
}
impl PartialEq<Ipv4Cidr> for Cidr {
fn eq(&self, other: &Ipv4Cidr) -> bool {
match self {
Cidr::V4(v4) => v4 == other,
_ => false,
}
}
}
impl PartialEq<Cidr> for Ipv4Cidr {
fn eq(&self, other: &Cidr) -> bool {
match other {
Cidr::V4(v4) => self == v4,
_ => false,
}
}
}
impl PartialEq<Ipv6Cidr> for Cidr {
fn eq(&self, other: &Ipv6Cidr) -> bool {
match self {
Cidr::V6(v6) => v6 == other,
_ => false,
}
}
}
impl PartialEq<Cidr> for Ipv6Cidr {
fn eq(&self, other: &Cidr) -> bool {
match other {
Cidr::V6(v6) => self == v6,
_ => false,
}
}
}
impl FromStr for Cidr {
type Err = Error;
fn from_str(s: &str) -> core::result::Result<Self, Self::Err> {
if let Ok(v4) = s.parse::<Ipv4Cidr>() {
return Ok(Cidr::V4(v4));
}
if let Ok(v6) = s.parse::<Ipv6Cidr>() {
return Ok(Cidr::V6(v6));
}
Err(Error::CidrParseError(CidrParseKind::Ip))
}
}
pub enum Hosts {
V4(Ipv4Hosts),
V6(Ipv6Hosts),
}
impl Iterator for Hosts {
type Item = IpAddr;
fn next(&mut self) -> Option<Self::Item> {
match self {
Self::V4(v4) => v4.next().map(IpAddr::V4),
Self::V6(v6) => v6.next().map(IpAddr::V6),
}
}
}