1use std::{collections::BTreeSet, fmt, net::SocketAddr};
10
11use data_encoding::HEXLOWER;
12use n0_error::stack_error;
13use serde::{Deserialize, Serialize};
14
15use crate::{EndpointId, PublicKey, RelayUrl};
16
17#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, PartialOrd, Ord, Hash)]
42pub struct EndpointAddr {
43 pub id: EndpointId,
45 pub addrs: BTreeSet<TransportAddr>,
47}
48
49#[derive(
51 derive_more::Debug, Clone, Serialize, Deserialize, PartialEq, Eq, PartialOrd, Ord, Hash,
52)]
53#[non_exhaustive]
54pub enum TransportAddr {
55 #[debug("Relay({_0})")]
57 Relay(RelayUrl),
58 Ip(SocketAddr),
60 Custom(CustomAddr),
62}
63
64impl TransportAddr {
65 pub fn is_relay(&self) -> bool {
67 matches!(self, Self::Relay(_))
68 }
69
70 pub fn is_ip(&self) -> bool {
72 matches!(self, Self::Ip(_))
73 }
74
75 pub fn is_custom(&self) -> bool {
77 matches!(self, Self::Custom(_))
78 }
79}
80
81impl fmt::Display for TransportAddr {
82 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
83 match self {
84 Self::Relay(url) => write!(f, "relay:{url}"),
85 Self::Ip(addr) => write!(f, "ip:{addr}"),
86 Self::Custom(addr) => write!(f, "custom:{addr}"),
87 }
88 }
89}
90
91impl EndpointAddr {
92 pub fn new(id: PublicKey) -> Self {
97 EndpointAddr {
98 id,
99 addrs: Default::default(),
100 }
101 }
102
103 pub fn from_parts(id: PublicKey, addrs: impl IntoIterator<Item = TransportAddr>) -> Self {
105 Self {
106 id,
107 addrs: addrs.into_iter().collect(),
108 }
109 }
110
111 pub fn with_relay_url(mut self, relay_url: RelayUrl) -> Self {
113 self.addrs.insert(TransportAddr::Relay(relay_url));
114 self
115 }
116
117 pub fn with_ip_addr(mut self, addr: SocketAddr) -> Self {
119 self.addrs.insert(TransportAddr::Ip(addr));
120 self
121 }
122
123 pub fn with_addrs(mut self, addrs: impl IntoIterator<Item = TransportAddr>) -> Self {
125 for addr in addrs.into_iter() {
126 self.addrs.insert(addr);
127 }
128 self
129 }
130
131 pub fn is_empty(&self) -> bool {
133 self.addrs.is_empty()
134 }
135
136 pub fn ip_addrs(&self) -> impl Iterator<Item = &SocketAddr> {
138 self.addrs.iter().filter_map(|addr| match addr {
139 TransportAddr::Ip(addr) => Some(addr),
140 _ => None,
141 })
142 }
143
144 pub fn relay_urls(&self) -> impl Iterator<Item = &RelayUrl> {
148 self.addrs.iter().filter_map(|addr| match addr {
149 TransportAddr::Relay(url) => Some(url),
150 _ => None,
151 })
152 }
153}
154
155impl From<EndpointId> for EndpointAddr {
156 fn from(endpoint_id: EndpointId) -> Self {
157 EndpointAddr::new(endpoint_id)
158 }
159}
160
161#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, PartialOrd, Ord, Hash)]
169pub struct CustomAddr {
170 id: u64,
172 data: CustomAddrBytes,
174}
175
176impl fmt::Display for CustomAddr {
177 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
178 write!(f, "{:x}_{}", self.id, HEXLOWER.encode(self.data.as_bytes()))
179 }
180}
181
182impl std::str::FromStr for CustomAddr {
183 type Err = CustomAddrParseError;
184
185 fn from_str(s: &str) -> Result<Self, Self::Err> {
186 let Some((id_str, data_str)) = s.split_once('_') else {
187 return Err(CustomAddrParseError::MissingSeparator);
188 };
189 let Ok(id) = u64::from_str_radix(id_str, 16) else {
190 return Err(CustomAddrParseError::InvalidId);
191 };
192 let Ok(data) = HEXLOWER.decode(data_str.as_bytes()) else {
193 return Err(CustomAddrParseError::InvalidData);
194 };
195 Ok(Self::from_parts(id, &data))
196 }
197}
198
199#[stack_error(derive)]
205#[allow(missing_docs)]
206pub enum CustomAddrParseError {
207 #[error("missing '_' separator")]
209 MissingSeparator,
210 #[error("invalid id")]
212 InvalidId,
213 #[error("invalid data")]
215 InvalidData,
216}
217
218#[derive(Clone, Serialize, Deserialize, PartialEq, Eq, PartialOrd, Ord, Hash)]
219enum CustomAddrBytes {
220 Inline { size: u8, data: [u8; 30] },
221 Heap(Box<[u8]>),
222}
223
224impl fmt::Debug for CustomAddrBytes {
225 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
226 if !f.alternate() {
227 write!(f, "[{}]", HEXLOWER.encode(self.as_bytes()))
228 } else {
229 let bytes = self.as_bytes();
230 match self {
231 Self::Inline { .. } => write!(f, "Inline[{}]", HEXLOWER.encode(bytes)),
232 Self::Heap(_) => write!(f, "Heap[{}]", HEXLOWER.encode(bytes)),
233 }
234 }
235 }
236}
237
238impl From<(u64, &[u8])> for CustomAddr {
239 fn from((id, data): (u64, &[u8])) -> Self {
240 Self::from_parts(id, data)
241 }
242}
243
244impl CustomAddrBytes {
245 pub fn len(&self) -> usize {
246 match self {
247 Self::Inline { size, .. } => *size as usize,
248 Self::Heap(data) => data.len(),
249 }
250 }
251
252 pub fn as_bytes(&self) -> &[u8] {
253 match self {
254 Self::Inline { size, data } => &data[..*size as usize],
255 Self::Heap(data) => data,
256 }
257 }
258
259 pub fn copy_from_slice(data: &[u8]) -> Self {
260 if data.len() <= 30 {
261 let mut inline = [0u8; 30];
262 inline[..data.len()].copy_from_slice(data);
263 Self::Inline {
264 size: data.len() as u8,
265 data: inline,
266 }
267 } else {
268 Self::Heap(data.to_vec().into_boxed_slice())
269 }
270 }
271}
272
273impl CustomAddr {
274 pub fn from_parts(id: u64, data: &[u8]) -> Self {
276 Self {
277 id,
278 data: CustomAddrBytes::copy_from_slice(data),
279 }
280 }
281
282 pub fn id(&self) -> u64 {
290 self.id
291 }
292
293 pub fn data(&self) -> &[u8] {
300 self.data.as_bytes()
301 }
302
303 pub fn as_vec(&self) -> Vec<u8> {
305 let mut out = vec![0u8; 8 + self.data.len()];
306 out[..8].copy_from_slice(&self.id().to_le_bytes());
307 out[8..].copy_from_slice(self.data());
308 out
309 }
310
311 pub fn from_bytes(data: &[u8]) -> Result<Self, &'static str> {
313 if data.len() < 8 {
314 return Err("data too short");
315 }
316 let id = u64::from_le_bytes(data[..8].try_into().expect("data length checked above"));
317 let data = &data[8..];
318 Ok(Self::from_parts(id, data))
319 }
320}
321
322#[cfg(test)]
323mod tests {
324 use super::*;
325
326 #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, PartialOrd, Ord, Hash)]
327 #[non_exhaustive]
328 enum NewAddrType {
329 Relay(RelayUrl),
331 Ip(SocketAddr),
333 Cool(u16),
335 }
336
337 #[test]
338 fn test_roundtrip_new_addr_type() {
339 let old = vec![
340 TransportAddr::Ip("127.0.0.1:9".parse().unwrap()),
341 TransportAddr::Relay("https://example.com".parse().unwrap()),
342 ];
343 let old_ser = postcard::to_stdvec(&old).unwrap();
344 let old_back: Vec<TransportAddr> = postcard::from_bytes(&old_ser).unwrap();
345 assert_eq!(old, old_back);
346
347 let new = vec![
348 NewAddrType::Ip("127.0.0.1:9".parse().unwrap()),
349 NewAddrType::Relay("https://example.com".parse().unwrap()),
350 NewAddrType::Cool(4),
351 ];
352 let new_ser = postcard::to_stdvec(&new).unwrap();
353 let new_back: Vec<NewAddrType> = postcard::from_bytes(&new_ser).unwrap();
354
355 assert_eq!(new, new_back);
356
357 let old_new_back: Vec<NewAddrType> = postcard::from_bytes(&old_ser).unwrap();
359
360 assert_eq!(
361 old_new_back,
362 vec![
363 NewAddrType::Ip("127.0.0.1:9".parse().unwrap()),
364 NewAddrType::Relay("https://example.com".parse().unwrap()),
365 ]
366 );
367 }
368
369 #[test]
370 fn test_custom_addr_roundtrip() {
371 let addr = CustomAddr::from_parts(1, &[0xa1, 0xb2, 0xc3, 0xd4, 0xe5, 0xf6]);
373 let s = addr.to_string();
374 assert_eq!(s, "1_a1b2c3d4e5f6");
375 let parsed: CustomAddr = s.parse().unwrap();
376 assert_eq!(addr, parsed);
377
378 let addr = CustomAddr::from_parts(42, &[0xab; 32]);
380 let s = addr.to_string();
381 assert_eq!(
382 s,
383 "2a_abababababababababababababababababababababababababababababababab"
384 );
385 let parsed: CustomAddr = s.parse().unwrap();
386 assert_eq!(addr, parsed);
387
388 let addr = CustomAddr::from_parts(0, &[]);
390 let s = addr.to_string();
391 assert_eq!(s, "0_");
392 let parsed: CustomAddr = s.parse().unwrap();
393 assert_eq!(addr, parsed);
394
395 let addr = CustomAddr::from_parts(0xdeadbeef, &[0x01, 0x02]);
397 let s = addr.to_string();
398 assert_eq!(s, "deadbeef_0102");
399 let parsed: CustomAddr = s.parse().unwrap();
400 assert_eq!(addr, parsed);
401 }
402
403 #[test]
404 fn test_custom_addr_parse_errors() {
405 assert!("abc123".parse::<CustomAddr>().is_err());
407
408 assert!("xyz_0102".parse::<CustomAddr>().is_err());
410
411 assert!("1_ghij".parse::<CustomAddr>().is_err());
413
414 assert!("1_abc".parse::<CustomAddr>().is_err());
416 }
417}