1use std::{cmp::Ordering, future::Future, io, sync::Arc};
34
35use bytes::Bytes;
36use chrono::{DateTime, Utc};
37use derive_more::Deref;
38use endhost_api_client::client::EndhostApiClient;
39use futures::{
40 FutureExt,
41 future::{self, BoxFuture},
42};
43use scc::{Guard, HashIndex, hash_index::Entry};
44use scion_proto::{
45 address::IsdAsn,
46 path::{self, Path},
47};
48use thiserror::Error;
49use tokio::sync::mpsc;
50use tokio_util::sync::CancellationToken;
51use tracing::{debug, error, trace, warn};
52
53use super::{Shortest, policy::PathPolicy};
54use crate::types::ResFut;
55
56#[derive(Debug, Error)]
58pub enum PathToError {
59 #[error("fetching paths: {0}")]
61 FetchPaths(String),
62 #[error("no path found")]
64 NoPathFound,
65}
66
67#[derive(Debug, Clone, Error)]
69pub enum PathWaitError {
70 #[error("path fetch failed: {0}")]
72 FetchFailed(String),
73 #[error("no path found")]
75 NoPathFound,
76}
77
78impl From<PathToError> for PathWaitError {
79 fn from(error: PathToError) -> Self {
80 match error {
81 PathToError::FetchPaths(msg) => PathWaitError::FetchFailed(msg),
82 PathToError::NoPathFound => PathWaitError::NoPathFound,
83 }
84 }
85}
86
87pub trait PathManager: SyncPathManager {
89 fn path_wait(
92 &self,
93 src: IsdAsn,
94 dst: IsdAsn,
95 now: DateTime<Utc>,
96 ) -> impl ResFut<'_, Path<Bytes>, PathWaitError>;
97}
98
99pub trait SyncPathManager {
102 fn register_path(&self, src: IsdAsn, dst: IsdAsn, now: DateTime<Utc>, path: Path<Bytes>);
104
105 fn try_cached_path(
109 &self,
110 src: IsdAsn,
111 dst: IsdAsn,
112 now: DateTime<Utc>,
113 ) -> io::Result<Option<Path<Bytes>>>;
114}
115
116#[derive(Debug, Clone)]
118struct PrefetchRequest {
119 pub src: IsdAsn,
120 pub dst: IsdAsn,
121 pub now: DateTime<Utc>,
122}
123
124#[derive(Debug, Clone)]
126struct PathRegistration {
127 pub src: IsdAsn,
128 pub dst: IsdAsn,
129 pub now: DateTime<Utc>,
130 pub path: Path<Bytes>,
131}
132
133#[derive(Debug, Clone)]
135struct CachedPath {
136 path: scion_proto::path::Path,
137 #[expect(unused)]
138 cached_at: DateTime<Utc>,
139 from_registration: bool,
140}
141
142impl CachedPath {
143 fn new(path: scion_proto::path::Path, now: DateTime<Utc>, from_registration: bool) -> Self {
144 Self {
145 path,
146 cached_at: now,
147 from_registration,
148 }
149 }
150
151 fn is_expired(&self, now: DateTime<Utc>) -> bool {
152 self.path
153 .expiry_time()
154 .map(|expiry| expiry < now)
155 .unwrap_or(true)
156 }
157}
158
159pub struct CachingPathManager<P: PathPolicy = Shortest, F: PathFetcher = PathFetcherImpl> {
161 state: CachingPathManagerState<P, F>,
163 prefetch_tx: mpsc::Sender<PrefetchRequest>,
165 registration_tx: mpsc::Sender<PathRegistration>,
166 cancellation_token: CancellationToken,
168}
169
170#[derive(Debug, thiserror::Error)]
172pub enum PathFetchError {
173 #[error("failed to fetch segments: {0}")]
175 FetchSegments(#[from] SegmentFetchError),
176}
177
178pub trait PathFetcher {
180 fn fetch_paths(
182 &self,
183 src: IsdAsn,
184 dst: IsdAsn,
185 ) -> impl ResFut<'_, Vec<path::Path>, PathFetchError>;
186}
187
188type BoxedPathLookupResult = BoxFuture<'static, Result<Path<Bytes>, PathWaitError>>;
189
190struct CachingPathManagerStateInner<P: PathPolicy, F: PathFetcher> {
191 policy: P,
193 fetcher: F,
195 path_cache: HashIndex<(IsdAsn, IsdAsn), CachedPath>,
197 inflight: HashIndex<(IsdAsn, IsdAsn), future::Shared<BoxedPathLookupResult>>,
199}
200
201#[derive(Deref)]
203#[deref(forward)]
204struct CachingPathManagerState<P: PathPolicy, F: PathFetcher>(
205 Arc<CachingPathManagerStateInner<P, F>>,
206);
207
208impl<P: PathPolicy, F: PathFetcher> Clone for CachingPathManagerState<P, F> {
209 fn clone(&self) -> Self {
210 Self(Arc::clone(&self.0))
211 }
212}
213
214impl<P: PathPolicy + Send + Sync + 'static, F: PathFetcher + Send + Sync + 'static>
215 CachingPathManager<P, F>
216{
217 pub fn start(policy: P, fetcher: F) -> Self {
221 let cancellation_token = CancellationToken::new();
222 let (manager, task_future) =
223 Self::start_future(policy, fetcher, cancellation_token.clone());
224
225 tokio::spawn(async move {
227 task_future.await;
228 });
229
230 manager
231 }
232
233 pub fn start_future(
235 policy: P,
236 fetcher: F,
237 cancellation_token: CancellationToken,
238 ) -> (Self, impl std::future::Future<Output = ()>) {
239 let (prefetch_tx, prefetch_rx) = mpsc::channel(1000);
240 let (registration_tx, registration_rx) = mpsc::channel(1000);
241
242 let state = CachingPathManagerState(Arc::new(CachingPathManagerStateInner {
243 policy,
244 fetcher,
245 path_cache: HashIndex::new(),
246 inflight: HashIndex::new(),
247 }));
248
249 let manager = Self {
250 state: state.clone(),
251 prefetch_tx,
252 registration_tx,
253 cancellation_token: cancellation_token.clone(),
254 };
255
256 let task_future = async move {
257 let task =
258 PathManagerTask::new(state, prefetch_rx, registration_rx, cancellation_token);
259 task.run().await
260 };
261
262 (manager, task_future)
263 }
264
265 pub fn try_cached_path(
267 &self,
268 src: IsdAsn,
269 dst: IsdAsn,
270 now: DateTime<Utc>,
271 ) -> io::Result<Option<Path<Bytes>>> {
272 self.state.try_cached_path(src, dst, now)
273 }
274
275 fn prefetch_path_internal(&self, src: IsdAsn, dst: IsdAsn, now: DateTime<Utc>) {
276 if let Err(e) = self.prefetch_tx.try_send(PrefetchRequest { src, dst, now }) {
277 trace!(
278 "Failed to send prefetch request - background task may be stopped: {}",
279 e
280 );
281 }
282 }
283
284 fn register_path_internal(
285 &self,
286 src: IsdAsn,
287 dst: IsdAsn,
288 now: DateTime<Utc>,
289 path: Path<Bytes>,
290 ) {
291 if let Err(e) = self.registration_tx.try_send(PathRegistration {
292 src,
293 dst,
294 now,
295 path,
296 }) {
297 warn!(
298 "Failed to send path registration - background task may be stopped: {}",
299 e
300 );
301 }
302 }
303}
304
305impl<P: PathPolicy, F: PathFetcher> Drop for CachingPathManager<P, F> {
306 fn drop(&mut self) {
307 self.cancellation_token.cancel();
308 trace!("PathManager dropped, background task cancelled");
310 }
311}
312
313impl<P: PathPolicy + Send + Sync + 'static, F: PathFetcher + Send + Sync + 'static> SyncPathManager
314 for CachingPathManager<P, F>
315{
316 fn register_path(&self, src: IsdAsn, dst: IsdAsn, now: DateTime<Utc>, path: Path<Bytes>) {
317 self.register_path_internal(src, dst, now, path);
318 }
319
320 fn try_cached_path(
324 &self,
325 src: IsdAsn,
326 dst: IsdAsn,
327 now: DateTime<Utc>,
328 ) -> io::Result<Option<Path<Bytes>>> {
329 match self.state.try_cached_path(src, dst, now)? {
330 Some(path) => Ok(Some(path)),
331 None => {
332 self.prefetch_path_internal(src, dst, now);
334 Ok(None)
335 }
336 }
337 }
338}
339
340impl<P: PathPolicy + Send + Sync + 'static, F: PathFetcher + Send + Sync + 'static> PathManager
341 for CachingPathManager<P, F>
342{
343 fn path_wait(
344 &self,
345 src: IsdAsn,
346 dst: IsdAsn,
347 now: DateTime<Utc>,
348 ) -> impl ResFut<'_, Path<Bytes>, PathWaitError> {
349 async move {
350 if let Some(cached) = self.state.cached_path_wait(src, dst, now).await {
352 return Ok(cached);
353 }
354
355 self.state.fetch_and_cache_path(src, dst, now).await
357 }
358 }
359}
360
361pub trait PathPrefetcher {
363 fn prefetch_path(&self, src: IsdAsn, dst: IsdAsn);
365}
366
367impl<P: PathPolicy + Send + Sync + 'static, F: PathFetcher + Send + Sync + 'static> PathPrefetcher
368 for CachingPathManager<P, F>
369{
370 fn prefetch_path(&self, src: IsdAsn, dst: IsdAsn) {
371 self.prefetch_path_internal(src, dst, Utc::now());
372 }
373}
374
375impl<P: PathPolicy + Send + Sync + 'static, F: PathFetcher + Send + Sync + 'static>
376 CachingPathManagerState<P, F>
377{
378 pub fn try_cached_path(
380 &self,
381 src: IsdAsn,
382 dst: IsdAsn,
383 now: DateTime<Utc>,
384 ) -> io::Result<Option<Path<Bytes>>> {
385 let guard = Guard::new();
386 match self.path_cache.peek(&(src, dst), &guard) {
387 Some(cached) => {
388 if !cached.is_expired(now) {
389 Ok(Some(cached.path.clone()))
390 } else {
391 Ok(None)
392 }
393 }
394 None => Ok(None),
395 }
396 }
397
398 async fn cached_path_wait(
401 &self,
402 src: IsdAsn,
403 dst: IsdAsn,
404 now: DateTime<Utc>,
405 ) -> Option<Path<Bytes>> {
406 let guard = Guard::new();
407 match self.path_cache.peek(&(src, dst), &guard) {
408 Some(cached) => {
409 if !cached.is_expired(now) {
410 Some(cached.path.clone())
411 } else {
412 None
413 }
414 }
415 None => None,
416 }
417 }
418
419 async fn fetch_and_cache_path(
421 &self,
422 src: IsdAsn,
423 dst: IsdAsn,
424 now: DateTime<Utc>,
425 ) -> Result<Path<Bytes>, PathWaitError> {
426 let fut = match self.inflight.entry_sync((src, dst)) {
427 Entry::Occupied(entry) => entry.get().clone(),
428 Entry::Vacant(entry) => {
429 let self_c = self.clone();
430 entry
431 .insert_entry(
432 async move {
433 let result = self_c.do_fetch_and_cache(src, dst, now).await;
434 self_c.inflight.remove_sync(&(src, dst));
435 result
436 }
437 .boxed()
438 .shared(),
439 )
440 .clone()
441 }
442 };
443
444 fut.await
445 }
446
447 async fn do_fetch_and_cache(
449 &self,
450 src: IsdAsn,
451 dst: IsdAsn,
452 now: DateTime<Utc>,
453 ) -> Result<Path<Bytes>, PathWaitError> {
454 let mut paths = self
455 .fetcher
456 .fetch_paths(src, dst)
457 .await
458 .map_err(|e| PathWaitError::FetchFailed(e.to_string()))?;
459
460 paths.retain(|path| {
461 self.policy
462 .predicate(&super::policy::PolicyPath::new(path, false))
463 });
464 paths.sort_by(|a, b| {
465 self.policy.rank(
466 &super::policy::PolicyPath::new(a, false),
467 &super::policy::PolicyPath::new(b, false),
468 )
469 });
470 let preferred_path = paths.into_iter().next().ok_or(PathWaitError::NoPathFound)?;
471 let cached_path = CachedPath::new(preferred_path.clone(), now, false);
472
473 let entry = self.path_cache.entry_sync((src, dst));
474 match entry {
475 Entry::Occupied(mut entry) => {
476 entry.update(cached_path);
477 }
478 Entry::Vacant(entry) => {
479 entry.insert_entry(cached_path);
480 }
481 }
482
483 debug!(src = %src, dst = %dst, "Cached new path");
484 Ok(preferred_path)
485 }
486
487 fn request_inflight(&self, src: IsdAsn, dst: IsdAsn) -> bool {
489 let guard = Guard::new();
490 self.inflight.peek(&(src, dst), &guard).is_some()
491 }
492}
493
494struct PathManagerTask<P: PathPolicy, F: PathFetcher> {
496 state: CachingPathManagerState<P, F>,
497 prefetch_rx: mpsc::Receiver<PrefetchRequest>,
498 registration_rx: mpsc::Receiver<PathRegistration>,
499 cancellation_token: CancellationToken,
500}
501
502impl<P: PathPolicy + Send + Sync + 'static, F: PathFetcher + Send + Sync + 'static>
503 PathManagerTask<P, F>
504{
505 fn new(
506 state: CachingPathManagerState<P, F>,
507 prefetch_rx: mpsc::Receiver<PrefetchRequest>,
508 registration_rx: mpsc::Receiver<PathRegistration>,
509 cancellation_token: CancellationToken,
510 ) -> Self {
511 Self {
512 state,
513 prefetch_rx,
514 registration_rx,
515 cancellation_token,
516 }
517 }
518
519 async fn run(mut self) {
520 trace!("Starting active path manager task");
521
522 loop {
523 tokio::select! {
524 _ = self.cancellation_token.cancelled() => {
526 debug!("Path manager task cancelled");
527 break;
528 }
529
530 registration = self.registration_rx.recv() => {
532 match registration {
533 Some(reg) => {
534 self.handle_registration(reg).await;
535 }
536 None => {
537 debug!("Registration channel closed");
538 break;
539 }
540 }
541 }
542
543 prefetch = self.prefetch_rx.recv() => {
545 match prefetch {
546 Some(req) => {
547 self.handle_prefetch(req).await;
548 }
549 None => {
550 debug!("Prefetch channel closed");
551 break;
552 }
553 }
554 }
555 }
556 }
557
558 trace!("Path manager task finished");
559 }
560
561 async fn handle_registration(&self, registration: PathRegistration) {
562 trace!(
563 src = %registration.src,
564 dst = %registration.dst,
565 "Handling path registration"
566 );
567
568 let policy_path = super::policy::PolicyPath::new(®istration.path, true);
570 if !self.state.policy.predicate(&policy_path) {
571 debug!(
572 src = %registration.src,
573 dst = %registration.dst,
574 "Registered path rejected by policy"
575 );
576 return;
577 }
578 let entry = self
579 .state
580 .path_cache
581 .entry_sync((registration.src, registration.dst));
582 match entry {
583 Entry::Occupied(mut entry) => {
584 let cached = entry.get();
585 let should_update = if cached.is_expired(registration.now) {
586 true
587 } else {
588 let existing_policy_path =
589 super::policy::PolicyPath::new(&cached.path, cached.from_registration);
590 matches!(
591 self.state.policy.rank(&policy_path, &existing_policy_path),
592 Ordering::Less
593 )
594 };
595 if should_update {
596 entry.update(CachedPath::new(registration.path, registration.now, true));
597 }
598 }
599 Entry::Vacant(entry) => {
600 entry.insert_entry(CachedPath::new(registration.path, registration.now, true));
601 }
602 }
603 }
604
605 async fn handle_prefetch(&self, request: PrefetchRequest) {
609 debug!(
610 src = %request.src,
611 dst = %request.dst,
612 "Handling prefetch request"
613 );
614
615 if self
617 .state
618 .cached_path_wait(request.src, request.dst, request.now)
619 .await
620 .is_some()
621 {
622 debug!("Path already cached, skipping prefetch");
623 return;
624 }
625
626 if self.state.request_inflight(request.src, request.dst) {
628 debug!(
629 src = %request.src,
630 dst = %request.dst,
631 "Path request already in flight, skipping prefetch"
632 );
633 return;
634 }
635
636 match self
640 .state
641 .fetch_and_cache_path(request.src, request.dst, request.now)
642 .await
643 {
644 Ok(_) => {
645 debug!(
646 src = %request.src,
647 dst = %request.dst,
648 "Successfully prefetched path"
649 );
650 }
651 Err(e) => {
652 error!(
653 src = %request.src,
654 dst = %request.dst,
655 error = %e,
656 "Failed to prefetch path"
657 );
658 }
659 }
660 }
661}
662
663pub type SegmentFetchError = Box<dyn std::error::Error + Send + Sync>;
665
666pub struct Segments {
668 pub core_segments: Vec<path::PathSegment>,
670 pub non_core_segments: Vec<path::PathSegment>,
672}
673
674pub trait SegmentFetcher {
676 fn fetch_segments<'a>(
678 &'a self,
679 src: IsdAsn,
680 dst: IsdAsn,
681 ) -> impl Future<Output = Result<Segments, SegmentFetchError>> + Send + 'a;
682}
683
684pub struct ConnectRpcSegmentFetcher {
686 client: Arc<dyn EndhostApiClient>,
687}
688
689impl ConnectRpcSegmentFetcher {
690 pub fn new(client: Arc<dyn EndhostApiClient>) -> Self {
692 Self { client }
693 }
694}
695
696impl SegmentFetcher for ConnectRpcSegmentFetcher {
697 async fn fetch_segments(
698 &self,
699 src: IsdAsn,
700 dst: IsdAsn,
701 ) -> Result<Segments, SegmentFetchError> {
702 let resp = self
703 .client
704 .list_segments(src, dst, 128, "".to_string())
705 .await?;
706 tracing::debug!(
707 n_core=resp.core_segments.len(),
708 n_up=resp.up_segments.len(),
709 n_down=resp.down_segments.len(),
710 src = %src,
711 dst = %dst,
712 "Received segments from control plane"
713 );
714 let (core_segments, non_core_segments) = resp.split_parts();
715 Ok(Segments {
716 core_segments,
717 non_core_segments,
718 })
719 }
720}
721
722pub struct PathFetcherImpl<F: SegmentFetcher = ConnectRpcSegmentFetcher> {
724 segment_fetcher: F,
725}
726
727impl<F: SegmentFetcher> PathFetcherImpl<F> {
728 pub fn new(segment_fetcher: F) -> Self {
730 Self { segment_fetcher }
731 }
732}
733
734impl<L: SegmentFetcher + Send + Sync> PathFetcher for PathFetcherImpl<L> {
735 async fn fetch_paths(
736 &self,
737 src: IsdAsn,
738 dst: IsdAsn,
739 ) -> Result<Vec<path::Path>, PathFetchError> {
740 let Segments {
741 core_segments,
742 non_core_segments,
743 } = self.segment_fetcher.fetch_segments(src, dst).await?;
744 trace!(
745 n_core_segments = core_segments.len(),
746 n_non_core_segments = non_core_segments.len(),
747 src = %src,
748 dst = %dst,
749 "Fetched segments"
750 );
751 let paths = path::combinator::combine(src, dst, core_segments, non_core_segments);
752 Ok(paths)
753 }
754}
755
756#[cfg(test)]
757mod tests {
758 use std::{
759 collections::HashMap,
760 sync::{
761 Arc, Mutex,
762 atomic::{AtomicUsize, Ordering},
763 },
764 };
765
766 use bytes::{BufMut, BytesMut};
767 use scion_proto::{
768 address::IsdAsn,
769 packet::ByEndpoint,
770 path::{self, DataPlanePath, EncodedStandardPath, Path},
771 wire_encoding::WireDecode,
772 };
773 use tokio::{sync::Barrier, task::yield_now};
774
775 use super::*;
776 use crate::path::policy;
777
778 type PathMap = HashMap<(IsdAsn, IsdAsn), Result<Vec<Path>, PathFetchError>>;
779 #[derive(Default)]
780 struct MockPathFetcher {
781 paths: Mutex<PathMap>,
782 call_count: AtomicUsize,
783 call_delay: Option<usize>,
784 barrier: Option<Arc<Barrier>>,
785 }
786
787 impl MockPathFetcher {
788 fn with_path(src: IsdAsn, dst: IsdAsn, path: Path) -> Self {
789 let mut paths = HashMap::new();
790 paths.insert((src, dst), Ok(vec![path]));
791 Self {
792 paths: Mutex::new(paths),
793 call_count: AtomicUsize::new(0),
794 call_delay: None,
795 barrier: None,
796 }
797 }
798
799 fn with_error(src: IsdAsn, dst: IsdAsn, error: &'static str) -> Self {
800 let mut paths = HashMap::new();
801 paths.insert((src, dst), Err(PathFetchError::FetchSegments(error.into())));
802 Self {
803 paths: Mutex::new(paths),
804 call_count: AtomicUsize::new(0),
805 call_delay: None,
806 barrier: None,
807 }
808 }
809
810 fn with_barrier(mut self, barrier: Arc<Barrier>) -> Self {
811 self.barrier = Some(barrier);
812 self
813 }
814 }
815
816 impl PathFetcher for MockPathFetcher {
817 fn fetch_paths(
818 &self,
819 src: IsdAsn,
820 dst: IsdAsn,
821 ) -> impl ResFut<'_, Vec<path::Path>, PathFetchError> {
822 async move {
823 self.call_count.fetch_add(1, Ordering::Relaxed);
824 if let Some(delay) = self.call_delay {
825 while self.call_count.load(Ordering::SeqCst) < delay {
826 yield_now().await;
827 }
828 }
829 if let Some(barrier) = &self.barrier {
830 barrier.wait().await;
831 }
832 match self.paths.lock().unwrap().get(&(src, dst)) {
833 Some(Ok(paths)) => Ok(paths.clone()),
834 None => Ok(vec![]),
835 Some(Err(_)) => Err(PathFetchError::FetchSegments("other error".into())),
836 }
837 }
838 }
839 }
840
841 fn test_path(src: IsdAsn, dst: IsdAsn) -> Path {
842 let mut path_raw = BytesMut::with_capacity(36);
843 path_raw.put_u32(0x0000_2000);
844 path_raw.put_slice(&[0_u8; 32]);
845 let dp_path =
846 DataPlanePath::Standard(EncodedStandardPath::decode(&mut path_raw.freeze()).unwrap());
847
848 Path::new(
849 dp_path,
850 ByEndpoint {
851 source: src,
852 destination: dst,
853 },
854 None,
855 )
856 }
857
858 fn setup_pm(fetcher: MockPathFetcher) -> CachingPathManagerState<Shortest, MockPathFetcher> {
859 CachingPathManagerState(Arc::new(CachingPathManagerStateInner {
860 policy: policy::Shortest {},
861 fetcher,
862 path_cache: HashIndex::new(),
863 inflight: HashIndex::new(),
864 }))
865 }
866
867 #[tokio::test]
868 async fn fetch_and_cache_path_single_request_success() {
869 let src = IsdAsn(0x1_ff00_0000_0110);
870 let dst = IsdAsn(0x1_ff00_0000_0111);
871 let path = test_path(src, dst);
872 let fetcher = MockPathFetcher::with_path(src, dst, path.clone());
873 let state = setup_pm(fetcher);
874
875 let result = state.fetch_and_cache_path(src, dst, Utc::now()).await;
876
877 assert!(result.is_ok());
878 assert_eq!(state.fetcher.call_count.load(Ordering::SeqCst), 1);
879 let guard = Guard::new();
880 assert!(state.path_cache.peek(&(src, dst), &guard).is_some());
881 assert!(state.inflight.peek(&(src, dst), &guard).is_none());
882 }
883
884 #[tokio::test]
885 async fn fetch_and_cache_path_concurrent_requests_coalesced() {
886 let src = IsdAsn(0x1_ff00_0000_0110);
887 let dst = IsdAsn(0x1_ff00_0000_0111);
888 let path = test_path(src, dst);
889 let barrier = Arc::new(Barrier::new(2));
890 let fetcher =
891 MockPathFetcher::with_path(src, dst, path.clone()).with_barrier(barrier.clone());
892 let state = setup_pm(fetcher);
893
894 let state_clone = state.clone();
895 let task1 =
896 tokio::spawn(
897 async move { state_clone.fetch_and_cache_path(src, dst, Utc::now()).await },
898 );
899 while state.fetcher.call_count.load(Ordering::SeqCst) < 1 {
901 yield_now().await;
902 }
903
904 let state_clone2 = state.clone();
905 let task2 = tokio::spawn(async move {
906 state_clone2
907 .fetch_and_cache_path(src, dst, Utc::now())
908 .await
909 });
910
911 barrier.wait().await;
913
914 let (res1, res2) = future::join(task1, task2).await;
915
916 assert_eq!(state.fetcher.call_count.load(Ordering::SeqCst), 1);
917 res1.unwrap().unwrap();
918 res2.unwrap().unwrap();
919 let guard = Guard::new();
920 assert!(state.inflight.peek(&(src, dst), &guard).is_none());
921 }
922
923 #[tokio::test]
924 async fn fetch_and_cache_path_fetch_error() {
925 let src = IsdAsn(0x1_ff00_0000_0110);
926 let dst = IsdAsn(0x1_ff00_0000_0111);
927 let fetcher = MockPathFetcher::with_error(src, dst, "error");
928 let state = setup_pm(fetcher);
929
930 let result = state.fetch_and_cache_path(src, dst, Utc::now()).await;
931
932 assert!(matches!(result, Err(PathWaitError::FetchFailed(_))));
933 assert_eq!(state.fetcher.call_count.load(Ordering::SeqCst), 1);
934 let guard = Guard::new();
935 assert!(state.path_cache.peek(&(src, dst), &guard).is_none());
936 assert!(state.inflight.peek(&(src, dst), &guard).is_none());
937 }
938
939 #[tokio::test]
940 async fn fetch_and_cache_path_no_path_found() {
941 let src = IsdAsn(0x1_ff00_0000_0110);
942 let dst = IsdAsn(0x1_ff00_0000_0111);
943 let fetcher = MockPathFetcher::default();
944 let state = setup_pm(fetcher);
945
946 let result = state.fetch_and_cache_path(src, dst, Utc::now()).await;
947
948 assert!(matches!(result, Err(PathWaitError::NoPathFound)));
949 assert_eq!(state.fetcher.call_count.load(Ordering::SeqCst), 1);
950 }
951
952 #[tokio::test]
953 async fn fetch_and_cache_path_concurrent_requests_different_keys() {
954 let src1 = IsdAsn(0x1_ff00_0000_0110);
955 let dst1 = IsdAsn(0x1_ff00_0000_0111);
956 let src2 = IsdAsn(0x1_ff00_0000_0120);
957 let dst2 = IsdAsn(0x1_ff00_0000_0121);
958 let path1 = test_path(src1, dst1);
959 let path2 = test_path(src2, dst2);
960
961 let mut paths = HashMap::new();
962 paths.insert((src1, dst1), Ok(vec![path1.clone()]));
963 paths.insert((src2, dst2), Ok(vec![path2.clone()]));
964
965 let barrier = Arc::new(Barrier::new(3));
966
967 let fetcher = MockPathFetcher {
968 paths: Mutex::new(paths),
969 ..Default::default()
970 }
971 .with_barrier(barrier.clone());
972 let state = setup_pm(fetcher);
973
974 let state_clone1 = state.clone();
975 let task1 = tokio::spawn(async move {
976 state_clone1
977 .fetch_and_cache_path(src1, dst1, Utc::now())
978 .await
979 });
980
981 let state_clone2 = state.clone();
982 let task2 = tokio::spawn(async move {
983 state_clone2
984 .fetch_and_cache_path(src2, dst2, Utc::now())
985 .await
986 });
987
988 barrier.wait().await;
990
991 let (res1, res2) = future::join(task1, task2).await;
992
993 assert_eq!(state.fetcher.call_count.load(Ordering::SeqCst), 2);
994 let got1 = res1.unwrap().unwrap();
995 let got2 = res2.unwrap().unwrap();
996 assert_eq!(got1.source(), path1.source());
997 assert_eq!(got1.destination(), path1.destination());
998 assert_eq!(got2.source(), path2.source());
999 assert_eq!(got2.destination(), path2.destination());
1000 }
1001}