use std::{
sync::{
atomic::{AtomicBool, Ordering},
Arc,
},
time::{Duration, Instant},
};
use arc_swap::ArcSwap;
use candid::Principal;
use futures_util::FutureExt;
use stop_token::StopSource;
use thiserror::Error;
use url::Url;
use crate::{
agent::{
route_provider::{
dynamic_routing::{
health_check::{HealthCheck, HealthChecker, HealthManagerActor},
messages::FetchedNodes,
node::Node,
nodes_fetch::{Fetch, NodesFetchActor, NodesFetcher},
snapshot::{
latency_based_routing::LatencyRoutingSnapshot,
routing_snapshot::RoutingSnapshot,
},
type_aliases::AtomicSwap,
},
RouteProvider, RoutesStats,
},
HttpService,
},
AgentError,
};
#[allow(unused)]
pub const IC0_SEED_DOMAIN: &str = "ic0.app";
const MAINNET_ROOT_SUBNET_ID: &str =
"tdb26-jop6k-aogll-7ltgs-eruif-6kk7m-qpktf-gdiqx-mxtrf-vb5e6-eqe";
const FETCH_PERIOD: Duration = Duration::from_secs(5);
const FETCH_RETRY_INTERVAL: Duration = Duration::from_millis(250);
const TIMEOUT_AWAIT_HEALTHY_SEED: Duration = Duration::from_millis(1000);
#[allow(unused)]
const HEALTH_CHECK_TIMEOUT: Duration = Duration::from_secs(1);
const HEALTH_CHECK_PERIOD: Duration = Duration::from_secs(1);
#[allow(unused)]
const DYNAMIC_ROUTE_PROVIDER: &str = "DynamicRouteProvider";
#[derive(Debug)]
pub struct DynamicRouteProvider {
fetcher: Arc<dyn Fetch>,
fetch_period: Duration,
fetch_retry_interval: Duration,
checker: Arc<dyn HealthCheck>,
check_period: Duration,
routing_snapshot: AtomicSwap<LatencyRoutingSnapshot>,
seeds: Vec<Node>,
token: StopSource,
started: Arc<AtomicBool>,
}
#[derive(Error, Debug)]
pub enum DynamicRouteProviderError {
#[error("An error when fetching API nodes: {0}")]
NodesFetchError(String),
#[error("An error when checking API node's health: {0}")]
HealthCheckError(String),
}
pub struct DynamicRouteProviderBuilder {
fetcher: Arc<dyn Fetch>,
fetch_period: Duration,
fetch_retry_interval: Duration,
checker: Arc<dyn HealthCheck>,
check_period: Duration,
seeds: Vec<Node>,
k_top_nodes: Option<usize>,
}
impl DynamicRouteProviderBuilder {
pub fn new(
seeds: Vec<Node>,
http_client: Arc<dyn HttpService>,
k_top_nodes: Option<usize>,
) -> Self {
let fetcher = Arc::new(NodesFetcher::new(
http_client.clone(),
Principal::from_text(MAINNET_ROOT_SUBNET_ID).unwrap(),
None,
));
let checker = Arc::new(HealthChecker::new(
http_client,
#[cfg(not(target_family = "wasm"))]
HEALTH_CHECK_TIMEOUT,
));
Self::from_components(seeds, fetcher, checker, k_top_nodes)
}
#[allow(unused)]
pub fn from_components(
seeds: Vec<Node>,
fetcher: Arc<dyn Fetch>,
checker: Arc<dyn HealthCheck>,
k_top_nodes: Option<usize>,
) -> Self {
Self {
fetcher,
fetch_period: FETCH_PERIOD,
fetch_retry_interval: FETCH_RETRY_INTERVAL,
checker,
check_period: HEALTH_CHECK_PERIOD,
seeds,
k_top_nodes,
}
}
#[allow(unused)]
pub fn with_fetcher(mut self, fetcher: Arc<dyn Fetch>) -> Self {
self.fetcher = fetcher;
self
}
pub fn with_fetch_period(mut self, period: Duration) -> Self {
self.fetch_period = period;
self
}
#[allow(unused)]
pub fn with_checker(mut self, checker: Arc<dyn HealthCheck>) -> Self {
self.checker = checker;
self
}
pub fn with_check_period(mut self, period: Duration) -> Self {
self.check_period = period;
self
}
pub fn build(self) -> DynamicRouteProvider {
let mut snapshot = LatencyRoutingSnapshot::new();
if let Some(k) = self.k_top_nodes {
snapshot = snapshot.set_k_top_nodes(k);
}
DynamicRouteProvider {
fetcher: self.fetcher,
fetch_period: self.fetch_period,
fetch_retry_interval: self.fetch_retry_interval,
checker: self.checker,
check_period: self.check_period,
routing_snapshot: Arc::new(ArcSwap::from_pointee(snapshot)),
seeds: self.seeds,
token: StopSource::new(),
started: Arc::new(AtomicBool::new(false)),
}
}
}
impl RouteProvider for DynamicRouteProvider {
fn route(&self) -> Result<Url, AgentError> {
self.ensure_started();
let snapshot = self.routing_snapshot.load();
let node = snapshot.next_node().ok_or_else(|| {
AgentError::RouteProviderError("No healthy API nodes found.".to_string())
})?;
Ok(node.to_routing_url())
}
fn n_ordered_routes(&self, n: usize) -> Result<Vec<Url>, AgentError> {
self.ensure_started();
let snapshot = self.routing_snapshot.load();
let nodes = snapshot.next_n_nodes(n);
if nodes.is_empty() {
return Err(AgentError::RouteProviderError(
"No healthy API nodes found.".to_string(),
));
};
let urls = nodes.iter().map(|n| n.to_routing_url()).collect();
Ok(urls)
}
fn routes_stats(&self) -> RoutesStats {
let snapshot = self.routing_snapshot.load();
snapshot.routes_stats()
}
}
struct BackgroundTaskConfig {
fetcher: Arc<dyn Fetch>,
checker: Arc<dyn HealthCheck>,
routing_snapshot: AtomicSwap<LatencyRoutingSnapshot>,
seeds: Vec<Node>,
fetch_period: Duration,
fetch_retry_interval: Duration,
check_period: Duration,
token: stop_token::StopToken,
}
impl DynamicRouteProvider {
pub async fn start(&self) {
if self.started.swap(true, Ordering::AcqRel) {
return;
}
self.run().await;
}
fn ensure_started(&self) {
if self
.started
.compare_exchange(false, true, Ordering::AcqRel, Ordering::Acquire)
.is_err()
{
return;
}
let config = BackgroundTaskConfig {
fetcher: Arc::clone(&self.fetcher),
checker: Arc::clone(&self.checker),
routing_snapshot: Arc::clone(&self.routing_snapshot),
seeds: self.seeds.clone(),
fetch_period: self.fetch_period,
fetch_retry_interval: self.fetch_retry_interval,
check_period: self.check_period,
token: self.token.token(),
};
crate::util::spawn(async move {
Self::run_background_tasks(config).await;
});
}
async fn run(&self) {
let config = BackgroundTaskConfig {
fetcher: Arc::clone(&self.fetcher),
checker: Arc::clone(&self.checker),
routing_snapshot: Arc::clone(&self.routing_snapshot),
seeds: self.seeds.clone(),
fetch_period: self.fetch_period,
fetch_retry_interval: self.fetch_retry_interval,
check_period: self.check_period,
token: self.token.token(),
};
Self::run_background_tasks(config).await;
}
async fn run_background_tasks(config: BackgroundTaskConfig) {
log!(info, "{DYNAMIC_ROUTE_PROVIDER}: started ...");
#[cfg(not(target_family = "wasm"))]
let (fetch_sender, fetch_receiver) = tokio::sync::watch::channel(None);
#[cfg(target_family = "wasm")]
let (fetch_sender, fetch_receiver) = async_watch::channel(None);
let (init_sender, init_receiver) = async_channel::bounded(1);
let health_manager_actor = HealthManagerActor::new(
Arc::clone(&config.checker),
config.check_period,
Arc::clone(&config.routing_snapshot),
fetch_receiver,
init_sender,
config.token.clone(),
);
crate::util::spawn(async move { health_manager_actor.run().await });
if let Err(_err) = fetch_sender.send(Some(FetchedNodes {
nodes: config.seeds.clone(),
})) {
log!(
error,
"{DYNAMIC_ROUTE_PROVIDER}: failed to send results to HealthManager: {_err:?}"
);
}
let _start = Instant::now();
futures_util::select! {
_ = crate::util::sleep(TIMEOUT_AWAIT_HEALTHY_SEED).fuse() => {
log!(
warn,
"{DYNAMIC_ROUTE_PROVIDER}: no healthy seeds found within {:?}",
_start.elapsed()
);
}
_ = init_receiver.recv().fuse() => {
log!(
info,
"{DYNAMIC_ROUTE_PROVIDER}: found healthy seeds within {:?}",
_start.elapsed()
);
}
}
init_receiver.close();
let fetch_actor = NodesFetchActor::new(
Arc::clone(&config.fetcher),
config.fetch_period,
config.fetch_retry_interval,
fetch_sender,
Arc::clone(&config.routing_snapshot),
config.token,
);
crate::util::spawn(async move { fetch_actor.run().await });
log!(
info,
"{DYNAMIC_ROUTE_PROVIDER}: NodesFetchActor and HealthManagerActor started successfully"
);
}
}
#[cfg(all(test, not(target_family = "wasm")))]
mod tests {
use candid::Principal;
use std::{
sync::{Arc, Once},
time::{Duration, Instant},
};
use tracing::Level;
use tracing_subscriber::FmtSubscriber;
use crate::{
agent::route_provider::{
dynamic_routing::{
dynamic_route_provider::{
DynamicRouteProviderBuilder, IC0_SEED_DOMAIN, MAINNET_ROOT_SUBNET_ID,
},
node::Node,
test_utils::{
assert_routed_domains, route_n_times, NodeHealthCheckerMock, NodesFetcherMock,
},
},
RouteProvider, RoutesStats,
},
Agent, AgentError,
};
static TRACING_INIT: Once = Once::new();
pub fn setup_tracing() {
TRACING_INIT.call_once(|| {
FmtSubscriber::builder()
.with_max_level(Level::TRACE)
.with_test_writer()
.init();
});
}
async fn assert_no_routing_via_domains(
route_provider: Arc<dyn RouteProvider>,
excluded_domains: Vec<&str>,
timeout: Duration,
route_call_interval: Duration,
) {
if excluded_domains.is_empty() {
panic!("List of excluded domains can't be empty");
}
let route_calls = 30;
let start = Instant::now();
while start.elapsed() < timeout {
let routed_domains = (0..route_calls)
.map(|_| {
route_provider.route().map(|url| {
let domain = url.domain().expect("no domain name in url");
domain.to_string()
})
})
.collect::<Result<Vec<String>, _>>()
.unwrap_or_default();
if !routed_domains.is_empty()
&& !routed_domains
.iter()
.any(|d| excluded_domains.contains(&d.as_str()))
{
return;
}
tokio::time::sleep(route_call_interval).await;
}
panic!("Expected excluded domains {excluded_domains:?} are still observed in routing over the last {route_calls} calls");
}
#[tokio::test]
async fn test_mainnet() {
setup_tracing();
let seed = Node::new(IC0_SEED_DOMAIN).unwrap();
let http_client = Arc::new(reqwest::Client::new());
let route_provider =
DynamicRouteProviderBuilder::new(vec![seed], http_client, None).build();
route_provider.start().await;
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
let route_provider = Arc::new(route_provider) as Arc<dyn RouteProvider>;
let agent = Agent::builder()
.with_arc_route_provider(Arc::clone(&route_provider))
.build()
.expect("failed to create an agent");
let subnet_id = Principal::from_text(MAINNET_ROOT_SUBNET_ID).unwrap();
assert_no_routing_via_domains(
Arc::clone(&route_provider),
vec![IC0_SEED_DOMAIN],
Duration::from_secs(40),
Duration::from_secs(2),
)
.await;
let api_bns = agent
.fetch_api_boundary_nodes_by_subnet_id(subnet_id)
.await
.expect("failed to fetch api boundary nodes");
assert!(!api_bns.is_empty());
}
#[tokio::test]
async fn test_routing_with_topology_and_node_health_updates() {
setup_tracing();
let node_1 = Node::new(IC0_SEED_DOMAIN).unwrap();
let fetcher = Arc::new(NodesFetcherMock::new());
fetcher.overwrite_nodes(vec![node_1.clone()]);
let fetch_interval = Duration::from_secs(2);
let checker = Arc::new(NodeHealthCheckerMock::new());
let check_interval = Duration::from_secs(1);
fetcher.overwrite_nodes(vec![node_1.clone()]);
checker.overwrite_healthy_nodes(vec![node_1.clone()]);
let client = reqwest::Client::builder().build().unwrap();
let route_provider =
DynamicRouteProviderBuilder::new(vec![node_1.clone()], Arc::new(client), None)
.with_fetcher(fetcher.clone())
.with_checker(checker.clone())
.with_fetch_period(fetch_interval)
.with_check_period(check_interval)
.build();
route_provider.start().await;
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
let route_provider = Arc::new(route_provider);
let snapshot_update_duration = fetch_interval + 2 * check_interval;
tokio::time::sleep(snapshot_update_duration).await;
let routed_domains = route_n_times(6, Arc::clone(&route_provider));
assert_routed_domains(routed_domains, vec![node_1.domain()]);
assert_eq!(route_provider.routes_stats(), RoutesStats::new(1, Some(1)));
let node_2 = Node::new("api1.com").unwrap();
let node_3 = Node::new("api2.com").unwrap();
checker.overwrite_healthy_nodes(vec![node_1.clone(), node_2.clone(), node_3.clone()]);
fetcher.overwrite_nodes(vec![node_1.clone(), node_2.clone(), node_3.clone()]);
tokio::time::sleep(snapshot_update_duration).await;
let routed_domains = route_n_times(30, Arc::clone(&route_provider)); assert_routed_domains(
routed_domains,
vec![node_1.domain(), node_2.domain(), node_3.domain()],
);
assert_eq!(route_provider.routes_stats(), RoutesStats::new(3, Some(3)));
checker.overwrite_healthy_nodes(vec![node_1.clone(), node_3.clone()]);
tokio::time::sleep(snapshot_update_duration).await;
let routed_domains = route_n_times(20, Arc::clone(&route_provider)); assert_routed_domains(routed_domains, vec![node_1.domain(), node_3.domain()]);
assert_eq!(route_provider.routes_stats(), RoutesStats::new(3, Some(2)));
checker.overwrite_healthy_nodes(vec![node_1.clone(), node_2.clone(), node_3.clone()]);
tokio::time::sleep(snapshot_update_duration).await;
let routed_domains = route_n_times(30, Arc::clone(&route_provider)); assert_routed_domains(
routed_domains,
vec![node_1.domain(), node_2.domain(), node_3.domain()],
);
assert_eq!(route_provider.routes_stats(), RoutesStats::new(3, Some(3)));
let node_4 = Node::new("api3.com").unwrap();
checker.overwrite_healthy_nodes(vec![node_2.clone(), node_3.clone(), node_4.clone()]);
fetcher.overwrite_nodes(vec![
node_1.clone(),
node_2.clone(),
node_3.clone(),
node_4.clone(),
]);
tokio::time::sleep(snapshot_update_duration).await;
let routed_domains = route_n_times(30, Arc::clone(&route_provider)); assert_routed_domains(
routed_domains,
vec![node_2.domain(), node_3.domain(), node_4.domain()],
);
assert_eq!(route_provider.routes_stats(), RoutesStats::new(4, Some(3)));
checker.overwrite_healthy_nodes(vec![node_2.clone(), node_3.clone()]);
fetcher.overwrite_nodes(vec![node_1.clone(), node_2.clone(), node_4.clone()]);
tokio::time::sleep(snapshot_update_duration).await;
let routed_domains = route_n_times(3, Arc::clone(&route_provider));
assert_routed_domains(routed_domains, vec![node_2.domain()]);
assert_eq!(route_provider.routes_stats(), RoutesStats::new(3, Some(1)));
}
#[tokio::test]
async fn test_route_with_initially_unhealthy_seeds_becoming_healthy() {
setup_tracing();
let node_1 = Node::new(IC0_SEED_DOMAIN).unwrap();
let node_2 = Node::new("api1.com").unwrap();
let fetcher = Arc::new(NodesFetcherMock::new());
let fetch_interval = Duration::from_secs(2);
let checker = Arc::new(NodeHealthCheckerMock::new());
let check_interval = Duration::from_secs(1);
fetcher.overwrite_nodes(vec![node_1.clone(), node_2.clone()]);
checker.overwrite_healthy_nodes(vec![]);
let route_provider = DynamicRouteProviderBuilder::from_components(
vec![node_1.clone(), node_2.clone()],
fetcher,
checker.clone(),
None,
)
.with_fetch_period(fetch_interval)
.with_check_period(check_interval)
.build();
route_provider.start().await;
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
let route_provider = Arc::new(route_provider);
for _ in 0..4 {
tokio::time::sleep(check_interval).await;
let result = route_provider.route();
assert_eq!(
result.unwrap_err(),
AgentError::RouteProviderError("No healthy API nodes found.".to_string())
);
}
checker.overwrite_healthy_nodes(vec![node_1.clone(), node_2.clone()]);
tokio::time::sleep(3 * check_interval).await;
let routed_domains = route_n_times(6, Arc::clone(&route_provider));
assert_routed_domains(routed_domains, vec![node_1.domain(), node_2.domain()]);
}
#[tokio::test]
async fn test_routing_with_no_healthy_nodes_returns_an_error() {
setup_tracing();
let node_1 = Node::new(IC0_SEED_DOMAIN).unwrap();
let fetcher = Arc::new(NodesFetcherMock::new());
let fetch_interval = Duration::from_secs(2);
let checker = Arc::new(NodeHealthCheckerMock::new());
let check_interval = Duration::from_secs(1);
fetcher.overwrite_nodes(vec![node_1.clone()]);
checker.overwrite_healthy_nodes(vec![node_1.clone()]);
let client = reqwest::Client::builder().build().unwrap();
let route_provider =
DynamicRouteProviderBuilder::new(vec![node_1.clone()], Arc::new(client), None)
.with_fetcher(fetcher)
.with_checker(checker.clone())
.with_fetch_period(fetch_interval)
.with_check_period(check_interval)
.build();
route_provider.start().await;
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
let route_provider = Arc::new(route_provider);
tokio::time::sleep(2 * check_interval).await;
let routed_domains = route_n_times(3, Arc::clone(&route_provider));
assert_routed_domains(routed_domains, vec![node_1.domain()]);
checker.overwrite_healthy_nodes(vec![]);
tokio::time::sleep(2 * check_interval).await;
for _ in 0..4 {
let result = route_provider.route();
assert_eq!(
result.unwrap_err(),
AgentError::RouteProviderError("No healthy API nodes found.".to_string())
);
}
}
#[tokio::test]
async fn test_route_with_no_healthy_seeds_errors() {
setup_tracing();
let node_1 = Node::new(IC0_SEED_DOMAIN).unwrap();
let fetcher = Arc::new(NodesFetcherMock::new());
let fetch_interval = Duration::from_secs(2);
let checker = Arc::new(NodeHealthCheckerMock::new());
let check_interval = Duration::from_secs(1);
fetcher.overwrite_nodes(vec![]);
checker.overwrite_healthy_nodes(vec![]);
let client = reqwest::Client::builder().build().unwrap();
let route_provider =
DynamicRouteProviderBuilder::new(vec![node_1.clone()], Arc::new(client), None)
.with_fetcher(fetcher)
.with_checker(checker)
.with_fetch_period(fetch_interval)
.with_check_period(check_interval)
.build();
for _ in 0..4 {
tokio::time::sleep(check_interval).await;
let result = route_provider.route();
assert_eq!(
result.unwrap_err(),
AgentError::RouteProviderError("No healthy API nodes found.".to_string())
);
}
}
#[ignore]
#[tokio::test]
async fn test_route_with_one_healthy_and_one_unhealthy_seed() {
setup_tracing();
let node_1 = Node::new(IC0_SEED_DOMAIN).unwrap();
let node_2 = Node::new("api1.com").unwrap();
let fetcher = Arc::new(NodesFetcherMock::new());
let fetch_interval = Duration::from_secs(2);
let checker = Arc::new(NodeHealthCheckerMock::new());
let check_interval = Duration::from_secs(1);
fetcher.overwrite_nodes(vec![node_1.clone(), node_2.clone()]);
checker.overwrite_healthy_nodes(vec![node_1.clone()]);
let route_provider = DynamicRouteProviderBuilder::from_components(
vec![node_1.clone(), node_2.clone()],
fetcher,
checker.clone(),
None,
)
.with_fetch_period(fetch_interval)
.with_check_period(check_interval)
.build();
route_provider.start().await;
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
let route_provider = Arc::new(route_provider);
let routed_domains = route_n_times(3, Arc::clone(&route_provider));
assert_routed_domains(routed_domains, vec![node_1.domain()]);
checker.overwrite_healthy_nodes(vec![node_1.clone(), node_2.clone()]);
tokio::time::sleep(2 * check_interval).await;
let routed_domains = route_n_times(6, Arc::clone(&route_provider));
assert_routed_domains(routed_domains, vec![node_1.domain(), node_2.domain()]);
}
#[tokio::test]
async fn test_routing_with_an_empty_fetched_list_of_api_nodes() {
setup_tracing();
let node_1 = Node::new(IC0_SEED_DOMAIN).unwrap();
let fetcher = Arc::new(NodesFetcherMock::new());
let fetch_interval = Duration::from_secs(2);
let checker = Arc::new(NodeHealthCheckerMock::new());
let check_interval = Duration::from_secs(1);
fetcher.overwrite_nodes(vec![]);
checker.overwrite_healthy_nodes(vec![node_1.clone()]);
let client = reqwest::Client::builder().build().unwrap();
let route_provider =
DynamicRouteProviderBuilder::new(vec![node_1.clone()], Arc::new(client), None)
.with_fetcher(fetcher.clone())
.with_checker(checker.clone())
.with_fetch_period(fetch_interval)
.with_check_period(check_interval)
.build();
route_provider.start().await;
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
let route_provider = Arc::new(route_provider);
let snapshot_update_duration = fetch_interval + 2 * check_interval;
tokio::time::sleep(snapshot_update_duration).await;
let routed_domains = route_n_times(3, Arc::clone(&route_provider));
assert_routed_domains(routed_domains, vec![node_1.domain()]);
let node_2 = Node::new("api1.com").unwrap();
let node_3 = Node::new("api2.com").unwrap();
fetcher.overwrite_nodes(vec![node_1.clone(), node_2.clone(), node_3.clone()]);
checker.overwrite_healthy_nodes(vec![node_1.clone(), node_2.clone(), node_3.clone()]);
tokio::time::sleep(snapshot_update_duration).await;
let routed_domains = route_n_times(30, Arc::clone(&route_provider)); assert_routed_domains(
routed_domains,
vec![node_1.domain(), node_2.domain(), node_3.domain()],
);
}
}