use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
use std::sync::Arc;
use std::time::{Duration, Instant};
use reqwest::Client;
use tokio::sync::{Mutex, Semaphore};
use tokio::time::sleep;
use crate::service::{
config::{CircuitBreakerConfig, ConnectionPoolConfig},
types::ServiceError,
};
#[derive(Debug, Clone, PartialEq)]
pub enum CircuitState {
Closed,
Open { opened_at: Instant },
HalfOpen,
}
#[derive(Debug)]
pub struct EnhancedHttpClient {
client: Client,
rate_limiter: Arc<Semaphore>,
rate_limit_delay: Option<Duration>,
last_request_time: Arc<Mutex<Option<Instant>>>,
circuit_state: Arc<Mutex<CircuitState>>,
circuit_config: CircuitBreakerConfig,
failure_count: Arc<AtomicU32>,
success_count: Arc<AtomicU32>,
request_count: Arc<AtomicU64>,
}
impl EnhancedHttpClient {
pub fn new(
pool_config: Option<&ConnectionPoolConfig>,
circuit_config: Option<&CircuitBreakerConfig>,
rate_limit_ms: Option<u64>,
) -> Result<Self, ServiceError> {
let mut client_builder = Client::builder();
if let Some(pool) = pool_config {
if let Some(max_conn) = pool.max_connections {
client_builder = client_builder.pool_max_idle_per_host(max_conn);
}
if let Some(idle_timeout) = pool.idle_timeout_seconds {
client_builder =
client_builder.pool_idle_timeout(Duration::from_secs(idle_timeout));
}
if let Some(keep_alive) = pool.keep_alive_seconds {
client_builder = client_builder.tcp_keepalive(Duration::from_secs(keep_alive));
}
}
client_builder = client_builder.gzip(true).deflate(true);
let client = client_builder.build().map_err(|e| {
ServiceError::InternalError(format!("Failed to create HTTP client: {}", e))
})?;
let rate_limiter = Arc::new(Semaphore::new(1));
let rate_limit_delay = rate_limit_ms.map(Duration::from_millis);
let circuit_config = circuit_config.cloned().unwrap_or_default();
Ok(Self {
client,
rate_limiter,
rate_limit_delay,
last_request_time: Arc::new(Mutex::new(None)),
circuit_state: Arc::new(Mutex::new(CircuitState::Closed)),
circuit_config,
failure_count: Arc::new(AtomicU32::new(0)),
success_count: Arc::new(AtomicU32::new(0)),
request_count: Arc::new(AtomicU64::new(0)),
})
}
pub async fn get(&self, url: &str) -> Result<reqwest::Response, ServiceError> {
self.execute_request(|| self.client.get(url)).await
}
async fn execute_request<F>(
&self,
request_builder: F,
) -> Result<reqwest::Response, ServiceError>
where
F: FnOnce() -> reqwest::RequestBuilder,
{
self.check_circuit_breaker().await?;
let _permit = self
.rate_limiter
.acquire()
.await
.map_err(|e| ServiceError::InternalError(format!("Rate limiter error: {}", e)))?;
if let Some(delay) = self.rate_limit_delay {
let mut last_time = self.last_request_time.lock().await;
if let Some(last) = *last_time {
let elapsed = last.elapsed();
if elapsed < delay {
let sleep_duration = delay - elapsed;
sleep(sleep_duration).await;
}
}
*last_time = Some(Instant::now());
}
self.request_count.fetch_add(1, Ordering::Relaxed);
let result = request_builder()
.send()
.await
.map_err(|e| ServiceError::InternalError(format!("HTTP request failed: {}", e)));
match &result {
Ok(response) => {
if response.status().is_success() {
self.record_success().await;
} else {
self.record_failure().await;
}
}
Err(_) => {
self.record_failure().await;
}
}
result
}
async fn check_circuit_breaker(&self) -> Result<(), ServiceError> {
let mut state = self.circuit_state.lock().await;
match *state {
CircuitState::Closed => Ok(()),
CircuitState::Open { opened_at } => {
let recovery_timeout =
Duration::from_secs(self.circuit_config.recovery_timeout_seconds.unwrap_or(60));
if opened_at.elapsed() >= recovery_timeout {
*state = CircuitState::HalfOpen;
self.success_count.store(0, Ordering::Relaxed);
tracing::info!("Circuit breaker transitioning to half-open state");
Ok(())
} else {
Err(ServiceError::CircuitBreakerOpen)
}
}
CircuitState::HalfOpen => Ok(()),
}
}
async fn record_success(&self) {
let mut state = self.circuit_state.lock().await;
match *state {
CircuitState::Closed => {
self.failure_count.store(0, Ordering::Relaxed);
}
CircuitState::HalfOpen => {
let success_count = self.success_count.fetch_add(1, Ordering::Relaxed) + 1;
let success_threshold = self.circuit_config.success_threshold.unwrap_or(3);
if success_count >= success_threshold {
*state = CircuitState::Closed;
self.failure_count.store(0, Ordering::Relaxed);
self.success_count.store(0, Ordering::Relaxed);
tracing::info!(
"Circuit breaker closed after {} successful requests",
success_count
);
}
}
CircuitState::Open { .. } => {
*state = CircuitState::Closed;
self.failure_count.store(0, Ordering::Relaxed);
}
}
}
async fn record_failure(&self) {
let failure_count = self.failure_count.fetch_add(1, Ordering::Relaxed) + 1;
let failure_threshold = self.circuit_config.failure_threshold.unwrap_or(5);
if failure_count >= failure_threshold {
let mut state = self.circuit_state.lock().await;
match *state {
CircuitState::Closed | CircuitState::HalfOpen => {
*state = CircuitState::Open {
opened_at: Instant::now(),
};
tracing::warn!("Circuit breaker opened after {} failures", failure_count);
}
CircuitState::Open { .. } => {
}
}
}
}
pub async fn circuit_state(&self) -> CircuitState {
self.circuit_state.lock().await.clone()
}
pub fn get_stats(&self) -> ClientStats {
ClientStats {
total_requests: self.request_count.load(Ordering::Relaxed),
failure_count: self.failure_count.load(Ordering::Relaxed),
success_count: self.success_count.load(Ordering::Relaxed),
}
}
}
#[derive(Debug, Clone)]
pub struct ClientStats {
pub total_requests: u64,
pub failure_count: u32,
pub success_count: u32,
}