use reqwest::Client;
use std::{
collections::HashMap,
sync::Arc,
time::{Duration, Instant},
};
use tokio::{
sync::{Mutex, Notify, RwLock},
time::interval,
};
use url::Url;
use crate::{
error::FusekiResult,
federation::{CircuitBreakerConfig, FederationConfig, ServiceEndpoint, ServiceHealth},
};
#[derive(Debug, Clone)]
pub struct HealthCheckResult {
pub url: Url,
pub timestamp: Instant,
pub response_time: Option<Duration>,
pub success: bool,
pub error: Option<String>,
}
#[derive(Debug, Clone)]
struct CircuitBreaker {
state: CircuitState,
failure_count: u32,
success_count: u32,
last_state_change: Instant,
config: CircuitBreakerConfig,
}
#[derive(Debug, Clone, PartialEq)]
enum CircuitState {
Closed,
Open,
HalfOpen,
}
impl CircuitBreaker {
fn new(config: CircuitBreakerConfig) -> Self {
Self {
state: CircuitState::Closed,
failure_count: 0,
success_count: 0,
last_state_change: Instant::now(),
config,
}
}
fn record_success(&mut self) {
match self.state {
CircuitState::Closed => {
self.failure_count = 0;
}
CircuitState::HalfOpen => {
self.success_count += 1;
if self.success_count >= self.config.success_threshold {
self.state = CircuitState::Closed;
self.failure_count = 0;
self.success_count = 0;
self.last_state_change = Instant::now();
tracing::info!("Circuit breaker closed after recovery");
}
}
CircuitState::Open => {
self.state = CircuitState::HalfOpen;
self.success_count = 1;
self.last_state_change = Instant::now();
}
}
}
fn record_failure(&mut self) {
match self.state {
CircuitState::Closed => {
self.failure_count += 1;
if self.failure_count >= self.config.failure_threshold {
self.state = CircuitState::Open;
self.last_state_change = Instant::now();
tracing::warn!(
"Circuit breaker opened after {} failures",
self.failure_count
);
}
}
CircuitState::HalfOpen => {
self.state = CircuitState::Open;
self.failure_count = self.config.failure_threshold;
self.success_count = 0;
self.last_state_change = Instant::now();
tracing::warn!("Circuit breaker reopened after failure in half-open state");
}
CircuitState::Open => {
self.failure_count += 1;
}
}
}
fn should_allow_request(&mut self) -> bool {
match self.state {
CircuitState::Closed => true,
CircuitState::HalfOpen => true,
CircuitState::Open => {
if self.last_state_change.elapsed() >= self.config.timeout {
self.state = CircuitState::HalfOpen;
self.success_count = 0;
self.last_state_change = Instant::now();
tracing::info!("Circuit breaker entering half-open state");
true
} else {
false
}
}
}
}
fn get_health_status(&self) -> ServiceHealth {
match self.state {
CircuitState::Closed => ServiceHealth::Healthy,
CircuitState::HalfOpen => ServiceHealth::Degraded,
CircuitState::Open => ServiceHealth::Unhealthy,
}
}
}
pub struct HealthMonitor {
config: FederationConfig,
endpoints: Arc<RwLock<HashMap<String, ServiceEndpoint>>>,
circuit_breakers: Arc<Mutex<HashMap<String, CircuitBreaker>>>,
http_client: Client,
shutdown: Arc<Notify>,
}
impl HealthMonitor {
pub fn new(
config: FederationConfig,
endpoints: Arc<RwLock<HashMap<String, ServiceEndpoint>>>,
) -> Self {
Self {
config,
endpoints,
circuit_breakers: Arc::new(Mutex::new(HashMap::new())),
http_client: Client::builder()
.timeout(Duration::from_secs(5))
.build()
.expect("HTTP client build should succeed"),
shutdown: Arc::new(Notify::new()),
}
}
pub async fn start(&self) -> FusekiResult<()> {
let shutdown = self.shutdown.clone();
let endpoints = self.endpoints.clone();
let circuit_breakers = self.circuit_breakers.clone();
let client = self.http_client.clone();
let circuit_config = self.config.circuit_breaker.clone();
tokio::spawn(async move {
let mut interval = interval(Duration::from_secs(30));
loop {
tokio::select! {
_ = interval.tick() => {
Self::check_all_endpoints(&endpoints, &circuit_breakers, &client, &circuit_config).await;
}
_ = shutdown.notified() => {
tracing::info!("Health monitor shutting down");
break;
}
}
}
});
Self::check_all_endpoints(
&self.endpoints,
&self.circuit_breakers,
&self.http_client,
&self.config.circuit_breaker,
)
.await;
Ok(())
}
pub async fn stop(&self) -> FusekiResult<()> {
self.shutdown.notify_one();
Ok(())
}
async fn check_all_endpoints(
endpoints: &Arc<RwLock<HashMap<String, ServiceEndpoint>>>,
circuit_breakers: &Arc<Mutex<HashMap<String, CircuitBreaker>>>,
client: &Client,
circuit_config: &CircuitBreakerConfig,
) {
let eps = endpoints.read().await.clone();
for (id, endpoint) in eps {
let result = Self::check_endpoint(&endpoint.url, client).await;
let mut breakers = circuit_breakers.lock().await;
let breaker = breakers
.entry(id.clone())
.or_insert_with(|| CircuitBreaker::new(circuit_config.clone()));
if result.success {
breaker.record_success();
} else {
breaker.record_failure();
}
let health_status = breaker.get_health_status();
drop(breakers);
let mut eps = endpoints.write().await;
if let Some(ep) = eps.get_mut(&id) {
ep.health = health_status;
if let Some(response_time) = result.response_time {
if let Some(avg) = &mut ep.capabilities.avg_response_time {
*avg = (*avg + response_time) / 2;
} else {
ep.capabilities.avg_response_time = Some(response_time);
}
}
}
}
}
async fn check_endpoint(url: &Url, client: &Client) -> HealthCheckResult {
let start = Instant::now();
let query = "ASK { ?s ?p ?o } LIMIT 1";
let result = client
.get(url.as_str())
.query(&[("query", query)])
.header("Accept", "application/sparql-results+json")
.send()
.await;
let response_time = start.elapsed();
match result {
Ok(response) => {
let success = response.status().is_success();
let error = if !success {
Some(format!("HTTP {}", response.status()))
} else {
None
};
HealthCheckResult {
url: url.clone(),
timestamp: Instant::now(),
response_time: Some(response_time),
success,
error,
}
}
Err(e) => HealthCheckResult {
url: url.clone(),
timestamp: Instant::now(),
response_time: None,
success: false,
error: Some(e.to_string()),
},
}
}
pub async fn should_use_service(&self, service_id: &str) -> bool {
let mut breakers = self.circuit_breakers.lock().await;
if let Some(breaker) = breakers.get_mut(service_id) {
breaker.should_allow_request()
} else {
true
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_circuit_breaker() {
let config = CircuitBreakerConfig {
failure_threshold: 3,
success_threshold: 2,
timeout: Duration::from_millis(50),
};
let mut breaker = CircuitBreaker::new(config);
assert_eq!(breaker.state, CircuitState::Closed);
assert!(breaker.should_allow_request());
breaker.record_failure();
breaker.record_failure();
assert_eq!(breaker.state, CircuitState::Closed);
breaker.record_failure();
assert_eq!(breaker.state, CircuitState::Open);
assert!(!breaker.should_allow_request());
std::thread::sleep(Duration::from_millis(100));
assert!(breaker.should_allow_request());
assert_eq!(breaker.state, CircuitState::HalfOpen);
breaker.record_success();
assert_eq!(breaker.state, CircuitState::HalfOpen);
breaker.record_success();
assert_eq!(breaker.state, CircuitState::Closed);
}
}