use std::{
fmt,
hash::Hash,
net::{IpAddr, Ipv6Addr, SocketAddr, SocketAddrV6},
sync::Arc,
};
use n0_error::{e, stack_error};
use portable_atomic::{AtomicU64, Ordering};
use rustc_hash::FxHashMap;
use tracing::trace;
const ADDR_PREFIXL: u8 = 0xfd;
const ADDR_GLOBAL_ID: [u8; 5] = [21, 7, 10, 81, 11];
const RELAY_MAPPED_SUBNET: [u8; 2] = [0, 1];
const CUSTOM_MAPPED_SUBNET: [u8; 2] = [0, 3];
const ENDPOINT_ID_SUBNET: [u8; 2] = [0; 2];
pub const DEFAULT_FAKE_ADDR: SocketAddrV6 = SocketAddrV6::new(
Ipv6Addr::new(
u16::from_be_bytes([ADDR_PREFIXL, 21]),
u16::from_be_bytes([7, 10]),
u16::from_be_bytes([81, 11]),
u16::from_be_bytes([0, 0]),
u16::MAX,
u16::MAX,
u16::MAX,
u16::MAX,
),
MAPPED_PORT,
0,
0,
);
const MAPPED_PORT: u16 = 12345;
static RELAY_ADDR_COUNTER: AtomicU64 = AtomicU64::new(1);
static ENDPOINT_ID_ADDR_COUNTER: AtomicU64 = AtomicU64::new(1);
static CUSTOM_ADDR_COUNTER: AtomicU64 = AtomicU64::new(1);
pub(crate) trait MappedAddr {
fn generate() -> Self;
fn private_socket_addr(&self) -> SocketAddr;
}
#[derive(Clone, Debug)]
pub(crate) enum MultipathMappedAddr {
Mixed(EndpointIdMappedAddr),
Relay(RelayMappedAddr),
Ip(SocketAddr),
Custom(CustomMappedAddr),
}
impl From<SocketAddr> for MultipathMappedAddr {
fn from(value: SocketAddr) -> Self {
match value.ip() {
IpAddr::V4(_) => Self::Ip(value),
IpAddr::V6(addr) => {
if let Ok(addr) = EndpointIdMappedAddr::try_from(addr) {
return Self::Mixed(addr);
}
if let Ok(addr) = RelayMappedAddr::try_from(addr) {
return Self::Relay(addr);
}
if let Ok(addr) = CustomMappedAddr::try_from(addr) {
return Self::Custom(addr);
}
Self::Ip(value)
}
}
}
}
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
pub(crate) struct EndpointIdMappedAddr(Ipv6Addr);
impl MappedAddr for EndpointIdMappedAddr {
fn generate() -> Self {
let mut addr = [0u8; 16];
addr[0] = ADDR_PREFIXL;
addr[1..6].copy_from_slice(&ADDR_GLOBAL_ID);
addr[6..8].copy_from_slice(&ENDPOINT_ID_SUBNET);
let counter = ENDPOINT_ID_ADDR_COUNTER.fetch_add(1, Ordering::Relaxed);
addr[8..16].copy_from_slice(&counter.to_be_bytes());
Self(Ipv6Addr::from(addr))
}
fn private_socket_addr(&self) -> SocketAddr {
SocketAddr::new(IpAddr::from(self.0), MAPPED_PORT)
}
}
impl std::fmt::Display for EndpointIdMappedAddr {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(f, "EndpointIdMappedAddr({})", self.0)
}
}
impl TryFrom<Ipv6Addr> for EndpointIdMappedAddr {
type Error = EndpointIdMappedAddrError;
fn try_from(value: Ipv6Addr) -> Result<Self, Self::Error> {
let octets = value.octets();
if octets[0] == ADDR_PREFIXL
&& octets[1..6] == ADDR_GLOBAL_ID
&& octets[6..8] == ENDPOINT_ID_SUBNET
{
return Ok(Self(value));
}
Err(e!(EndpointIdMappedAddrError))
}
}
#[stack_error(derive, add_meta)]
#[error("Failed to convert")]
pub(crate) struct EndpointIdMappedAddrError;
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, Ord, PartialOrd)]
pub(crate) struct RelayMappedAddr(Ipv6Addr);
impl MappedAddr for RelayMappedAddr {
fn generate() -> Self {
let mut addr = [0u8; 16];
addr[0] = ADDR_PREFIXL;
addr[1..6].copy_from_slice(&ADDR_GLOBAL_ID);
addr[6..8].copy_from_slice(&RELAY_MAPPED_SUBNET);
let counter = RELAY_ADDR_COUNTER.fetch_add(1, Ordering::Relaxed);
addr[8..16].copy_from_slice(&counter.to_be_bytes());
Self(Ipv6Addr::from(addr))
}
fn private_socket_addr(&self) -> SocketAddr {
SocketAddr::new(IpAddr::from(self.0), MAPPED_PORT)
}
}
impl TryFrom<Ipv6Addr> for RelayMappedAddr {
type Error = RelayMappedAddrError;
fn try_from(value: Ipv6Addr) -> std::result::Result<Self, Self::Error> {
let octets = value.octets();
if octets[0] == ADDR_PREFIXL
&& octets[1..6] == ADDR_GLOBAL_ID
&& octets[6..8] == RELAY_MAPPED_SUBNET
{
return Ok(Self(value));
}
Err(e!(RelayMappedAddrError))
}
}
#[stack_error(derive, add_meta)]
#[error("Failed to convert")]
pub(crate) struct RelayMappedAddrError;
impl std::fmt::Display for RelayMappedAddr {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(f, "RelayMappedAddr({})", self.0)
}
}
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, Ord, PartialOrd)]
pub(crate) struct CustomMappedAddr(Ipv6Addr);
impl MappedAddr for CustomMappedAddr {
fn generate() -> Self {
let mut addr = [0u8; 16];
addr[0] = ADDR_PREFIXL;
addr[1..6].copy_from_slice(&ADDR_GLOBAL_ID);
addr[6..8].copy_from_slice(&CUSTOM_MAPPED_SUBNET);
let counter = CUSTOM_ADDR_COUNTER.fetch_add(1, Ordering::Relaxed);
addr[8..16].copy_from_slice(&counter.to_be_bytes());
Self(Ipv6Addr::from(addr))
}
fn private_socket_addr(&self) -> SocketAddr {
SocketAddr::new(IpAddr::from(self.0), MAPPED_PORT)
}
}
impl TryFrom<Ipv6Addr> for CustomMappedAddr {
type Error = CustomMappedAddrError;
fn try_from(value: Ipv6Addr) -> std::result::Result<Self, Self::Error> {
let octets = value.octets();
if octets[0] == ADDR_PREFIXL
&& octets[1..6] == ADDR_GLOBAL_ID
&& octets[6..8] == CUSTOM_MAPPED_SUBNET
{
return Ok(Self(value));
}
Err(e!(CustomMappedAddrError))
}
}
#[stack_error(derive, add_meta)]
#[error("Failed to convert")]
pub(crate) struct CustomMappedAddrError;
impl std::fmt::Display for CustomMappedAddr {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(f, "CustomMappedAddr({})", self.0)
}
}
#[derive(Debug, Clone)]
pub(super) struct AddrMap<K, V> {
inner: Arc<std::sync::Mutex<AddrMapInner<K, V>>>,
}
impl<K, V> Default for AddrMap<K, V> {
fn default() -> Self {
Self {
inner: Default::default(),
}
}
}
impl<K, V> AddrMap<K, V>
where
K: Eq + Hash + Clone + fmt::Debug,
V: MappedAddr + Eq + Hash + Copy + fmt::Debug,
{
pub(super) fn get(&self, key: &K) -> V {
let mut inner = self.inner.lock().expect("poisoned");
match inner.addrs.get(key) {
Some(addr) => *addr,
None => {
let addr = V::generate();
inner.addrs.insert(key.clone(), addr);
inner.lookup.insert(addr, key.clone());
trace!(?addr, ?key, "generated new addr");
addr
}
}
}
pub(super) fn lookup(&self, addr: &V) -> Option<K> {
let inner = self.inner.lock().expect("poisoned");
inner.lookup.get(addr).cloned()
}
}
#[derive(Debug)]
struct AddrMapInner<K, V> {
addrs: FxHashMap<K, V>,
lookup: FxHashMap<V, K>,
}
impl<K, V> Default for AddrMapInner<K, V> {
fn default() -> Self {
Self {
addrs: Default::default(),
lookup: Default::default(),
}
}
}