1use std::{cmp::Ordering, future::Future, io, sync::Arc};
35
36use bytes::Bytes;
37use chrono::{DateTime, Utc};
38use derive_more::Deref;
39use endhost_api_client::client::EndhostApiClient;
40use futures::{
41 FutureExt,
42 future::{self, BoxFuture},
43};
44use scc::{Guard, HashIndex, hash_index::Entry};
45use scion_proto::{
46 address::IsdAsn,
47 path::{self, Path},
48};
49use thiserror::Error;
50use tokio::sync::mpsc;
51use tokio_util::sync::CancellationToken;
52use tracing::{error, instrument};
53
54use crate::{
55 path::{PathStrategy, types::PathManagerPath},
56 types::ResFut,
57};
58#[derive(Debug, Error)]
60pub enum PathToError {
61 #[error("fetching paths: {0}")]
63 FetchPaths(String),
64 #[error("no path found")]
66 NoPathFound,
67}
68
69#[derive(Debug, Clone, Error)]
71pub enum PathWaitError {
72 #[error("path fetch failed: {0}")]
74 FetchFailed(String),
75 #[error("no path found")]
77 NoPathFound,
78}
79
80impl From<PathToError> for PathWaitError {
81 fn from(error: PathToError) -> Self {
82 match error {
83 PathToError::FetchPaths(msg) => PathWaitError::FetchFailed(msg),
84 PathToError::NoPathFound => PathWaitError::NoPathFound,
85 }
86 }
87}
88
89pub trait PathManager: SyncPathManager {
91 fn path_wait(
94 &self,
95 src: IsdAsn,
96 dst: IsdAsn,
97 now: DateTime<Utc>,
98 ) -> impl ResFut<'_, Path<Bytes>, PathWaitError>;
99}
100
101pub trait SyncPathManager {
104 fn register_path(&self, src: IsdAsn, dst: IsdAsn, now: DateTime<Utc>, path: Path<Bytes>);
106
107 fn try_cached_path(
111 &self,
112 src: IsdAsn,
113 dst: IsdAsn,
114 now: DateTime<Utc>,
115 ) -> io::Result<Option<Path<Bytes>>>;
116}
117
118#[derive(Debug, Clone)]
120struct PrefetchRequest {
121 pub src: IsdAsn,
122 pub dst: IsdAsn,
123 pub now: DateTime<Utc>,
124}
125
126#[derive(Debug, Clone)]
128struct PathRegistration {
129 pub src: IsdAsn,
130 pub dst: IsdAsn,
131 pub now: DateTime<Utc>,
132 pub path: Path<Bytes>,
133}
134
135#[derive(Debug, Clone)]
137struct PathCacheEntry {
138 path: PathManagerPath,
139 #[expect(unused)]
140 cached_at: DateTime<Utc>,
141}
142
143impl PathCacheEntry {
144 fn new(path: PathManagerPath, now: DateTime<Utc>) -> Self {
145 Self {
146 path,
147 cached_at: now,
148 }
149 }
150
151 fn is_expired(&self, now: DateTime<Utc>) -> bool {
152 self.path
153 .scion_path()
154 .expiry_time()
155 .map(|expiry| expiry < now)
156 .unwrap_or(true)
157 }
158}
159
160pub struct CachingPathManager<F: PathFetcher = PathFetcherImpl> {
162 state: CachingPathManagerState<F>,
164 prefetch_tx: mpsc::Sender<PrefetchRequest>,
166 registration_tx: mpsc::Sender<PathRegistration>,
167 cancellation_token: CancellationToken,
169}
170
171#[derive(Debug, thiserror::Error)]
173pub enum PathFetchError {
174 #[error("failed to fetch segments: {0}")]
176 FetchSegments(#[from] SegmentFetchError),
177}
178
179pub trait PathFetcher {
181 fn fetch_paths(
183 &self,
184 src: IsdAsn,
185 dst: IsdAsn,
186 ) -> impl ResFut<'_, Vec<path::Path>, PathFetchError>;
187}
188
189type BoxedPathLookupResult = BoxFuture<'static, Result<Path<Bytes>, PathWaitError>>;
190
191struct CachingPathManagerStateInner<F: PathFetcher> {
192 selection: PathStrategy,
194 fetcher: F,
196 path_cache: HashIndex<(IsdAsn, IsdAsn), PathCacheEntry>,
198 inflight: HashIndex<(IsdAsn, IsdAsn), future::Shared<BoxedPathLookupResult>>,
200}
201
202#[derive(Deref)]
204#[deref(forward)]
205struct CachingPathManagerState<F: PathFetcher>(Arc<CachingPathManagerStateInner<F>>);
206
207impl<F: PathFetcher> Clone for CachingPathManagerState<F> {
208 fn clone(&self) -> Self {
209 Self(Arc::clone(&self.0))
210 }
211}
212
213impl<F: PathFetcher + Send + Sync + 'static> CachingPathManager<F> {
214 pub fn start(path_strategy: PathStrategy, fetcher: F) -> Self {
218 let cancellation_token = CancellationToken::new();
219 let (manager, task_future) =
220 Self::start_future(path_strategy, fetcher, cancellation_token.clone());
221
222 tokio::spawn(async move {
224 task_future.await;
225 });
226
227 manager
228 }
229
230 pub fn start_future(
232 selection: PathStrategy,
233 fetcher: F,
234 cancellation_token: CancellationToken,
235 ) -> (Self, impl std::future::Future<Output = ()>) {
236 let (prefetch_tx, prefetch_rx) = mpsc::channel(1000);
237 let (registration_tx, registration_rx) = mpsc::channel(1000);
238
239 let state = CachingPathManagerState(Arc::new(CachingPathManagerStateInner {
240 selection,
241 fetcher,
242 path_cache: HashIndex::new(),
243 inflight: HashIndex::new(),
244 }));
245
246 let manager = Self {
247 state: state.clone(),
248 prefetch_tx,
249 registration_tx,
250 cancellation_token: cancellation_token.clone(),
251 };
252
253 let task_future = async move {
254 let task =
255 PathManagerTask::new(state, prefetch_rx, registration_rx, cancellation_token);
256 task.run().await
257 };
258
259 (manager, task_future)
260 }
261
262 pub fn try_cached_path(
264 &self,
265 src: IsdAsn,
266 dst: IsdAsn,
267 now: DateTime<Utc>,
268 ) -> io::Result<Option<Path<Bytes>>> {
269 self.state.try_cached_path(src, dst, now)
270 }
271
272 fn prefetch_path_internal(&self, src: IsdAsn, dst: IsdAsn, now: DateTime<Utc>) {
273 if let Err(e) = self.prefetch_tx.try_send(PrefetchRequest { src, dst, now }) {
274 tracing::warn!(err=?e, "Prefetch path channel send failed");
275 }
276 }
277
278 fn register_path_internal(
279 &self,
280 src: IsdAsn,
281 dst: IsdAsn,
282 now: DateTime<Utc>,
283 path: Path<Bytes>,
284 ) {
285 if let Err(e) = self.registration_tx.try_send(PathRegistration {
286 src,
287 dst,
288 now,
289 path,
290 }) {
291 tracing::warn!(err=?e, "Register path channel send failed");
292 }
293 }
294}
295
296impl<F: PathFetcher> Drop for CachingPathManager<F> {
297 fn drop(&mut self) {
298 self.cancellation_token.cancel();
299 }
300}
301
302impl<F: PathFetcher + Send + Sync + 'static> SyncPathManager for CachingPathManager<F> {
303 fn register_path(&self, src: IsdAsn, dst: IsdAsn, now: DateTime<Utc>, path: Path<Bytes>) {
304 self.register_path_internal(src, dst, now, path);
305 }
306
307 fn try_cached_path(
311 &self,
312 src: IsdAsn,
313 dst: IsdAsn,
314 now: DateTime<Utc>,
315 ) -> io::Result<Option<Path<Bytes>>> {
316 match self.state.try_cached_path(src, dst, now)? {
317 Some(path) => Ok(Some(path)),
318 None => {
319 self.prefetch_path_internal(src, dst, now);
321 Ok(None)
322 }
323 }
324 }
325}
326
327impl<F: PathFetcher + Send + Sync + 'static> PathManager for CachingPathManager<F> {
328 fn path_wait(
329 &self,
330 src: IsdAsn,
331 dst: IsdAsn,
332 now: DateTime<Utc>,
333 ) -> impl ResFut<'_, Path<Bytes>, PathWaitError> {
334 async move {
335 if let Some(cached) = self.state.cached_path_wait(src, dst, now).await {
337 return Ok(cached);
338 }
339
340 self.state.fetch_and_cache_path(src, dst, now).await
342 }
343 }
344}
345
346pub trait PathPrefetcher {
348 fn prefetch_path(&self, src: IsdAsn, dst: IsdAsn);
350}
351
352impl<F: PathFetcher + Send + Sync + 'static> PathPrefetcher for CachingPathManager<F> {
353 fn prefetch_path(&self, src: IsdAsn, dst: IsdAsn) {
354 self.prefetch_path_internal(src, dst, Utc::now());
355 }
356}
357
358impl<F: PathFetcher + Send + Sync + 'static> CachingPathManagerState<F> {
359 pub fn try_cached_path(
361 &self,
362 src: IsdAsn,
363 dst: IsdAsn,
364 now: DateTime<Utc>,
365 ) -> io::Result<Option<Path<Bytes>>> {
366 let guard = Guard::new();
367 match self.path_cache.peek(&(src, dst), &guard) {
368 Some(cached) => {
369 if !cached.is_expired(now) {
370 Ok(Some(cached.path.scion_path().clone()))
371 } else {
372 Ok(None)
373 }
374 }
375 None => Ok(None),
376 }
377 }
378
379 async fn cached_path_wait(
382 &self,
383 src: IsdAsn,
384 dst: IsdAsn,
385 now: DateTime<Utc>,
386 ) -> Option<Path<Bytes>> {
387 let guard = Guard::new();
388 match self.path_cache.peek(&(src, dst), &guard) {
389 Some(cached) => {
390 if !cached.is_expired(now) {
391 Some(cached.path.scion_path().clone())
392 } else {
393 None
394 }
395 }
396 None => None,
397 }
398 }
399
400 async fn fetch_and_cache_path(
402 &self,
403 src: IsdAsn,
404 dst: IsdAsn,
405 now: DateTime<Utc>,
406 ) -> Result<Path<Bytes>, PathWaitError> {
407 let fut = match self.inflight.entry_sync((src, dst)) {
408 Entry::Occupied(entry) => entry.get().clone(),
409 Entry::Vacant(entry) => {
410 let self_c = self.clone();
411 entry
412 .insert_entry(
413 async move {
414 let result = self_c.do_fetch_and_cache(src, dst, now).await;
415 self_c.inflight.remove_sync(&(src, dst));
416 result
417 }
418 .boxed()
419 .shared(),
420 )
421 .clone()
422 }
423 };
424
425 fut.await
426 }
427
428 async fn do_fetch_and_cache(
430 &self,
431 src: IsdAsn,
432 dst: IsdAsn,
433 now: DateTime<Utc>,
434 ) -> Result<Path<Bytes>, PathWaitError> {
435 let mut paths = self
436 .fetcher
437 .fetch_paths(src, dst)
438 .await
439 .map_err(|e| PathWaitError::FetchFailed(e.to_string()))?
440 .into_iter()
441 .map(|p| PathManagerPath::new(p, false))
442 .collect::<Vec<_>>();
443
444 let initial = paths.len();
445
446 self.selection.filter_inplace(&mut paths);
447 self.selection.rank_inplace(&mut paths);
448
449 tracing::info!(
450 src = %src,
451 dst = %dst,
452 n_initial = initial,
453 n_ok = paths.len(),
454 "Fetched and filtered paths",
455 );
456
457 let preferred_path = paths.into_iter().next().ok_or(PathWaitError::NoPathFound)?;
458 let preferred_path_entry = PathCacheEntry::new(preferred_path.clone(), now);
459
460 match self.path_cache.entry_sync((src, dst)) {
461 Entry::Occupied(mut entry) => {
462 entry.update(preferred_path_entry);
463 }
464 Entry::Vacant(entry) => {
465 entry.insert_entry(preferred_path_entry);
466 }
467 }
468
469 Ok(preferred_path.path)
470 }
471
472 fn request_inflight(&self, src: IsdAsn, dst: IsdAsn) -> bool {
474 let guard = Guard::new();
475 self.inflight.peek(&(src, dst), &guard).is_some()
476 }
477}
478
479struct PathManagerTask<F: PathFetcher> {
481 state: CachingPathManagerState<F>,
482 prefetch_rx: mpsc::Receiver<PrefetchRequest>,
483 registration_rx: mpsc::Receiver<PathRegistration>,
484 cancellation_token: CancellationToken,
485}
486
487impl<F: PathFetcher + Send + Sync + 'static> PathManagerTask<F> {
488 fn new(
489 state: CachingPathManagerState<F>,
490 prefetch_rx: mpsc::Receiver<PrefetchRequest>,
491 registration_rx: mpsc::Receiver<PathRegistration>,
492 cancellation_token: CancellationToken,
493 ) -> Self {
494 Self {
495 state,
496 prefetch_rx,
497 registration_rx,
498 cancellation_token,
499 }
500 }
501
502 async fn run(mut self) {
503 tracing::trace!("Starting active path manager task");
504
505 loop {
506 tokio::select! {
507 _ = self.cancellation_token.cancelled() => {
509 tracing::info!("Path manager task cancelled");
510 break;
511 }
512
513 registration = self.registration_rx.recv() => {
515 match registration {
516 Some(reg) => {
517 self.handle_registration(reg).await;
518 }
519 None => {
520 tracing::info!("Registration channel closed");
521 break;
522 }
523 }
524 }
525
526 prefetch = self.prefetch_rx.recv() => {
528 match prefetch {
529 Some(req) => {
530 self.handle_prefetch(req).await;
531 }
532 None => {
533 tracing::info!("Prefetch channel closed");
534 break;
535 }
536 }
537 }
538 }
539 }
540
541 tracing::info!("Path manager task finished");
542 }
543
544 async fn handle_registration(&self, registration: PathRegistration) {
545 tracing::trace!(
546 src = %registration.src,
547 dst = %registration.dst,
548 "Handling path registration"
549 );
550
551 let new_path = PathManagerPath::new(registration.path, true);
552
553 if !self.state.selection.predicate(&new_path) {
555 tracing::debug!(
556 src = %registration.src,
557 dst = %registration.dst,
558 "Registered path rejected by policy"
559 );
560 return;
561 }
562
563 let entry = self
565 .state
566 .path_cache
567 .entry_sync((registration.src, registration.dst));
568
569 match entry {
570 Entry::Occupied(mut entry) => {
571 if entry.is_expired(registration.now)
573 || self.state.selection.rank_order(&new_path, &entry.path) == Ordering::Less
575 {
576 tracing::info!(
577 src = %registration.src,
578 dst = %registration.dst,
579 "Updating active path"
580 );
581 entry.update(PathCacheEntry::new(new_path, registration.now));
582 }
583 }
584 Entry::Vacant(entry) => {
585 entry.insert_entry(PathCacheEntry::new(new_path, registration.now));
586 }
587 }
588 }
589
590 #[instrument(name = "prefetch", fields(src = %request.src, dst = %request.dst), skip_all)]
594 async fn handle_prefetch(&self, request: PrefetchRequest) {
595 tracing::debug!("Handling prefetch request");
596
597 if self
599 .state
600 .cached_path_wait(request.src, request.dst, request.now)
601 .await
602 .is_some()
603 {
604 tracing::debug!("Path already cached, skipping prefetch");
605 return;
606 }
607
608 if self.state.request_inflight(request.src, request.dst) {
610 tracing::debug!("Path request already in flight, skipping prefetch");
611 return;
612 }
613
614 match self
618 .state
619 .fetch_and_cache_path(request.src, request.dst, request.now)
620 .await
621 {
622 Ok(_) => {
623 tracing::debug!("Successfully prefetched path");
624 }
625 Err(e) => {
626 tracing::warn!(
627 error = %e,
628 "Failed to prefetch path"
629 );
630 }
631 }
632 }
633}
634
635pub type SegmentFetchError = Box<dyn std::error::Error + Send + Sync>;
637
638pub struct Segments {
640 pub core_segments: Vec<path::PathSegment>,
642 pub non_core_segments: Vec<path::PathSegment>,
644}
645
646pub trait SegmentFetcher {
648 fn fetch_segments<'a>(
650 &'a self,
651 src: IsdAsn,
652 dst: IsdAsn,
653 ) -> impl Future<Output = Result<Segments, SegmentFetchError>> + Send + 'a;
654}
655
656pub struct ConnectRpcSegmentFetcher {
658 client: Arc<dyn EndhostApiClient>,
659}
660
661impl ConnectRpcSegmentFetcher {
662 pub fn new(client: Arc<dyn EndhostApiClient>) -> Self {
664 Self { client }
665 }
666}
667
668impl SegmentFetcher for ConnectRpcSegmentFetcher {
669 async fn fetch_segments(
670 &self,
671 src: IsdAsn,
672 dst: IsdAsn,
673 ) -> Result<Segments, SegmentFetchError> {
674 let resp = self
675 .client
676 .list_segments(src, dst, 128, "".to_string())
677 .await?;
678
679 tracing::debug!(
680 n_core=resp.core_segments.len(),
681 n_up=resp.up_segments.len(),
682 n_down=resp.down_segments.len(),
683 src = %src,
684 dst = %dst,
685 "Received segments from control plane"
686 );
687
688 let (core_segments, non_core_segments) = resp.split_parts();
689 Ok(Segments {
690 core_segments,
691 non_core_segments,
692 })
693 }
694}
695
696pub struct PathFetcherImpl<F: SegmentFetcher = ConnectRpcSegmentFetcher> {
698 segment_fetcher: F,
699}
700
701impl<F: SegmentFetcher> PathFetcherImpl<F> {
702 pub fn new(segment_fetcher: F) -> Self {
704 Self { segment_fetcher }
705 }
706}
707
708impl<L: SegmentFetcher + Send + Sync> PathFetcher for PathFetcherImpl<L> {
709 async fn fetch_paths(
710 &self,
711 src: IsdAsn,
712 dst: IsdAsn,
713 ) -> Result<Vec<path::Path>, PathFetchError> {
714 let Segments {
715 core_segments,
716 non_core_segments,
717 } = self.segment_fetcher.fetch_segments(src, dst).await?;
718
719 tracing::trace!(
720 n_core_segments = core_segments.len(),
721 n_non_core_segments = non_core_segments.len(),
722 src = %src,
723 dst = %dst,
724 "Fetched segments"
725 );
726
727 let paths = path::combinator::combine(src, dst, core_segments, non_core_segments);
728 Ok(paths)
729 }
730}
731
732#[cfg(test)]
733mod tests {
734 use std::{
735 collections::HashMap,
736 sync::{
737 Arc, Mutex,
738 atomic::{AtomicUsize, Ordering},
739 },
740 };
741
742 use bytes::{BufMut, BytesMut};
743 use scion_proto::{
744 address::IsdAsn,
745 packet::ByEndpoint,
746 path::{self, DataPlanePath, EncodedStandardPath, Path},
747 wire_encoding::WireDecode,
748 };
749 use tokio::{sync::Barrier, task::yield_now};
750
751 use super::*;
752 use crate::path::ranking::Shortest;
753
754 type PathMap = HashMap<(IsdAsn, IsdAsn), Result<Vec<Path>, PathFetchError>>;
755 #[derive(Default)]
756 struct MockPathFetcher {
757 paths: Mutex<PathMap>,
758 call_count: AtomicUsize,
759 call_delay: Option<usize>,
760 barrier: Option<Arc<Barrier>>,
761 }
762
763 impl MockPathFetcher {
764 fn with_path(src: IsdAsn, dst: IsdAsn, path: Path) -> Self {
765 let mut paths = HashMap::new();
766 paths.insert((src, dst), Ok(vec![path]));
767 Self {
768 paths: Mutex::new(paths),
769 call_count: AtomicUsize::new(0),
770 call_delay: None,
771 barrier: None,
772 }
773 }
774
775 fn with_error(src: IsdAsn, dst: IsdAsn, error: &'static str) -> Self {
776 let mut paths = HashMap::new();
777 paths.insert((src, dst), Err(PathFetchError::FetchSegments(error.into())));
778 Self {
779 paths: Mutex::new(paths),
780 call_count: AtomicUsize::new(0),
781 call_delay: None,
782 barrier: None,
783 }
784 }
785
786 fn with_barrier(mut self, barrier: Arc<Barrier>) -> Self {
787 self.barrier = Some(barrier);
788 self
789 }
790 }
791
792 impl PathFetcher for MockPathFetcher {
793 fn fetch_paths(
794 &self,
795 src: IsdAsn,
796 dst: IsdAsn,
797 ) -> impl ResFut<'_, Vec<path::Path>, PathFetchError> {
798 async move {
799 self.call_count.fetch_add(1, Ordering::Relaxed);
800 if let Some(delay) = self.call_delay {
801 while self.call_count.load(Ordering::SeqCst) < delay {
802 yield_now().await;
803 }
804 }
805 if let Some(barrier) = &self.barrier {
806 barrier.wait().await;
807 }
808 match self.paths.lock().unwrap().get(&(src, dst)) {
809 Some(Ok(paths)) => Ok(paths.clone()),
810 None => Ok(vec![]),
811 Some(Err(_)) => Err(PathFetchError::FetchSegments("other error".into())),
812 }
813 }
814 }
815 }
816
817 fn test_path(src: IsdAsn, dst: IsdAsn) -> Path {
818 let mut path_raw = BytesMut::with_capacity(36);
819 path_raw.put_u32(0x0000_2000);
820 path_raw.put_slice(&[0_u8; 32]);
821 let dp_path =
822 DataPlanePath::Standard(EncodedStandardPath::decode(&mut path_raw.freeze()).unwrap());
823
824 Path::new(
825 dp_path,
826 ByEndpoint {
827 source: src,
828 destination: dst,
829 },
830 None,
831 )
832 }
833
834 fn setup_pm(fetcher: MockPathFetcher) -> CachingPathManagerState<MockPathFetcher> {
835 CachingPathManagerState(Arc::new(CachingPathManagerStateInner {
836 fetcher,
837 path_cache: HashIndex::new(),
838 inflight: HashIndex::new(),
839 selection: PathStrategy {
840 policies: vec![],
841 ranking: vec![Arc::new(Shortest)],
842 },
843 }))
844 }
845
846 #[tokio::test]
847 async fn fetch_and_cache_path_single_request_success() {
848 let src = IsdAsn(0x1_ff00_0000_0110);
849 let dst = IsdAsn(0x1_ff00_0000_0111);
850 let path = test_path(src, dst);
851 let fetcher = MockPathFetcher::with_path(src, dst, path.clone());
852 let state = setup_pm(fetcher);
853
854 let result = state.fetch_and_cache_path(src, dst, Utc::now()).await;
855
856 assert!(result.is_ok());
857 assert_eq!(state.fetcher.call_count.load(Ordering::SeqCst), 1);
858 let guard = Guard::new();
859 assert!(state.path_cache.peek(&(src, dst), &guard).is_some());
860 assert!(state.inflight.peek(&(src, dst), &guard).is_none());
861 }
862
863 #[tokio::test]
864 async fn fetch_and_cache_path_concurrent_requests_coalesced() {
865 let src = IsdAsn(0x1_ff00_0000_0110);
866 let dst = IsdAsn(0x1_ff00_0000_0111);
867 let path = test_path(src, dst);
868 let barrier = Arc::new(Barrier::new(2));
869 let fetcher =
870 MockPathFetcher::with_path(src, dst, path.clone()).with_barrier(barrier.clone());
871 let state = setup_pm(fetcher);
872
873 let state_clone = state.clone();
874 let task1 =
875 tokio::spawn(
876 async move { state_clone.fetch_and_cache_path(src, dst, Utc::now()).await },
877 );
878 while state.fetcher.call_count.load(Ordering::SeqCst) < 1 {
880 yield_now().await;
881 }
882
883 let state_clone2 = state.clone();
884 let task2 = tokio::spawn(async move {
885 state_clone2
886 .fetch_and_cache_path(src, dst, Utc::now())
887 .await
888 });
889
890 barrier.wait().await;
892
893 let (res1, res2) = future::join(task1, task2).await;
894
895 assert_eq!(state.fetcher.call_count.load(Ordering::SeqCst), 1);
896 res1.unwrap().unwrap();
897 res2.unwrap().unwrap();
898 let guard = Guard::new();
899 assert!(state.inflight.peek(&(src, dst), &guard).is_none());
900 }
901
902 #[tokio::test]
903 async fn fetch_and_cache_path_fetch_error() {
904 let src = IsdAsn(0x1_ff00_0000_0110);
905 let dst = IsdAsn(0x1_ff00_0000_0111);
906 let fetcher = MockPathFetcher::with_error(src, dst, "error");
907 let state = setup_pm(fetcher);
908
909 let result = state.fetch_and_cache_path(src, dst, Utc::now()).await;
910
911 assert!(matches!(result, Err(PathWaitError::FetchFailed(_))));
912 assert_eq!(state.fetcher.call_count.load(Ordering::SeqCst), 1);
913 let guard = Guard::new();
914 assert!(state.path_cache.peek(&(src, dst), &guard).is_none());
915 assert!(state.inflight.peek(&(src, dst), &guard).is_none());
916 }
917
918 #[tokio::test]
919 async fn fetch_and_cache_path_no_path_found() {
920 let src = IsdAsn(0x1_ff00_0000_0110);
921 let dst = IsdAsn(0x1_ff00_0000_0111);
922 let fetcher = MockPathFetcher::default();
923 let state = setup_pm(fetcher);
924
925 let result = state.fetch_and_cache_path(src, dst, Utc::now()).await;
926
927 assert!(matches!(result, Err(PathWaitError::NoPathFound)));
928 assert_eq!(state.fetcher.call_count.load(Ordering::SeqCst), 1);
929 }
930
931 #[tokio::test]
932 async fn fetch_and_cache_path_concurrent_requests_different_keys() {
933 let src1 = IsdAsn(0x1_ff00_0000_0110);
934 let dst1 = IsdAsn(0x1_ff00_0000_0111);
935 let src2 = IsdAsn(0x1_ff00_0000_0120);
936 let dst2 = IsdAsn(0x1_ff00_0000_0121);
937 let path1 = test_path(src1, dst1);
938 let path2 = test_path(src2, dst2);
939
940 let mut paths = HashMap::new();
941 paths.insert((src1, dst1), Ok(vec![path1.clone()]));
942 paths.insert((src2, dst2), Ok(vec![path2.clone()]));
943
944 let barrier = Arc::new(Barrier::new(3));
945
946 let fetcher = MockPathFetcher {
947 paths: Mutex::new(paths),
948 ..Default::default()
949 }
950 .with_barrier(barrier.clone());
951 let state = setup_pm(fetcher);
952
953 let state_clone1 = state.clone();
954 let task1 = tokio::spawn(async move {
955 state_clone1
956 .fetch_and_cache_path(src1, dst1, Utc::now())
957 .await
958 });
959
960 let state_clone2 = state.clone();
961 let task2 = tokio::spawn(async move {
962 state_clone2
963 .fetch_and_cache_path(src2, dst2, Utc::now())
964 .await
965 });
966
967 barrier.wait().await;
969
970 let (res1, res2) = future::join(task1, task2).await;
971
972 assert_eq!(state.fetcher.call_count.load(Ordering::SeqCst), 2);
973 let got1 = res1.unwrap().unwrap();
974 let got2 = res2.unwrap().unwrap();
975 assert_eq!(got1.source(), path1.source());
976 assert_eq!(got1.destination(), path1.destination());
977 assert_eq!(got2.source(), path2.source());
978 assert_eq!(got2.destination(), path2.destination());
979 }
980}