1use std::collections::HashMap;
23use std::net::SocketAddr;
24use std::pin::Pin;
25use std::sync::Arc;
26use std::task::{Context, Poll};
27
28use futures_util::stream::{SelectAll, Stream, StreamExt};
29
30use crate::bootstrap_cache::BootstrapCache;
31use crate::link_transport::BoxStream;
32use crate::nat_traversal_api::PeerId;
33
34#[derive(Debug, Clone)]
42pub struct LookupError {
43 pub message: String,
45 pub retryable: bool,
48}
49
50impl LookupError {
51 pub fn transient(message: impl Into<String>) -> Self {
53 Self {
54 message: message.into(),
55 retryable: true,
56 }
57 }
58
59 pub fn structural(message: impl Into<String>) -> Self {
61 Self {
62 message: message.into(),
63 retryable: false,
64 }
65 }
66}
67
68impl std::fmt::Display for LookupError {
69 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
70 write!(f, "lookup error: {}", self.message)
71 }
72}
73
74impl std::error::Error for LookupError {}
75
76pub trait AddressLookup: Send + Sync + 'static {
86 fn name(&self) -> &'static str;
88
89 fn lookup(&self, peer_id: PeerId) -> BoxStream<'static, Result<SocketAddr, LookupError>>;
91}
92
93#[derive(Default, Clone)]
102pub struct LookupRegistry {
103 services: Vec<Arc<dyn AddressLookup>>,
104}
105
106impl std::fmt::Debug for LookupRegistry {
107 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
108 f.debug_struct("LookupRegistry")
109 .field("service_count", &self.services.len())
110 .field(
111 "services",
112 &self.services.iter().map(|s| s.name()).collect::<Vec<_>>(),
113 )
114 .finish()
115 }
116}
117
118impl LookupRegistry {
119 pub fn new() -> Self {
121 Self::default()
122 }
123
124 pub fn add_service<S>(&mut self, service: S)
126 where
127 S: AddressLookup,
128 {
129 self.services.push(Arc::new(service));
130 }
131
132 pub fn add_service_arc(&mut self, service: Arc<dyn AddressLookup>) {
135 self.services.push(service);
136 }
137
138 pub fn len(&self) -> usize {
140 self.services.len()
141 }
142
143 pub fn is_empty(&self) -> bool {
145 self.services.is_empty()
146 }
147
148 pub fn service_names(&self) -> Vec<&'static str> {
150 self.services.iter().map(|s| s.name()).collect()
151 }
152
153 pub fn lookup(&self, peer_id: PeerId) -> ParallelLookupStream {
166 let mut inner = SelectAll::new();
167
168 for service in &self.services {
169 let service = Arc::clone(service);
170 inner.push(NamedLookup::new(service, peer_id));
171 }
172
173 ParallelLookupStream { inner }
174 }
175}
176
177pub struct ParallelLookupStream {
180 inner: SelectAll<NamedLookup>,
181}
182
183impl Stream for ParallelLookupStream {
184 type Item = (&'static str, Result<SocketAddr, LookupError>);
185
186 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
187 self.inner.poll_next_unpin(cx)
188 }
189}
190
191impl std::fmt::Debug for ParallelLookupStream {
192 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
193 f.debug_struct("ParallelLookupStream")
194 .field("active_lookups", &self.inner.len())
195 .finish()
196 }
197}
198
199struct NamedLookup {
208 name: &'static str,
209 stream: BoxStream<'static, Result<SocketAddr, LookupError>>,
210}
211
212impl NamedLookup {
213 fn new(service: Arc<dyn AddressLookup>, peer_id: PeerId) -> Self {
214 let name = service.name();
215 let stream = service.lookup(peer_id);
216 Self { name, stream }
217 }
218}
219
220impl Stream for NamedLookup {
221 type Item = (&'static str, Result<SocketAddr, LookupError>);
222
223 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
224 let name = self.name;
225 match self.stream.as_mut().poll_next(cx) {
226 Poll::Ready(Some(item)) => Poll::Ready(Some((name, item))),
227 Poll::Ready(None) => Poll::Ready(None),
228 Poll::Pending => Poll::Pending,
229 }
230 }
231}
232
233#[derive(Clone)]
241pub struct BootstrapCacheLookup {
242 cache: Arc<BootstrapCache>,
243}
244
245impl BootstrapCacheLookup {
246 pub fn new(cache: Arc<BootstrapCache>) -> Self {
248 Self { cache }
249 }
250}
251
252impl std::fmt::Debug for BootstrapCacheLookup {
253 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
254 f.debug_struct("BootstrapCacheLookup").finish()
255 }
256}
257
258impl AddressLookup for BootstrapCacheLookup {
259 fn name(&self) -> &'static str {
260 "bootstrap-cache"
261 }
262
263 fn lookup(&self, peer_id: PeerId) -> BoxStream<'static, Result<SocketAddr, LookupError>> {
264 let cache = Arc::clone(&self.cache);
265 let stream = futures_util::stream::once(async move {
266 let peer = cache.get_peer(&peer_id).await;
267 match peer {
268 Some(p) if !p.addresses.is_empty() => Ok(p.addresses),
269 Some(_) => Err(LookupError::structural(format!(
270 "bootstrap cache: peer {:?} present but has no addresses",
271 peer_id
272 ))),
273 None => Err(LookupError::structural(format!(
274 "bootstrap cache: peer {:?} not present",
275 peer_id
276 ))),
277 }
278 })
279 .flat_map(|res| match res {
280 Ok(addrs) => {
281 let items: Vec<Result<SocketAddr, LookupError>> =
282 addrs.into_iter().map(Ok).collect();
283 futures_util::stream::iter(items).boxed()
284 }
285 Err(e) => futures_util::stream::iter(vec![Err(e)]).boxed(),
286 });
287 Box::pin(stream)
288 }
289}
290
291#[derive(Default, Clone)]
298pub struct MdnsLookup {
299 inner: Arc<tokio::sync::RwLock<HashMap<PeerId, Vec<SocketAddr>>>>,
300}
301
302impl MdnsLookup {
303 pub fn new() -> Self {
305 Self::default()
306 }
307
308 pub async fn upsert(&self, peer_id: PeerId, addrs: Vec<SocketAddr>) {
310 let mut guard = self.inner.write().await;
311 guard.insert(peer_id, addrs);
312 }
313
314 pub async fn forget(&self, peer_id: &PeerId) {
316 let mut guard = self.inner.write().await;
317 guard.remove(peer_id);
318 }
319}
320
321impl std::fmt::Debug for MdnsLookup {
322 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
323 f.debug_struct("MdnsLookup").finish()
324 }
325}
326
327impl AddressLookup for MdnsLookup {
328 fn name(&self) -> &'static str {
329 "mdns"
330 }
331
332 fn lookup(&self, peer_id: PeerId) -> BoxStream<'static, Result<SocketAddr, LookupError>> {
333 let inner = Arc::clone(&self.inner);
334 let stream = futures_util::stream::once(async move {
335 let guard = inner.read().await;
336 match guard.get(&peer_id).cloned() {
337 Some(addrs) if !addrs.is_empty() => Ok(addrs),
338 Some(_) => Err(LookupError::structural(format!(
339 "mdns: peer {:?} present with empty address list",
340 peer_id
341 ))),
342 None => Err(LookupError::structural(format!(
343 "mdns: peer {:?} not present",
344 peer_id
345 ))),
346 }
347 })
348 .flat_map(|res| match res {
349 Ok(addrs) => {
350 let items: Vec<Result<SocketAddr, LookupError>> =
351 addrs.into_iter().map(Ok).collect();
352 futures_util::stream::iter(items).boxed()
353 }
354 Err(e) => futures_util::stream::iter(vec![Err(e)]).boxed(),
355 });
356 Box::pin(stream)
357 }
358}
359
360#[derive(Clone)]
365pub struct HardcodedLookup {
366 name: &'static str,
367 table: Arc<HashMap<PeerId, Vec<SocketAddr>>>,
368}
369
370impl HardcodedLookup {
371 pub fn from_map(name: &'static str, table: HashMap<PeerId, Vec<SocketAddr>>) -> Self {
373 Self {
374 name,
375 table: Arc::new(table),
376 }
377 }
378
379 pub fn from_pairs(
381 name: &'static str,
382 pairs: impl IntoIterator<Item = (PeerId, Vec<SocketAddr>)>,
383 ) -> Self {
384 let mut table: HashMap<PeerId, Vec<SocketAddr>> = HashMap::new();
385 for (peer, addrs) in pairs {
386 table.entry(peer).or_default().extend(addrs);
387 }
388 Self::from_map(name, table)
389 }
390}
391
392impl std::fmt::Debug for HardcodedLookup {
393 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
394 f.debug_struct("HardcodedLookup")
395 .field("name", &self.name)
396 .field("entries", &self.table.len())
397 .finish()
398 }
399}
400
401impl AddressLookup for HardcodedLookup {
402 fn name(&self) -> &'static str {
403 self.name
404 }
405
406 fn lookup(&self, peer_id: PeerId) -> BoxStream<'static, Result<SocketAddr, LookupError>> {
407 let addrs = self.table.get(&peer_id).cloned().unwrap_or_default();
408 if addrs.is_empty() {
409 let err = LookupError::structural(format!(
410 "hardcoded[{}]: peer {:?} not present",
411 self.name, peer_id
412 ));
413 Box::pin(futures_util::stream::iter(vec![Err(err)]))
414 } else {
415 let items: Vec<Result<SocketAddr, LookupError>> = addrs.into_iter().map(Ok).collect();
416 Box::pin(futures_util::stream::iter(items))
417 }
418 }
419}
420
421#[cfg(test)]
422mod tests {
423 use super::*;
424 use crate::bootstrap_cache::BootstrapCacheConfig;
425 use futures_util::StreamExt;
426 use std::sync::atomic::AtomicUsize;
427 use tempfile::TempDir;
428
429 fn addr(port: u16) -> SocketAddr {
430 format!("127.0.0.1:{port}").parse().unwrap()
431 }
432
433 fn peer(byte: u8) -> PeerId {
434 PeerId([byte; 32])
435 }
436
437 async fn drain<S, T>(mut s: Pin<Box<S>>) -> Vec<T>
439 where
440 S: Stream<Item = T> + ?Sized,
441 {
442 let mut out = Vec::new();
443 while let Some(x) = s.next().await {
444 out.push(x);
445 }
446 out
447 }
448
449 async fn drain_registry(
451 mut s: ParallelLookupStream,
452 ) -> Vec<(&'static str, Result<SocketAddr, LookupError>)> {
453 let mut out = Vec::new();
454 while let Some(x) = s.next().await {
455 out.push(x);
456 }
457 out
458 }
459
460 #[tokio::test]
461 async fn hardcoded_lookup_yields_addresses() {
462 let p = peer(1);
463 let lookup =
464 HardcodedLookup::from_pairs("test-static", [(p, vec![addr(5000), addr(5001)])]);
465 assert_eq!(lookup.name(), "test-static");
466
467 let items = drain(Box::pin(lookup.lookup(p))).await;
468 assert_eq!(items.len(), 2);
469 assert!(matches!(&items[0], Ok(a) if a.port() == 5000));
470 assert!(matches!(&items[1], Ok(a) if a.port() == 5001));
471 }
472
473 #[tokio::test]
474 async fn hardcoded_lookup_missing_peer_yields_structural_error() {
475 let lookup = HardcodedLookup::from_pairs("static", []);
476 let items = drain(Box::pin(lookup.lookup(peer(7)))).await;
477 assert_eq!(items.len(), 1);
478 assert!(matches!(&items[0], Err(e) if !e.retryable));
479 }
480
481 #[tokio::test]
482 async fn mdns_lookup_yields_after_upsert() {
483 let lookup = MdnsLookup::new();
484 let p = peer(2);
485 lookup.upsert(p, vec![addr(6000)]).await;
486
487 let items = drain(Box::pin(lookup.lookup(p))).await;
488 assert_eq!(items.len(), 1);
489 assert!(matches!(&items[0], Ok(a) if a.port() == 6000));
490 }
491
492 #[tokio::test]
493 async fn mdns_lookup_missing_yields_error() {
494 let lookup = MdnsLookup::new();
495 let items = drain(Box::pin(lookup.lookup(peer(9)))).await;
496 assert_eq!(items.len(), 1);
497 assert!(matches!(&items[0], Err(_)));
498 }
499
500 fn cache_config(dir: &TempDir) -> BootstrapCacheConfig {
501 BootstrapCacheConfig::builder()
502 .cache_dir(dir.path().to_path_buf())
503 .build()
504 }
505
506 #[tokio::test]
507 async fn bootstrap_cache_lookup_yields_seeded_addresses() {
508 let dir = TempDir::new().expect("tempdir");
509 let cache = Arc::new(
510 BootstrapCache::open(cache_config(&dir))
511 .await
512 .expect("open cache"),
513 );
514 let p = peer(3);
515 cache.add_seed(p, vec![addr(7000), addr(7001)]).await;
516
517 let lookup = BootstrapCacheLookup::new(Arc::clone(&cache));
518 assert_eq!(lookup.name(), "bootstrap-cache");
519
520 let items = drain(Box::pin(lookup.lookup(p))).await;
521 let oks: Vec<_> = items
522 .iter()
523 .filter_map(|x| x.as_ref().ok().copied())
524 .collect();
525 assert!(oks.contains(&addr(7000)));
526 assert!(oks.contains(&addr(7001)));
527 }
528
529 #[tokio::test]
530 async fn bootstrap_cache_lookup_missing_yields_error() {
531 let dir = TempDir::new().expect("tempdir");
532 let cache = Arc::new(
533 BootstrapCache::open(cache_config(&dir))
534 .await
535 .expect("open cache"),
536 );
537 let lookup = BootstrapCacheLookup::new(cache);
538 let items = drain(Box::pin(lookup.lookup(peer(99)))).await;
539 assert_eq!(items.len(), 1);
540 assert!(matches!(&items[0], Err(_)));
541 }
542
543 #[tokio::test]
546 async fn discovery_parallel_error_tolerance() {
547 let p = peer(42);
548
549 let svc_a = HardcodedLookup::from_pairs("svc-a", [(p, vec![addr(8000)])]);
551 let svc_b = HardcodedLookup::from_pairs("svc-b", [(p, vec![addr(8001)])]);
553 struct ErrorOnly;
556 impl AddressLookup for ErrorOnly {
557 fn name(&self) -> &'static str {
558 "svc-c-error"
559 }
560 fn lookup(
561 &self,
562 _peer_id: PeerId,
563 ) -> BoxStream<'static, Result<SocketAddr, LookupError>> {
564 Box::pin(futures_util::stream::iter(vec![Err(
565 LookupError::transient("synthetic"),
566 )]))
567 }
568 }
569
570 let mut reg = LookupRegistry::new();
571 reg.add_service(svc_a);
572 reg.add_service(svc_b);
573 reg.add_service(ErrorOnly);
574 assert_eq!(reg.len(), 3);
575 assert_eq!(reg.service_names(), vec!["svc-a", "svc-b", "svc-c-error"]);
576
577 let items = drain_registry(reg.lookup(p)).await;
578
579 assert_eq!(items.len(), 3);
581
582 let oks: Vec<SocketAddr> = items
584 .iter()
585 .filter_map(|(_, r)| r.as_ref().ok().copied())
586 .collect();
587 assert_eq!(oks.len(), 2);
588 assert!(oks.contains(&addr(8000)));
589 assert!(oks.contains(&addr(8001)));
590
591 let errs: Vec<&'static str> = items
593 .iter()
594 .filter_map(|(name, r)| if r.is_err() { Some(*name) } else { None })
595 .collect();
596 assert_eq!(errs, vec!["svc-c-error"]);
597 }
598
599 #[tokio::test]
601 async fn empty_registry_yields_no_items() {
602 let reg = LookupRegistry::new();
603 assert!(reg.is_empty());
604 let items = drain_registry(reg.lookup(peer(0))).await;
605 assert!(items.is_empty());
606 }
607
608 #[tokio::test]
611 async fn registry_is_concurrent() {
612 let p = peer(7);
613
614 struct SlowOk {
615 counter: AtomicUsize,
616 }
617 impl AddressLookup for SlowOk {
618 fn name(&self) -> &'static str {
619 "slow"
620 }
621 fn lookup(
622 &self,
623 _peer_id: PeerId,
624 ) -> BoxStream<'static, Result<SocketAddr, LookupError>> {
625 self.counter
626 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
627 Box::pin(futures_util::stream::once(async {
628 tokio::time::sleep(std::time::Duration::from_millis(200)).await;
629 Ok::<_, LookupError>(addr(9999))
630 }))
631 }
632 }
633
634 let mut reg = LookupRegistry::new();
635 reg.add_service(HardcodedLookup::from_pairs("fast", [(p, vec![addr(8000)])]));
636 reg.add_service(SlowOk {
637 counter: AtomicUsize::new(0),
638 });
639
640 let start = std::time::Instant::now();
641 let mut stream = reg.lookup(p);
642 let first = stream.next().await.expect("first item");
643 assert!(first.1.is_ok());
645 assert!(
647 start.elapsed() < std::time::Duration::from_millis(150),
648 "fanout did not happen in parallel: elapsed = {:?}",
649 start.elapsed()
650 );
651 }
652
653 #[tokio::test]
663 async fn registry_surfaces_all_addresses_per_service() {
664 let p = peer(11);
665
666 let svc_a =
668 HardcodedLookup::from_pairs("multi-a", [(p, vec![addr(7100), addr(7101), addr(7102)])]);
669 let svc_b = HardcodedLookup::from_pairs("single-b", [(p, vec![addr(7200)])]);
671 struct ErrorOnly;
673 impl AddressLookup for ErrorOnly {
674 fn name(&self) -> &'static str {
675 "err-c"
676 }
677 fn lookup(
678 &self,
679 _peer_id: PeerId,
680 ) -> BoxStream<'static, Result<SocketAddr, LookupError>> {
681 Box::pin(futures_util::stream::iter(vec![Err(
682 LookupError::transient("synthetic"),
683 )]))
684 }
685 }
686
687 let mut reg = LookupRegistry::new();
688 reg.add_service(svc_a);
689 reg.add_service(svc_b);
690 reg.add_service(ErrorOnly);
691
692 let items = drain_registry(reg.lookup(p)).await;
693 assert_eq!(items.len(), 5, "expected 5 items (3+1+1), got: {:?}", items);
695
696 let multi_a_addrs: Vec<SocketAddr> = items
698 .iter()
699 .filter_map(|(name, r)| {
700 if *name == "multi-a" {
701 r.as_ref().ok().copied()
702 } else {
703 None
704 }
705 })
706 .collect();
707 assert_eq!(
708 multi_a_addrs.len(),
709 3,
710 "multi-a must surface all 3 addresses, got: {:?}",
711 multi_a_addrs
712 );
713 assert!(multi_a_addrs.contains(&addr(7100)));
714 assert!(multi_a_addrs.contains(&addr(7101)));
715 assert!(multi_a_addrs.contains(&addr(7102)));
716
717 let single_b_addrs: Vec<SocketAddr> = items
719 .iter()
720 .filter_map(|(name, r)| {
721 if *name == "single-b" {
722 r.as_ref().ok().copied()
723 } else {
724 None
725 }
726 })
727 .collect();
728 assert_eq!(single_b_addrs, vec![addr(7200)]);
729
730 let err_c_count = items
732 .iter()
733 .filter(|(name, r)| *name == "err-c" && r.is_err())
734 .count();
735 assert_eq!(err_c_count, 1);
736 }
737
738 #[tokio::test]
740 async fn registry_handles_empty_stream() {
741 struct EmptySvc;
742 impl AddressLookup for EmptySvc {
743 fn name(&self) -> &'static str {
744 "empty"
745 }
746 fn lookup(
747 &self,
748 _peer_id: PeerId,
749 ) -> BoxStream<'static, Result<SocketAddr, LookupError>> {
750 Box::pin(futures_util::stream::empty())
751 }
752 }
753
754 let mut reg = LookupRegistry::new();
755 reg.add_service(EmptySvc);
756 reg.add_service(HardcodedLookup::from_pairs(
757 "static",
758 [(peer(1), vec![addr(5000)])],
759 ));
760
761 let items = drain_registry(reg.lookup(peer(1))).await;
762 assert_eq!(items.len(), 1);
764 assert!(items[0].1.is_ok());
765 assert_eq!(items[0].0, "static");
766 }
767}