use std::{collections::BTreeSet, fmt, net::SocketAddr};
use data_encoding::HEXLOWER;
use n0_error::stack_error;
use serde::{Deserialize, Serialize};
use crate::{EndpointId, PublicKey, RelayUrl};
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct EndpointAddr {
pub id: EndpointId,
pub addrs: BTreeSet<TransportAddr>,
}
#[derive(
derive_more::Debug, Clone, Serialize, Deserialize, PartialEq, Eq, PartialOrd, Ord, Hash,
)]
#[non_exhaustive]
pub enum TransportAddr {
#[debug("Relay({_0})")]
Relay(RelayUrl),
Ip(SocketAddr),
Custom(CustomAddr),
}
impl TransportAddr {
pub fn is_relay(&self) -> bool {
matches!(self, Self::Relay(_))
}
pub fn is_ip(&self) -> bool {
matches!(self, Self::Ip(_))
}
pub fn is_custom(&self) -> bool {
matches!(self, Self::Custom(_))
}
}
impl fmt::Display for TransportAddr {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Relay(url) => write!(f, "relay:{url}"),
Self::Ip(addr) => write!(f, "ip:{addr}"),
Self::Custom(addr) => write!(f, "custom:{addr}"),
}
}
}
impl EndpointAddr {
pub fn new(id: PublicKey) -> Self {
EndpointAddr {
id,
addrs: Default::default(),
}
}
pub fn from_parts(id: PublicKey, addrs: impl IntoIterator<Item = TransportAddr>) -> Self {
Self {
id,
addrs: addrs.into_iter().collect(),
}
}
pub fn with_relay_url(mut self, relay_url: RelayUrl) -> Self {
self.addrs.insert(TransportAddr::Relay(relay_url));
self
}
pub fn with_ip_addr(mut self, addr: SocketAddr) -> Self {
self.addrs.insert(TransportAddr::Ip(addr));
self
}
pub fn with_addrs(mut self, addrs: impl IntoIterator<Item = TransportAddr>) -> Self {
for addr in addrs.into_iter() {
self.addrs.insert(addr);
}
self
}
pub fn is_empty(&self) -> bool {
self.addrs.is_empty()
}
pub fn ip_addrs(&self) -> impl Iterator<Item = &SocketAddr> {
self.addrs.iter().filter_map(|addr| match addr {
TransportAddr::Ip(addr) => Some(addr),
_ => None,
})
}
pub fn relay_urls(&self) -> impl Iterator<Item = &RelayUrl> {
self.addrs.iter().filter_map(|addr| match addr {
TransportAddr::Relay(url) => Some(url),
_ => None,
})
}
}
impl From<EndpointId> for EndpointAddr {
fn from(endpoint_id: EndpointId) -> Self {
EndpointAddr::new(endpoint_id)
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct CustomAddr {
id: u64,
data: CustomAddrBytes,
}
impl fmt::Display for CustomAddr {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{:x}_{}", self.id, HEXLOWER.encode(self.data.as_bytes()))
}
}
impl std::str::FromStr for CustomAddr {
type Err = CustomAddrParseError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let Some((id_str, data_str)) = s.split_once('_') else {
return Err(CustomAddrParseError::MissingSeparator);
};
let Ok(id) = u64::from_str_radix(id_str, 16) else {
return Err(CustomAddrParseError::InvalidId);
};
let Ok(data) = HEXLOWER.decode(data_str.as_bytes()) else {
return Err(CustomAddrParseError::InvalidData);
};
Ok(Self::from_parts(id, &data))
}
}
#[stack_error(derive)]
#[allow(missing_docs)]
pub enum CustomAddrParseError {
#[error("missing '_' separator")]
MissingSeparator,
#[error("invalid id")]
InvalidId,
#[error("invalid data")]
InvalidData,
}
#[derive(Clone, Serialize, Deserialize, PartialEq, Eq, PartialOrd, Ord, Hash)]
enum CustomAddrBytes {
Inline { size: u8, data: [u8; 30] },
Heap(Box<[u8]>),
}
impl fmt::Debug for CustomAddrBytes {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
if !f.alternate() {
write!(f, "[{}]", HEXLOWER.encode(self.as_bytes()))
} else {
let bytes = self.as_bytes();
match self {
Self::Inline { .. } => write!(f, "Inline[{}]", HEXLOWER.encode(bytes)),
Self::Heap(_) => write!(f, "Heap[{}]", HEXLOWER.encode(bytes)),
}
}
}
}
impl From<(u64, &[u8])> for CustomAddr {
fn from((id, data): (u64, &[u8])) -> Self {
Self::from_parts(id, data)
}
}
impl CustomAddrBytes {
fn len(&self) -> usize {
match self {
Self::Inline { size, .. } => *size as usize,
Self::Heap(data) => data.len(),
}
}
fn as_bytes(&self) -> &[u8] {
match self {
Self::Inline { size, data } => &data[..*size as usize],
Self::Heap(data) => data,
}
}
fn copy_from_slice(data: &[u8]) -> Self {
if data.len() <= 30 {
let mut inline = [0u8; 30];
inline[..data.len()].copy_from_slice(data);
Self::Inline {
size: data.len() as u8,
data: inline,
}
} else {
Self::Heap(data.to_vec().into_boxed_slice())
}
}
}
impl CustomAddr {
pub fn from_parts(id: u64, data: &[u8]) -> Self {
Self {
id,
data: CustomAddrBytes::copy_from_slice(data),
}
}
pub fn id(&self) -> u64 {
self.id
}
pub fn data(&self) -> &[u8] {
self.data.as_bytes()
}
pub fn to_vec(&self) -> Vec<u8> {
let mut out = vec![0u8; 8 + self.data.len()];
out[..8].copy_from_slice(&self.id().to_le_bytes());
out[8..].copy_from_slice(self.data());
out
}
pub fn from_bytes(data: &[u8]) -> Result<Self, &'static str> {
if data.len() < 8 {
return Err("data too short");
}
let id = u64::from_le_bytes(data[..8].try_into().expect("data length checked above"));
let data = &data[8..];
Ok(Self::from_parts(id, data))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, PartialOrd, Ord, Hash)]
#[non_exhaustive]
enum NewAddrType {
Relay(RelayUrl),
Ip(SocketAddr),
Cool(u16),
}
#[test]
fn test_roundtrip_new_addr_type() {
let old = vec![
TransportAddr::Ip("127.0.0.1:9".parse().unwrap()),
TransportAddr::Relay("https://example.com".parse().unwrap()),
];
let old_ser = postcard::to_stdvec(&old).unwrap();
let old_back: Vec<TransportAddr> = postcard::from_bytes(&old_ser).unwrap();
assert_eq!(old, old_back);
let new = vec![
NewAddrType::Ip("127.0.0.1:9".parse().unwrap()),
NewAddrType::Relay("https://example.com".parse().unwrap()),
NewAddrType::Cool(4),
];
let new_ser = postcard::to_stdvec(&new).unwrap();
let new_back: Vec<NewAddrType> = postcard::from_bytes(&new_ser).unwrap();
assert_eq!(new, new_back);
let old_new_back: Vec<NewAddrType> = postcard::from_bytes(&old_ser).unwrap();
assert_eq!(
old_new_back,
vec![
NewAddrType::Ip("127.0.0.1:9".parse().unwrap()),
NewAddrType::Relay("https://example.com".parse().unwrap()),
]
);
}
#[test]
fn test_custom_addr_roundtrip() {
let addr = CustomAddr::from_parts(1, &[0xa1, 0xb2, 0xc3, 0xd4, 0xe5, 0xf6]);
let s = addr.to_string();
assert_eq!(s, "1_a1b2c3d4e5f6");
let parsed: CustomAddr = s.parse().unwrap();
assert_eq!(addr, parsed);
let addr = CustomAddr::from_parts(42, &[0xab; 32]);
let s = addr.to_string();
assert_eq!(
s,
"2a_abababababababababababababababababababababababababababababababab"
);
let parsed: CustomAddr = s.parse().unwrap();
assert_eq!(addr, parsed);
let addr = CustomAddr::from_parts(0, &[]);
let s = addr.to_string();
assert_eq!(s, "0_");
let parsed: CustomAddr = s.parse().unwrap();
assert_eq!(addr, parsed);
let addr = CustomAddr::from_parts(0xdeadbeef, &[0x01, 0x02]);
let s = addr.to_string();
assert_eq!(s, "deadbeef_0102");
let parsed: CustomAddr = s.parse().unwrap();
assert_eq!(addr, parsed);
}
#[test]
fn test_custom_addr_parse_errors() {
assert!("abc123".parse::<CustomAddr>().is_err());
assert!("xyz_0102".parse::<CustomAddr>().is_err());
assert!("1_ghij".parse::<CustomAddr>().is_err());
assert!("1_abc".parse::<CustomAddr>().is_err());
}
}