use async_trait::async_trait;
use bytes::Bytes;
use futures_util::FutureExt;
use http::{Method, Request, StatusCode, Uri};
use std::{
fmt::Debug,
str::FromStr,
sync::Arc,
time::{Duration, Instant},
};
use stop_token::{StopSource, StopToken};
use crate::agent::{
route_provider::dynamic_routing::{
dynamic_route_provider::DynamicRouteProviderError,
messages::{FetchedNodes, NodeHealthState},
node::Node,
snapshot::routing_snapshot::RoutingSnapshot,
type_aliases::{AtomicSwap, ReceiverMpsc, ReceiverWatch, SenderMpsc},
},
HttpService,
};
const CHANNEL_BUFFER: usize = 128;
#[cfg(not(target_family = "wasm"))]
async fn fetch_receiver_recv(
rx: &mut tokio::sync::watch::Receiver<Option<FetchedNodes>>,
) -> Result<Option<FetchedNodes>, tokio::sync::watch::error::RecvError> {
rx.changed().await?;
Ok(rx.borrow_and_update().clone())
}
#[cfg(target_family = "wasm")]
async fn fetch_receiver_recv(
rx: &mut async_watch::Receiver<Option<FetchedNodes>>,
) -> Result<Option<FetchedNodes>, async_watch::error::RecvError> {
rx.recv().await
}
#[cfg_attr(target_family = "wasm", async_trait(?Send))]
#[cfg_attr(not(target_family = "wasm"), async_trait)]
pub trait HealthCheck: Send + Sync + Debug {
async fn check(&self, node: &Node) -> Result<HealthCheckStatus, DynamicRouteProviderError>;
}
#[derive(Clone, PartialEq, Debug, Default)]
pub struct HealthCheckStatus {
latency: Option<Duration>,
}
impl HealthCheckStatus {
pub fn new(latency: Option<Duration>) -> Self {
Self { latency }
}
pub fn is_healthy(&self) -> bool {
self.latency.is_some()
}
pub fn latency(&self) -> Option<Duration> {
self.latency
}
}
#[derive(Debug)]
pub struct HealthChecker {
http_client: Arc<dyn HttpService>,
#[cfg(not(target_family = "wasm"))]
timeout: Duration,
}
impl HealthChecker {
pub fn new(
http_client: Arc<dyn HttpService>,
#[cfg(not(target_family = "wasm"))] timeout: Duration,
) -> Self {
Self {
http_client,
#[cfg(not(target_family = "wasm"))]
timeout,
}
}
}
const HEALTH_CHECKER: &str = "HealthChecker";
#[cfg_attr(target_family = "wasm", async_trait(?Send))]
#[cfg_attr(not(target_family = "wasm"), async_trait)]
impl HealthCheck for HealthChecker {
#[allow(unused_mut)]
async fn check(&self, node: &Node) -> Result<HealthCheckStatus, DynamicRouteProviderError> {
let uri = Uri::from_str(&format!("https://{}/health", node.domain())).unwrap();
let request = Request::builder()
.method(Method::GET)
.uri(uri.clone())
.body(Bytes::new())
.unwrap();
let start = Instant::now();
#[cfg(not(target_family = "wasm"))]
let response = tokio::time::timeout(
self.timeout,
self.http_client.call(&|| Ok(request.clone()), 1, None),
)
.await
.map_err(|_| {
DynamicRouteProviderError::HealthCheckError(format!("GET request to {uri} timed out"))
})?;
#[cfg(target_family = "wasm")]
let response = self
.http_client
.call(&|| Ok(request.clone()), 1, None)
.await;
let response = response.map_err(|err| {
DynamicRouteProviderError::HealthCheckError(format!(
"Failed to execute GET request to {uri}: {err}"
))
})?;
let latency = start.elapsed();
if response.status() != StatusCode::NO_CONTENT {
let err_msg = format!(
"{HEALTH_CHECKER}: Unexpected http status code {} for url={uri} received",
response.status()
);
log!(error, err_msg);
return Err(DynamicRouteProviderError::HealthCheckError(err_msg));
}
Ok(HealthCheckStatus::new(Some(latency)))
}
}
#[allow(unused)]
const HEALTH_CHECK_ACTOR: &str = "HealthCheckActor";
struct HealthCheckActor {
checker: Arc<dyn HealthCheck>,
period: Duration,
node: Node,
sender_channel: SenderMpsc<NodeHealthState>,
token: StopToken,
}
impl HealthCheckActor {
fn new(
checker: Arc<dyn HealthCheck>,
period: Duration,
node: Node,
sender_channel: SenderMpsc<NodeHealthState>,
token: StopToken,
) -> Self {
Self {
checker,
period,
node,
sender_channel,
token,
}
}
async fn run(self) {
loop {
let health = futures_util::select! {
result = self.checker.check(&self.node).fuse() => result.unwrap_or_default(),
_ = self.token.clone().fuse() => {
log!(info, "{HEALTH_CHECK_ACTOR}: was gracefully cancelled for node {:?}", self.node);
break;
}
};
let message = NodeHealthState {
node: self.node.clone(),
health,
};
if self.sender_channel.send(message).await.is_err() {
break;
}
futures_util::select! {
_ = crate::util::sleep(self.period).fuse() => {
continue;
}
_ = self.token.clone().fuse() => {
log!(info, "{HEALTH_CHECK_ACTOR}: was gracefully cancelled for node {:?}", self.node);
break;
}
}
}
}
}
#[allow(unused)]
pub(super) const HEALTH_MANAGER_ACTOR: &str = "HealthManagerActor";
pub(super) struct HealthManagerActor<S> {
checker: Arc<dyn HealthCheck>,
period: Duration,
routing_snapshot: AtomicSwap<S>,
fetch_receiver: ReceiverWatch<FetchedNodes>,
check_sender: SenderMpsc<NodeHealthState>,
check_receiver: ReceiverMpsc<NodeHealthState>,
init_sender: SenderMpsc<bool>,
token: StopToken,
nodes_token: StopSource,
is_initialized: bool,
}
impl<S> HealthManagerActor<S>
where
S: RoutingSnapshot,
{
pub fn new(
checker: Arc<dyn HealthCheck>,
period: Duration,
routing_snapshot: AtomicSwap<S>,
fetch_receiver: ReceiverWatch<FetchedNodes>,
init_sender: SenderMpsc<bool>,
token: StopToken,
) -> Self {
let (check_sender, check_receiver) = async_channel::bounded(CHANNEL_BUFFER);
Self {
checker,
period,
routing_snapshot,
fetch_receiver,
check_sender,
check_receiver,
init_sender,
token,
nodes_token: StopSource::new(),
is_initialized: false,
}
}
pub async fn run(mut self) {
loop {
futures_util::select! {
result = fetch_receiver_recv(&mut self.fetch_receiver).fuse() => {
let value = match result {
Ok(value) => value,
Err(_err) => {
log!(error, "{HEALTH_MANAGER_ACTOR}: nodes fetch sender has been dropped: {_err:?}");
break;
}
};
let Some(FetchedNodes { nodes }) = value else { continue };
self.handle_fetch_update(nodes).await;
}
msg_opt = self.check_receiver.recv().fuse() => {
if let Ok(msg) = msg_opt {
self.handle_health_update(msg).await;
}
}
_ = self.token.clone().fuse() => {
self.stop_all_checks().await;
self.check_receiver.close();
log!(warn, "{HEALTH_MANAGER_ACTOR}: was gracefully cancelled, all nodes health checks stopped");
break;
}
}
}
}
async fn handle_health_update(&mut self, msg: NodeHealthState) {
let current_snapshot = self.routing_snapshot.load_full();
let mut new_snapshot = (*current_snapshot).clone();
new_snapshot.update_node(&msg.node, msg.health.clone());
self.routing_snapshot.store(Arc::new(new_snapshot));
if !self.is_initialized && msg.health.is_healthy() {
self.is_initialized = true;
let _ = self.init_sender.send(true).await;
}
}
async fn handle_fetch_update(&mut self, nodes: Vec<Node>) {
if nodes.is_empty() {
log!(
error,
"{HEALTH_MANAGER_ACTOR}: list of fetched nodes is empty"
);
return;
}
log!(
debug,
"{HEALTH_MANAGER_ACTOR}: fetched nodes received {:?}",
nodes
);
let current_snapshot = self.routing_snapshot.load_full();
let mut new_snapshot = (*current_snapshot).clone();
if new_snapshot.sync_nodes(&nodes) {
self.routing_snapshot.store(Arc::new(new_snapshot));
self.stop_all_checks().await;
self.start_checks(nodes.to_vec());
}
}
fn start_checks(&mut self, nodes: Vec<Node>) {
self.nodes_token = StopSource::new();
for node in nodes {
log!(
debug,
"{HEALTH_MANAGER_ACTOR}: starting health check for node {node:?}"
);
let actor = HealthCheckActor::new(
Arc::clone(&self.checker),
self.period,
node,
self.check_sender.clone(),
self.nodes_token.token(),
);
crate::util::spawn(async move { actor.run().await });
}
}
async fn stop_all_checks(&mut self) {
log!(
warn,
"{HEALTH_MANAGER_ACTOR}: stopping all running health checks"
);
self.nodes_token = StopSource::new();
}
}
#[cfg(all(test, not(target_family = "wasm")))]
mod tests {
use std::sync::Arc;
use std::time::Duration;
use arc_swap::ArcSwap;
use stop_token::StopSource;
use crate::agent::route_provider::dynamic_routing::{
messages::FetchedNodes, node::Node,
snapshot::latency_based_routing::LatencyRoutingSnapshot, test_utils::NodeHealthCheckerMock,
};
use super::{HealthCheck, HealthManagerActor};
fn make_nodes(n: usize) -> Vec<Node> {
(1..=n)
.map(|i| Node::new(format!("api{i}.example.com")).unwrap())
.collect()
}
#[tokio::test(flavor = "multi_thread")]
async fn test_health_manager_no_panic_on_rapid_updates_and_shutdown() {
let nodes = make_nodes(5);
for _ in 0..50 {
let checker = Arc::new(NodeHealthCheckerMock::new());
checker.overwrite_healthy_nodes(nodes.clone());
let routing_snapshot = Arc::new(ArcSwap::from_pointee(LatencyRoutingSnapshot::new()));
let (fetch_sender, fetch_receiver) = tokio::sync::watch::channel(None);
let (init_sender, _init_receiver) = async_channel::bounded(1);
let stop_source = StopSource::new();
let actor = HealthManagerActor::new(
Arc::clone(&checker) as Arc<dyn HealthCheck>,
Duration::from_millis(1), Arc::clone(&routing_snapshot),
fetch_receiver,
init_sender,
stop_source.token(),
);
let handle = tokio::spawn(actor.run());
let nodes_clone = nodes.clone();
let flood_handle = tokio::spawn(async move {
for i in 0..200usize {
let batch = nodes_clone[..=(i % nodes_clone.len())].to_vec();
let _ = fetch_sender.send(Some(FetchedNodes { nodes: batch }));
tokio::task::yield_now().await;
}
});
tokio::time::sleep(Duration::from_millis(5)).await;
drop(stop_source);
flood_handle.await.expect("flood task should not panic");
tokio::time::timeout(Duration::from_secs(2), handle)
.await
.expect("HealthManagerActor timed out; it may be stuck in an infinite loop")
.expect("HealthManagerActor panicked");
}
}
#[tokio::test(flavor = "multi_thread")]
async fn test_health_manager_exits_when_fetch_sender_dropped() {
let nodes = make_nodes(3);
let checker = Arc::new(NodeHealthCheckerMock::new());
checker.overwrite_healthy_nodes(nodes.clone());
let routing_snapshot = Arc::new(ArcSwap::from_pointee(LatencyRoutingSnapshot::new()));
let (fetch_sender, fetch_receiver) = tokio::sync::watch::channel(None);
let (init_sender, _init_receiver) = async_channel::bounded(1);
let stop_source = StopSource::new();
let actor = HealthManagerActor::new(
Arc::clone(&checker) as Arc<dyn HealthCheck>,
Duration::from_millis(10),
Arc::clone(&routing_snapshot),
fetch_receiver,
init_sender,
stop_source.token(),
);
let handle = tokio::spawn(actor.run());
fetch_sender
.send(Some(FetchedNodes {
nodes: nodes.clone(),
}))
.expect("initial send should succeed");
tokio::time::sleep(Duration::from_millis(20)).await;
drop(fetch_sender);
let result = tokio::time::timeout(Duration::from_secs(2), handle).await;
drop(stop_source);
result
.expect("HealthManagerActor did not exit within 2 s after fetch_sender was dropped")
.expect("HealthManagerActor panicked");
}
}