1use std::{collections::BTreeSet, fmt, net::SocketAddr};
10
11use data_encoding::HEXLOWER;
12use n0_error::stack_error;
13use serde::{Deserialize, Serialize};
14
15use crate::{EndpointId, PublicKey, RelayUrl};
16
17#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, PartialOrd, Ord, Hash)]
42pub struct EndpointAddr {
43 pub id: EndpointId,
45 pub addrs: BTreeSet<TransportAddr>,
47}
48
49#[derive(
51 derive_more::Debug, Clone, Serialize, Deserialize, PartialEq, Eq, PartialOrd, Ord, Hash,
52)]
53#[non_exhaustive]
54pub enum TransportAddr {
55 #[debug("Relay({_0})")]
57 Relay(RelayUrl),
58 Ip(SocketAddr),
60 Custom(CustomAddr),
62}
63
64impl TransportAddr {
65 pub fn is_relay(&self) -> bool {
67 matches!(self, Self::Relay(_))
68 }
69
70 pub fn is_ip(&self) -> bool {
72 matches!(self, Self::Ip(_))
73 }
74
75 pub fn is_custom(&self) -> bool {
77 matches!(self, Self::Custom(_))
78 }
79}
80
81impl fmt::Display for TransportAddr {
82 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
83 match self {
84 Self::Relay(url) => write!(f, "relay:{url}"),
85 Self::Ip(addr) => write!(f, "ip:{addr}"),
86 Self::Custom(addr) => write!(f, "custom:{addr}"),
87 }
88 }
89}
90
91impl EndpointAddr {
92 pub fn new(id: PublicKey) -> Self {
97 EndpointAddr {
98 id,
99 addrs: Default::default(),
100 }
101 }
102
103 pub fn from_parts(id: PublicKey, addrs: impl IntoIterator<Item = TransportAddr>) -> Self {
105 Self {
106 id,
107 addrs: addrs.into_iter().collect(),
108 }
109 }
110
111 pub fn with_relay_url(mut self, relay_url: RelayUrl) -> Self {
113 self.addrs.insert(TransportAddr::Relay(relay_url));
114 self
115 }
116
117 pub fn with_ip_addr(mut self, addr: SocketAddr) -> Self {
119 self.addrs.insert(TransportAddr::Ip(addr));
120 self
121 }
122
123 pub fn with_addrs(mut self, addrs: impl IntoIterator<Item = TransportAddr>) -> Self {
125 for addr in addrs.into_iter() {
126 self.addrs.insert(addr);
127 }
128 self
129 }
130
131 pub fn is_empty(&self) -> bool {
133 self.addrs.is_empty()
134 }
135
136 pub fn ip_addrs(&self) -> impl Iterator<Item = &SocketAddr> {
138 self.addrs.iter().filter_map(|addr| match addr {
139 TransportAddr::Ip(addr) => Some(addr),
140 _ => None,
141 })
142 }
143
144 pub fn relay_urls(&self) -> impl Iterator<Item = &RelayUrl> {
148 self.addrs.iter().filter_map(|addr| match addr {
149 TransportAddr::Relay(url) => Some(url),
150 _ => None,
151 })
152 }
153}
154
155impl From<EndpointId> for EndpointAddr {
156 fn from(endpoint_id: EndpointId) -> Self {
157 EndpointAddr::new(endpoint_id)
158 }
159}
160
161#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, PartialOrd, Ord, Hash)]
186pub struct CustomAddr {
187 id: u64,
189 data: CustomAddrBytes,
191}
192
193impl fmt::Display for CustomAddr {
194 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
195 write!(f, "{:x}_{}", self.id, HEXLOWER.encode(self.data.as_bytes()))
196 }
197}
198
199impl std::str::FromStr for CustomAddr {
200 type Err = CustomAddrParseError;
201
202 fn from_str(s: &str) -> Result<Self, Self::Err> {
203 let Some((id_str, data_str)) = s.split_once('_') else {
204 return Err(CustomAddrParseError::MissingSeparator);
205 };
206 let Ok(id) = u64::from_str_radix(id_str, 16) else {
207 return Err(CustomAddrParseError::InvalidId);
208 };
209 let Ok(data) = HEXLOWER.decode(data_str.as_bytes()) else {
210 return Err(CustomAddrParseError::InvalidData);
211 };
212 Ok(Self::from_parts(id, &data))
213 }
214}
215
216#[stack_error(derive)]
222#[allow(missing_docs)]
223pub enum CustomAddrParseError {
224 #[error("missing '_' separator")]
226 MissingSeparator,
227 #[error("invalid id")]
229 InvalidId,
230 #[error("invalid data")]
232 InvalidData,
233}
234
235#[derive(Clone, Serialize, Deserialize, PartialEq, Eq, PartialOrd, Ord, Hash)]
236enum CustomAddrBytes {
237 Inline { size: u8, data: [u8; 30] },
238 Heap(Box<[u8]>),
239}
240
241impl fmt::Debug for CustomAddrBytes {
242 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
243 if !f.alternate() {
244 write!(f, "[{}]", HEXLOWER.encode(self.as_bytes()))
245 } else {
246 let bytes = self.as_bytes();
247 match self {
248 Self::Inline { .. } => write!(f, "Inline[{}]", HEXLOWER.encode(bytes)),
249 Self::Heap(_) => write!(f, "Heap[{}]", HEXLOWER.encode(bytes)),
250 }
251 }
252 }
253}
254
255impl From<(u64, &[u8])> for CustomAddr {
256 fn from((id, data): (u64, &[u8])) -> Self {
257 Self::from_parts(id, data)
258 }
259}
260
261impl CustomAddrBytes {
262 fn len(&self) -> usize {
263 match self {
264 Self::Inline { size, .. } => *size as usize,
265 Self::Heap(data) => data.len(),
266 }
267 }
268
269 fn as_bytes(&self) -> &[u8] {
270 match self {
271 Self::Inline { size, data } => &data[..*size as usize],
272 Self::Heap(data) => data,
273 }
274 }
275
276 fn copy_from_slice(data: &[u8]) -> Self {
277 if data.len() <= 30 {
278 let mut inline = [0u8; 30];
279 inline[..data.len()].copy_from_slice(data);
280 Self::Inline {
281 size: data.len() as u8,
282 data: inline,
283 }
284 } else {
285 Self::Heap(data.to_vec().into_boxed_slice())
286 }
287 }
288}
289
290impl CustomAddr {
291 pub fn from_parts(id: u64, data: &[u8]) -> Self {
293 Self {
294 id,
295 data: CustomAddrBytes::copy_from_slice(data),
296 }
297 }
298
299 pub fn id(&self) -> u64 {
307 self.id
308 }
309
310 pub fn data(&self) -> &[u8] {
317 self.data.as_bytes()
318 }
319
320 pub fn to_vec(&self) -> Vec<u8> {
324 let mut out = vec![0u8; 8 + self.data.len()];
325 out[..8].copy_from_slice(&self.id().to_le_bytes());
326 out[8..].copy_from_slice(self.data());
327 out
328 }
329
330 pub fn from_bytes(data: &[u8]) -> Result<Self, &'static str> {
334 if data.len() < 8 {
335 return Err("data too short");
336 }
337 let id = u64::from_le_bytes(data[..8].try_into().expect("data length checked above"));
338 let data = &data[8..];
339 Ok(Self::from_parts(id, data))
340 }
341}
342
343#[cfg(test)]
344mod tests {
345 use super::*;
346
347 #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, PartialOrd, Ord, Hash)]
348 #[non_exhaustive]
349 enum NewAddrType {
350 Relay(RelayUrl),
352 Ip(SocketAddr),
354 Cool(u16),
356 }
357
358 #[test]
359 fn test_roundtrip_new_addr_type() {
360 let old = vec![
361 TransportAddr::Ip("127.0.0.1:9".parse().unwrap()),
362 TransportAddr::Relay("https://example.com".parse().unwrap()),
363 ];
364 let old_ser = postcard::to_stdvec(&old).unwrap();
365 let old_back: Vec<TransportAddr> = postcard::from_bytes(&old_ser).unwrap();
366 assert_eq!(old, old_back);
367
368 let new = vec![
369 NewAddrType::Ip("127.0.0.1:9".parse().unwrap()),
370 NewAddrType::Relay("https://example.com".parse().unwrap()),
371 NewAddrType::Cool(4),
372 ];
373 let new_ser = postcard::to_stdvec(&new).unwrap();
374 let new_back: Vec<NewAddrType> = postcard::from_bytes(&new_ser).unwrap();
375
376 assert_eq!(new, new_back);
377
378 let old_new_back: Vec<NewAddrType> = postcard::from_bytes(&old_ser).unwrap();
380
381 assert_eq!(
382 old_new_back,
383 vec![
384 NewAddrType::Ip("127.0.0.1:9".parse().unwrap()),
385 NewAddrType::Relay("https://example.com".parse().unwrap()),
386 ]
387 );
388 }
389
390 #[test]
391 fn test_custom_addr_roundtrip() {
392 let addr = CustomAddr::from_parts(1, &[0xa1, 0xb2, 0xc3, 0xd4, 0xe5, 0xf6]);
394 let s = addr.to_string();
395 assert_eq!(s, "1_a1b2c3d4e5f6");
396 let parsed: CustomAddr = s.parse().unwrap();
397 assert_eq!(addr, parsed);
398
399 let addr = CustomAddr::from_parts(42, &[0xab; 32]);
401 let s = addr.to_string();
402 assert_eq!(
403 s,
404 "2a_abababababababababababababababababababababababababababababababab"
405 );
406 let parsed: CustomAddr = s.parse().unwrap();
407 assert_eq!(addr, parsed);
408
409 let addr = CustomAddr::from_parts(0, &[]);
411 let s = addr.to_string();
412 assert_eq!(s, "0_");
413 let parsed: CustomAddr = s.parse().unwrap();
414 assert_eq!(addr, parsed);
415
416 let addr = CustomAddr::from_parts(0xdeadbeef, &[0x01, 0x02]);
418 let s = addr.to_string();
419 assert_eq!(s, "deadbeef_0102");
420 let parsed: CustomAddr = s.parse().unwrap();
421 assert_eq!(addr, parsed);
422 }
423
424 #[test]
425 fn test_custom_addr_parse_errors() {
426 assert!("abc123".parse::<CustomAddr>().is_err());
428
429 assert!("xyz_0102".parse::<CustomAddr>().is_err());
431
432 assert!("1_ghij".parse::<CustomAddr>().is_err());
434
435 assert!("1_abc".parse::<CustomAddr>().is_err());
437 }
438}