use crate::error::*;
use data_encoding::HEXLOWER_PERMISSIVE;
use serde::de::{Deserialize, Deserializer, Unexpected, Visitor};
use serde::ser::{Serialize, Serializer};
use std::cmp::Ordering;
use std::fmt;
use std::marker::PhantomData;
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
use std::str::FromStr;
struct ParseableVisitor<T: FromStr<Err = Error>> {
phantom: PhantomData<T>,
}
impl<'de, T: FromStr<Err = Error>> Visitor<'de> for ParseableVisitor<T> {
type Value = T;
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("a parseable value")
}
fn visit_str<E: ::serde::de::Error>(self, v: &str) -> ::std::result::Result<Self::Value, E> {
match v.parse::<T>() {
Err(e) => Err(E::invalid_value(
Unexpected::Str(v),
&e.to_string().as_str(),
)),
Ok(addr) => Ok(addr),
}
}
}
impl<T: FromStr<Err = Error>> Default for ParseableVisitor<T> {
fn default() -> Self {
ParseableVisitor {
phantom: PhantomData,
}
}
}
fn increment_ip_bytes(bytes: &mut [u8]) {
for byte in bytes.iter_mut().rev() {
match byte.checked_add(1) {
None => *byte = 0,
Some(new_byte) => {
*byte = new_byte;
return;
}
}
}
}
pub fn increment_ip(ip: IpAddr) -> Option<IpAddr> {
match ip {
IpAddr::V4(ip) => {
let mut bytes: [u8; 4] = ip.octets();
increment_ip_bytes(&mut bytes);
if bytes.iter().all(|b| *b == 0) {
return None;
}
Some(Ipv4Addr::from(bytes).into())
}
IpAddr::V6(ip) => {
let mut bytes: [u8; 16] = ip.octets();
increment_ip_bytes(&mut bytes);
if bytes.iter().all(|b| *b == 0) {
return None;
}
Some(Ipv6Addr::from(bytes).into())
}
}
}
pub fn compare_ips(a: IpAddr, b: IpAddr) -> Option<Ordering> {
match (a, b) {
(IpAddr::V4(a), IpAddr::V4(b)) => Some(a.cmp(&b)),
(IpAddr::V6(a), IpAddr::V6(b)) => Some(a.cmp(&b)),
_ => None,
}
}
pub fn min_ip(a: IpAddr, b: IpAddr) -> Option<IpAddr> {
match (a, b) {
(IpAddr::V4(a), IpAddr::V4(b)) => Some(IpAddr::V4(a.min(b))),
(IpAddr::V6(a), IpAddr::V6(b)) => Some(IpAddr::V6(a.min(b))),
_ => None,
}
}
pub fn max_ip(a: IpAddr, b: IpAddr) -> Option<IpAddr> {
match (a, b) {
(IpAddr::V4(a), IpAddr::V4(b)) => Some(IpAddr::V4(a.max(b))),
(IpAddr::V6(a), IpAddr::V6(b)) => Some(IpAddr::V6(a.max(b))),
_ => None,
}
}
#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
pub struct HardwareAddr {
address: [u8; 6],
}
impl HardwareAddr {
pub fn as_bytes(&self) -> &[u8] {
&self.address
}
}
impl fmt::Display for HardwareAddr {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(
f,
"{:02x}:{:02x}:{:02x}:{:02x}:{:02x}:{:02x}",
self.address[0],
self.address[1],
self.address[2],
self.address[3],
self.address[4],
self.address[5]
)
}
}
impl FromStr for HardwareAddr {
type Err = Error;
fn from_str(s: &str) -> Result<Self> {
let normalized = s.replace(":", "");
let normalized = normalized.replace("-", "");
let normalized = normalized.replace(".", "");
let address_vec = match HEXLOWER_PERMISSIVE.decode(normalized.as_bytes()) {
Ok(data) => data,
Err(e) => return Err(Error::HexDecode(e)),
};
if address_vec.len() != 6 {
return Err(Error::InvalidArgument(format!(
"invalid MAC address '{}', expected 6 bytes found {}",
s,
address_vec.len()
)));
}
let mut address = [0_u8; 6];
address.copy_from_slice(&address_vec);
Ok(HardwareAddr { address })
}
}
impl Serialize for HardwareAddr {
fn serialize<S: Serializer>(&self, serializer: S) -> ::std::result::Result<S::Ok, S::Error> {
serializer.serialize_str(self.to_string().as_str())
}
}
impl<'de> Deserialize<'de> for HardwareAddr {
fn deserialize<D: Deserializer<'de>>(deserializer: D) -> ::std::result::Result<Self, D::Error> {
deserializer.deserialize_str(ParseableVisitor::<HardwareAddr>::default())
}
}
fn apply_ip_mask_bytes(ip: &mut [u8], mask: &[u8], invert: bool, set: bool) {
debug_assert!(ip.len() == mask.len());
for (b, m) in ip.iter_mut().zip(mask.iter()) {
let m: u8 = if invert { !*m } else { *m };
if set {
*b |= m;
} else {
*b &= m;
}
}
}
fn apply_ip_mask(ip: IpAddr, mask: &[u8], invert: bool, set: bool) -> IpAddr {
debug_assert!(mask.len() == 16);
let mut bytes: [u8; 16] = match ip {
IpAddr::V4(ip) => ip.to_ipv6_compatible().octets(),
IpAddr::V6(ip) => ip.octets(),
};
apply_ip_mask_bytes(&mut bytes, mask, invert, set);
let masked_ip = Ipv6Addr::from(bytes);
if ip.is_ipv4() {
let bytes = masked_ip.octets();
IpAddr::V4(Ipv4Addr::new(bytes[12], bytes[13], bytes[14], bytes[15]))
} else {
IpAddr::V6(masked_ip)
}
}
#[derive(Clone, Copy, Debug, Hash, Eq, PartialEq)]
pub struct IpNet {
ip: IpAddr,
mask: [u8; 16],
}
impl IpNet {
pub fn get_ip(&self) -> IpAddr {
self.ip
}
pub fn get_mask(&self) -> &[u8] {
&self.mask
}
pub fn apply_mask(&self, ip: IpAddr, invert: bool, set: bool) -> IpAddr {
apply_ip_mask(ip, &self.mask, invert, set)
}
pub fn get_one_bits(&self) -> usize {
self.mask
.iter()
.skip(if self.ip.is_ipv4() { 12 } else { 0 })
.map(|b| b.count_ones() as usize)
.sum::<usize>()
}
pub fn is_canonical(&self) -> bool {
let first_zero_bit = match self.mask.iter().position(|b| *b != 0xff_u8) {
None => return true,
Some(idx) => idx,
};
if self.mask[first_zero_bit].count_zeros() != self.mask[first_zero_bit].trailing_zeros() {
return false;
}
self.mask
.iter()
.skip(first_zero_bit + 1)
.all(|byte| *byte == 0x00_u8)
}
pub fn netmask(&self) -> IpAddr {
self.apply_mask(self.ip, false, true)
}
pub fn broadcast(&self) -> IpAddr {
self.apply_mask(self.ip, true, true)
}
pub fn contains(&self, ip: IpAddr, strict: bool) -> bool {
let contains = self.apply_mask(ip, false, false) == self.ip;
if strict {
contains && ip != self.ip && ip != self.broadcast()
} else {
contains
}
}
pub fn increment_in(&self, ip: IpAddr, strict: bool) -> Option<IpAddr> {
let next_ip = increment_ip(ip)?;
if self.contains(next_ip, strict) {
Some(next_ip)
} else {
None
}
}
pub fn first(&self) -> Option<IpAddr> {
let ones = self.get_one_bits();
if self.ip.is_ipv4() && ones >= 31 {
return Some(self.ip);
}
if self.ip.is_ipv6() && ones >= 127 {
return Some(self.ip);
}
increment_ip(self.ip)
}
pub fn last(&self) -> IpAddr {
if (self.ip.is_ipv4() && self.get_one_bits() == 32)
|| (self.ip.is_ipv6() && self.get_one_bits() == 128)
{
return self.ip;
}
match self.broadcast() {
IpAddr::V4(ip) => {
let mut bytes = ip.octets();
let idx = bytes.len() - 1;
bytes[idx] -= 1;
Ipv4Addr::from(bytes).into()
}
IpAddr::V6(ip) => {
let mut bytes = ip.octets();
let idx = bytes.len() - 1;
bytes[idx] -= 1;
Ipv6Addr::from(bytes).into()
}
}
}
}
impl fmt::Display for IpNet {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(
f,
"{}/{}",
self.ip,
if self.is_canonical() {
self.get_one_bits().to_string()
} else {
HEXLOWER_PERMISSIVE
.encode(&self.mask)
.chars()
.skip(if self.ip.is_ipv4() { 24 } else { 0 })
.collect::<String>()
}
)
}
}
impl FromStr for IpNet {
type Err = Error;
fn from_str(s: &str) -> Result<Self> {
let (ip, mask): (&str, &str) = s.split_at(match s.find('/') {
None => {
return Err(Error::InvalidArgument(format!(
"invalid IP network specifier '{}'",
s
)));
}
Some(idx) => idx,
});
let ip: IpAddr = ip.parse()?;
let mask: &str = &mask[1..];
let mut mask_vec: Vec<u8> = vec![0xff_u8; if ip.is_ipv4() { 12 } else { 0 }];
let mask_is_hex = (ip.is_ipv4() && mask.len() == 8) || (ip.is_ipv6() && mask.len() == 32);
if mask_is_hex {
mask_vec.extend(
match HEXLOWER_PERMISSIVE.decode(mask.as_bytes()) {
Ok(data) => data,
Err(e) => return Err(Error::HexDecode(e)),
}
.into_iter(),
);
} else {
let ones = u8::from_str_radix(mask, 10)?;
let max_ones = if ip.is_ipv4() { 32 } else { 128 };
if ones > max_ones {
return Err(Error::InvalidArgument(format!(
"invalid IP network prefix length /{}, must be <= {}",
ones, max_ones
)));
}
let mut v = vec![0xff_u8; (ones / 8) as usize];
let extra_ones = ones % 8;
if extra_ones > 0 {
v.push((0xff_u8 >> (8 - extra_ones)) << (8 - extra_ones));
}
mask_vec.extend(v.into_iter());
}
let mut mask = [0_u8; 16];
mask[..mask_vec.len()].copy_from_slice(&mask_vec);
Ok(IpNet {
ip: apply_ip_mask(ip, &mask, false, false),
mask,
})
}
}
impl Serialize for IpNet {
fn serialize<S: Serializer>(&self, serializer: S) -> ::std::result::Result<S::Ok, S::Error> {
serializer.serialize_str(self.to_string().as_str())
}
}
impl<'de> Deserialize<'de> for IpNet {
fn deserialize<D: Deserializer<'de>>(deserializer: D) -> ::std::result::Result<Self, D::Error> {
deserializer.deserialize_str(ParseableVisitor::<IpNet>::default())
}
}