use std::sync::Arc;
use endhost_api_client::client::EndhostApiClient;
use scion_proto::{
address::IsdAsn,
path::{self, Path},
};
use crate::path::fetcher::traits::{
PathFetchError, PathFetcher, SegmentFetchError, SegmentFetcher, Segments,
};
pub mod traits {
use std::borrow::Cow;
use scion_proto::{
address::IsdAsn,
path::{Path, PathSegment},
};
use crate::types::ResFut;
pub trait PathFetcher: Send + Sync + 'static {
fn fetch_paths(
&self,
src: IsdAsn,
dst: IsdAsn,
) -> impl ResFut<'_, Vec<Path>, PathFetchError>;
}
#[derive(Debug, thiserror::Error)]
pub enum PathFetchError {
#[error("failed to fetch segments: {0}")]
FetchSegments(#[from] SegmentFetchError),
#[error("no paths found")]
NoPathsFound,
#[error("internal error: {0}")]
InternalError(Cow<'static, str>),
}
#[async_trait::async_trait]
pub trait SegmentFetcher: Send + Sync + 'static {
async fn fetch_segments(
&self,
src: IsdAsn,
dst: IsdAsn,
) -> Result<Segments, SegmentFetchError>;
}
pub type SegmentFetchError = Box<dyn std::error::Error + Send + Sync>;
#[derive(Debug)]
pub struct Segments {
pub core_segments: Vec<PathSegment>,
pub non_core_segments: Vec<PathSegment>,
}
}
pub struct PathFetcherImpl {
segment_fetchers: Vec<(String, Box<dyn SegmentFetcher>)>,
timeout: std::time::Duration,
}
impl PathFetcherImpl {
pub fn new(
segment_fetchers: Vec<(String, Box<dyn SegmentFetcher>)>,
timeout: std::time::Duration,
) -> Self {
Self {
segment_fetchers,
timeout,
}
}
}
impl PathFetcher for PathFetcherImpl {
async fn fetch_paths(&self, src: IsdAsn, dst: IsdAsn) -> Result<Vec<Path>, PathFetchError> {
let mut all_core_segments = Vec::new();
let mut all_non_core_segments = Vec::new();
let fetch_tasks: Vec<_> = self
.segment_fetchers
.iter()
.map(|(_, fetcher)| {
tokio::time::timeout(self.timeout, fetcher.fetch_segments(src, dst))
})
.collect();
let results = futures::future::join_all(fetch_tasks).await;
let mut errors = Vec::new();
for (i, result) in results.into_iter().enumerate() {
let fetcher_name = &self.segment_fetchers[i].0;
match result {
Ok(res) => {
match res {
Ok(segments) => {
tracing::info!(
name = %fetcher_name,
n_core_segments = segments.core_segments.len(),
n_non_core_segments = segments.non_core_segments.len(),
%src,
%dst,
"Segment fetcher succeeded"
);
all_core_segments.extend(segments.core_segments);
all_non_core_segments.extend(segments.non_core_segments);
}
Err(e) => {
errors.push((fetcher_name.clone(), e));
}
}
}
Err(e) => {
errors.push((
fetcher_name.clone(),
Box::new(e) as Box<dyn std::error::Error + Send + Sync>,
));
}
}
}
let paths = path::combinator::combine(src, dst, all_core_segments, all_non_core_segments);
for (fetcher_name, error) in errors.iter() {
tracing::warn!(
name = %fetcher_name,
%error,
%src,
%dst,
"Segment fetcher failed"
);
}
if let Some((_name, err)) = errors.into_iter().next()
&& paths.is_empty()
{
return Err(PathFetchError::FetchSegments(err));
}
Ok(paths)
}
}
pub struct EndhostApiSegmentFetcher {
client: Arc<dyn EndhostApiClient>,
}
impl EndhostApiSegmentFetcher {
pub fn new(client: Arc<dyn EndhostApiClient>) -> Self {
Self { client }
}
}
#[async_trait::async_trait]
impl SegmentFetcher for EndhostApiSegmentFetcher {
async fn fetch_segments(
&self,
src: IsdAsn,
dst: IsdAsn,
) -> Result<Segments, SegmentFetchError> {
let resp = self
.client
.list_segments(src.into(), dst.into(), 128, "".to_string())
.await?;
tracing::trace!(
n_core=resp.segments.core_segments.len(),
n_up=resp.segments.up_segments.len(),
n_down=resp.segments.down_segments.len(),
src = %src,
dst = %dst,
"Received segments from endhost API"
);
let (core_segments, non_core_segments) = resp.segments.split_parts();
Ok(Segments {
core_segments: core_segments.into_iter().map(Into::into).collect(),
non_core_segments: non_core_segments.into_iter().map(Into::into).collect(),
})
}
}