use std::net::{Ipv4Addr, Ipv6Addr};
use std::str::FromStr;
use crate::{NodeId, PimError};
pub const DEFAULT_MESH_IPV4_PREFIX: &str = "10.77.0.0/16";
pub const DEFAULT_MESH_IPV6_PREFIX: &str = "fd77::/64";
const CONTEXT_V4: &str = "pim-mesh-v4";
const CONTEXT_V6: &str = "pim-mesh-v6";
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct Ipv4Prefix {
pub network: Ipv4Addr,
pub prefix_len: u8,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct Ipv6Prefix {
pub network: Ipv6Addr,
pub prefix_len: u8,
}
impl Ipv4Prefix {
pub fn parse(s: &str) -> Result<Self, PimError> {
let (ip_str, prefix_str) = s
.split_once('/')
.ok_or_else(|| PimError::Config(format!("invalid IPv4 CIDR (missing `/`): {s}")))?;
let ip: Ipv4Addr = ip_str
.parse()
.map_err(|_| PimError::Config(format!("invalid IPv4 in CIDR: {s}")))?;
let prefix_len: u8 = prefix_str
.parse()
.map_err(|_| PimError::Config(format!("invalid prefix length in CIDR: {s}")))?;
if prefix_len > 32 {
return Err(PimError::Config(format!(
"IPv4 prefix length must be 0..=32: {s}"
)));
}
let mask = ipv4_netmask(prefix_len);
let network = Ipv4Addr::from(u32::from(ip) & mask);
Ok(Self {
network,
prefix_len,
})
}
pub fn to_cidr_string(&self) -> String {
format!("{}/{}", self.network, self.prefix_len)
}
}
impl FromStr for Ipv4Prefix {
type Err = PimError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
Self::parse(s)
}
}
impl Ipv6Prefix {
pub fn parse(s: &str) -> Result<Self, PimError> {
let (ip_str, prefix_str) = s
.split_once('/')
.ok_or_else(|| PimError::Config(format!("invalid IPv6 CIDR (missing `/`): {s}")))?;
let ip: Ipv6Addr = ip_str
.parse()
.map_err(|_| PimError::Config(format!("invalid IPv6 in CIDR: {s}")))?;
let prefix_len: u8 = prefix_str
.parse()
.map_err(|_| PimError::Config(format!("invalid prefix length in CIDR: {s}")))?;
if prefix_len > 128 {
return Err(PimError::Config(format!(
"IPv6 prefix length must be 0..=128: {s}"
)));
}
let mask = ipv6_netmask(prefix_len);
let network = Ipv6Addr::from(u128::from(ip) & mask);
Ok(Self {
network,
prefix_len,
})
}
pub fn to_cidr_string(&self) -> String {
format!("{}/{}", self.network, self.prefix_len)
}
}
impl FromStr for Ipv6Prefix {
type Err = PimError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
Self::parse(s)
}
}
fn ipv4_netmask(prefix_len: u8) -> u32 {
if prefix_len == 0 {
0
} else if prefix_len >= 32 {
u32::MAX
} else {
!((1u32 << (32 - prefix_len)) - 1)
}
}
fn ipv6_netmask(prefix_len: u8) -> u128 {
if prefix_len == 0 {
0
} else if prefix_len >= 128 {
u128::MAX
} else {
!((1u128 << (128 - prefix_len)) - 1)
}
}
fn keyed_digest(context: &str, node_id: &NodeId) -> [u8; 32] {
let key = blake3::derive_key(context, b"pim-mesh-derivation");
blake3::keyed_hash(&key, node_id.as_bytes()).into()
}
pub fn derive_mesh_ipv4(node_id: &NodeId, prefix: Ipv4Prefix) -> Ipv4Addr {
let host_bits = 32u32.saturating_sub(prefix.prefix_len as u32);
let host_mask: u32 = if host_bits == 0 {
0
} else if host_bits >= 32 {
u32::MAX
} else {
(1u32 << host_bits) - 1
};
if host_bits == 0 {
return prefix.network;
}
let digest = keyed_digest(CONTEXT_V4, node_id);
let raw = u32::from_be_bytes(digest[..4].try_into().expect("blake3 digest >= 4 bytes"));
let mut host = raw & host_mask;
if host_bits >= 2 {
if host == 0 {
host = 1;
} else if host == host_mask {
host = host_mask - 1;
}
}
let network = u32::from(prefix.network) & !host_mask;
Ipv4Addr::from(network | host)
}
pub fn derive_mesh_ipv6(node_id: &NodeId, prefix: Ipv6Prefix) -> Ipv6Addr {
let host_bits = 128u32.saturating_sub(prefix.prefix_len as u32);
let host_mask: u128 = if host_bits == 0 {
0
} else if host_bits >= 128 {
u128::MAX
} else {
(1u128 << host_bits) - 1
};
if host_bits == 0 {
return prefix.network;
}
let digest = keyed_digest(CONTEXT_V6, node_id);
let raw = u128::from_be_bytes(digest[..16].try_into().expect("blake3 digest >= 16 bytes"));
let mut host = raw & host_mask;
if host_bits >= 2 {
if host == 0 {
host = 1;
} else if host == host_mask {
host = host_mask - 1;
}
}
let network = u128::from(prefix.network) & !host_mask;
Ipv6Addr::from(network | host)
}
pub fn verify_mesh_ipv4(node_id: &NodeId, claimed: Ipv4Addr, prefix: Ipv4Prefix) -> bool {
derive_mesh_ipv4(node_id, prefix) == claimed
}
pub fn verify_mesh_ipv6(node_id: &NodeId, claimed: Ipv6Addr, prefix: Ipv6Prefix) -> bool {
derive_mesh_ipv6(node_id, prefix) == claimed
}
#[cfg(test)]
mod tests;