use std::collections::HashMap;
use std::sync::RwLock;
use std::time::{Duration, Instant};
use reqwest::{IntoUrl, Method, Request, RequestBuilder, Response};
#[derive(Debug, Clone)]
pub struct CircuitState {
pub state: CircuitStatus,
pub failure_count: u32,
pub success_count: u32,
pub opened_at: Option<Instant>,
pub current_backoff: Duration,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CircuitStatus {
Closed,
Open,
HalfOpen,
}
impl Default for CircuitState {
fn default() -> Self {
Self {
state: CircuitStatus::Closed,
failure_count: 0,
success_count: 0,
opened_at: None,
current_backoff: Duration::from_secs(30),
}
}
}
#[derive(Debug, Clone)]
pub struct CircuitBreakerConfig {
pub failure_threshold: u32,
pub success_threshold: u32,
pub base_timeout: Duration,
pub max_backoff: Duration,
pub backoff_multiplier: f64,
pub enabled: bool,
}
impl Default for CircuitBreakerConfig {
fn default() -> Self {
Self {
failure_threshold: 5,
success_threshold: 2,
base_timeout: Duration::from_secs(30),
max_backoff: Duration::from_secs(600), backoff_multiplier: 1.5,
enabled: true,
}
}
}
#[derive(Debug, Clone)]
pub struct CircuitBreakerOpen {
pub host: String,
pub retry_after: Duration,
}
impl std::fmt::Display for CircuitBreakerOpen {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"Circuit breaker open for {}: retry after {:?}",
self.host, self.retry_after
)
}
}
impl std::error::Error for CircuitBreakerOpen {}
#[derive(Clone)]
pub struct CircuitBreakerClient {
inner: reqwest::Client,
states: std::sync::Arc<RwLock<HashMap<String, CircuitState>>>,
config: CircuitBreakerConfig,
}
impl CircuitBreakerClient {
pub fn new(client: reqwest::Client, config: CircuitBreakerConfig) -> Self {
Self {
inner: client,
states: std::sync::Arc::new(RwLock::new(HashMap::new())),
config,
}
}
pub fn with_defaults(client: reqwest::Client) -> Self {
Self::new(client, CircuitBreakerConfig::default())
}
pub fn inner(&self) -> &reqwest::Client {
&self.inner
}
pub fn with_timeout(&self, timeout: Option<Duration>) -> HttpClient {
HttpClient::new(self.clone(), timeout)
}
fn extract_host(url: &reqwest::Url) -> String {
format!(
"{}://{}{}",
url.scheme(),
url.host_str().unwrap_or("unknown"),
url.port().map(|p| format!(":{}", p)).unwrap_or_default()
)
}
pub fn should_allow(&self, host: &str) -> Result<(), CircuitBreakerOpen> {
if !self.config.enabled {
return Ok(());
}
let states = self.states.read().unwrap_or_else(|e| {
tracing::error!("Circuit breaker lock was poisoned, recovering");
e.into_inner()
});
let state = match states.get(host) {
Some(s) => s,
None => return Ok(()), };
match state.state {
CircuitStatus::Closed => Ok(()),
CircuitStatus::HalfOpen => Ok(()), CircuitStatus::Open => {
let opened_at = state.opened_at.unwrap_or_else(Instant::now);
let elapsed = opened_at.elapsed();
if elapsed >= state.current_backoff {
Ok(())
} else {
Err(CircuitBreakerOpen {
host: host.to_string(),
retry_after: state.current_backoff - elapsed,
})
}
}
}
}
pub fn record_success(&self, host: &str) {
if !self.config.enabled {
return;
}
let mut states = self.states.write().unwrap_or_else(|e| {
tracing::error!("Circuit breaker lock was poisoned, recovering");
e.into_inner()
});
let state = states.entry(host.to_string()).or_default();
match state.state {
CircuitStatus::Closed => {
state.failure_count = 0;
}
CircuitStatus::HalfOpen => {
state.success_count += 1;
if state.success_count >= self.config.success_threshold {
tracing::info!(host = %host, "Circuit breaker closed, service recovered");
state.state = CircuitStatus::Closed;
state.failure_count = 0;
state.success_count = 0;
state.opened_at = None;
state.current_backoff = self.config.base_timeout;
}
}
CircuitStatus::Open => {
tracing::info!(host = %host, "Circuit breaker half-open, testing service");
state.state = CircuitStatus::HalfOpen;
state.success_count = 1;
}
}
}
pub fn record_failure(&self, host: &str) {
if !self.config.enabled {
return;
}
let mut states = self.states.write().unwrap_or_else(|e| {
tracing::error!("Circuit breaker lock was poisoned, recovering");
e.into_inner()
});
let state = states.entry(host.to_string()).or_default();
match state.state {
CircuitStatus::Closed => {
state.failure_count += 1;
if state.failure_count >= self.config.failure_threshold {
tracing::warn!(
host = %host,
failures = state.failure_count,
"Circuit breaker opened, service unhealthy"
);
state.state = CircuitStatus::Open;
state.opened_at = Some(Instant::now());
}
}
CircuitStatus::HalfOpen => {
let new_backoff = Duration::from_secs_f64(
(state.current_backoff.as_secs_f64() * self.config.backoff_multiplier)
.min(self.config.max_backoff.as_secs_f64()),
);
tracing::warn!(
host = %host,
backoff_secs = new_backoff.as_secs(),
"Circuit breaker reopened, service still unhealthy"
);
state.state = CircuitStatus::Open;
state.opened_at = Some(Instant::now());
state.current_backoff = new_backoff;
state.success_count = 0;
}
CircuitStatus::Open => {
state.opened_at = Some(Instant::now());
}
}
}
pub async fn execute(&self, request: Request) -> Result<Response, CircuitBreakerError> {
let host = Self::extract_host(request.url());
self.should_allow(&host)
.map_err(CircuitBreakerError::CircuitOpen)?;
{
let mut states = self.states.write().unwrap_or_else(|e| {
tracing::error!("Circuit breaker lock was poisoned, recovering");
e.into_inner()
});
if let Some(state) = states.get_mut(&host)
&& state.state == CircuitStatus::Open
&& let Some(opened_at) = state.opened_at
&& opened_at.elapsed() >= state.current_backoff
{
tracing::info!(host = %host, "Circuit breaker half-open, testing service");
state.state = CircuitStatus::HalfOpen;
state.success_count = 0;
}
}
match self.inner.execute(request).await {
Ok(response) => {
if response.status().is_server_error() {
self.record_failure(&host);
} else {
self.record_success(&host);
}
Ok(response)
}
Err(e) => {
self.record_failure(&host);
Err(CircuitBreakerError::Request(e))
}
}
}
pub fn get_state(&self, host: &str) -> Option<CircuitState> {
self.states
.read()
.unwrap_or_else(|e| {
tracing::error!("Circuit breaker lock was poisoned, recovering");
e.into_inner()
})
.get(host)
.cloned()
}
pub fn reset(&self, host: &str) {
self.states
.write()
.unwrap_or_else(|e| {
tracing::error!("Circuit breaker lock was poisoned, recovering");
e.into_inner()
})
.remove(host);
}
pub fn reset_all(&self) {
self.states
.write()
.unwrap_or_else(|e| {
tracing::error!("Circuit breaker lock was poisoned, recovering");
e.into_inner()
})
.clear();
}
}
#[derive(Debug)]
pub enum CircuitBreakerError {
CircuitOpen(CircuitBreakerOpen),
Request(reqwest::Error),
}
impl std::fmt::Display for CircuitBreakerError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
CircuitBreakerError::CircuitOpen(e) => write!(f, "{}", e),
CircuitBreakerError::Request(e) => write!(f, "HTTP request failed: {}", e),
}
}
}
impl std::error::Error for CircuitBreakerError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
CircuitBreakerError::CircuitOpen(e) => Some(e),
CircuitBreakerError::Request(e) => Some(e),
}
}
}
impl From<reqwest::Error> for CircuitBreakerError {
fn from(e: reqwest::Error) -> Self {
CircuitBreakerError::Request(e)
}
}
#[derive(Clone)]
pub struct HttpClient {
circuit_breaker: CircuitBreakerClient,
default_timeout: Option<Duration>,
}
impl HttpClient {
pub fn new(circuit_breaker: CircuitBreakerClient, default_timeout: Option<Duration>) -> Self {
Self {
circuit_breaker,
default_timeout,
}
}
pub fn inner(&self) -> &reqwest::Client {
self.circuit_breaker.inner()
}
pub fn circuit_breaker(&self) -> &CircuitBreakerClient {
&self.circuit_breaker
}
pub fn default_timeout(&self) -> Option<Duration> {
self.default_timeout
}
pub fn request<U: IntoUrl>(&self, method: Method, url: U) -> HttpRequestBuilder {
HttpRequestBuilder::new(self.clone(), self.inner().request(method, url))
}
pub fn get<U: IntoUrl>(&self, url: U) -> HttpRequestBuilder {
self.request(Method::GET, url)
}
pub fn post<U: IntoUrl>(&self, url: U) -> HttpRequestBuilder {
self.request(Method::POST, url)
}
pub fn put<U: IntoUrl>(&self, url: U) -> HttpRequestBuilder {
self.request(Method::PUT, url)
}
pub fn patch<U: IntoUrl>(&self, url: U) -> HttpRequestBuilder {
self.request(Method::PATCH, url)
}
pub fn delete<U: IntoUrl>(&self, url: U) -> HttpRequestBuilder {
self.request(Method::DELETE, url)
}
pub fn head<U: IntoUrl>(&self, url: U) -> HttpRequestBuilder {
self.request(Method::HEAD, url)
}
pub async fn execute(&self, mut request: Request) -> crate::Result<Response> {
self.apply_default_timeout(&mut request);
self.circuit_breaker
.execute(request)
.await
.map_err(Into::into)
}
fn apply_default_timeout(&self, request: &mut Request) {
if request.timeout().is_none()
&& let Some(timeout) = self.default_timeout
{
*request.timeout_mut() = Some(timeout);
}
}
}
pub struct HttpRequestBuilder {
client: HttpClient,
request: RequestBuilder,
}
impl HttpRequestBuilder {
fn new(client: HttpClient, request: RequestBuilder) -> Self {
Self { client, request }
}
pub fn header(self, key: impl AsRef<str>, value: impl AsRef<str>) -> Self {
Self {
request: self.request.header(key.as_ref(), value.as_ref()),
..self
}
}
pub fn headers(self, headers: reqwest::header::HeaderMap) -> Self {
Self {
request: self.request.headers(headers),
..self
}
}
pub fn bearer_auth(self, token: impl std::fmt::Display) -> Self {
Self {
request: self.request.bearer_auth(token),
..self
}
}
pub fn basic_auth(
self,
username: impl std::fmt::Display,
password: Option<impl std::fmt::Display>,
) -> Self {
Self {
request: self.request.basic_auth(username, password),
..self
}
}
pub fn body(self, body: impl Into<reqwest::Body>) -> Self {
Self {
request: self.request.body(body),
..self
}
}
pub fn json(self, json: &impl serde::Serialize) -> Self {
Self {
request: self.request.json(json),
..self
}
}
pub fn form(self, form: &impl serde::Serialize) -> Self {
Self {
request: self.request.form(form),
..self
}
}
pub fn query(self, query: &impl serde::Serialize) -> Self {
Self {
request: self.request.query(query),
..self
}
}
pub fn timeout(self, timeout: Duration) -> Self {
Self {
request: self.request.timeout(timeout),
..self
}
}
pub fn version(self, version: reqwest::Version) -> Self {
Self {
request: self.request.version(version),
..self
}
}
pub fn try_clone(&self) -> Option<Self> {
self.request.try_clone().map(|request| Self {
client: self.client.clone(),
request,
})
}
pub fn build(self) -> crate::Result<Request> {
self.request
.build()
.map_err(|e| crate::ForgeError::Internal(e.to_string()))
}
pub async fn send(self) -> crate::Result<Response> {
let client = self.client.clone();
let request = self.build()?;
client.execute(request).await
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::indexing_slicing)]
mod tests {
use super::*;
#[test]
fn test_circuit_breaker_defaults() {
let config = CircuitBreakerConfig::default();
assert_eq!(config.failure_threshold, 5);
assert_eq!(config.success_threshold, 2);
assert!(config.enabled);
}
#[test]
fn test_circuit_state_transitions() {
let client = reqwest::Client::new();
let breaker = CircuitBreakerClient::with_defaults(client);
let host = "https://api.example.com";
assert!(breaker.should_allow(host).is_ok());
for _ in 0..5 {
breaker.record_failure(host);
}
let state = breaker.get_state(host).unwrap();
assert_eq!(state.state, CircuitStatus::Open);
assert!(breaker.should_allow(host).is_err());
breaker.reset(host);
assert!(breaker.should_allow(host).is_ok());
}
#[test]
fn test_extract_host() {
let url = reqwest::Url::parse("https://api.example.com:8080/path").unwrap();
assert_eq!(
CircuitBreakerClient::extract_host(&url),
"https://api.example.com:8080"
);
let url2 = reqwest::Url::parse("http://localhost/api").unwrap();
assert_eq!(
CircuitBreakerClient::extract_host(&url2),
"http://localhost"
);
}
#[test]
fn test_http_client_applies_default_timeout_when_missing() {
let breaker = CircuitBreakerClient::with_defaults(reqwest::Client::new());
let client = breaker.with_timeout(Some(Duration::from_secs(5)));
let mut request = reqwest::Request::new(
Method::GET,
reqwest::Url::parse("https://example.com").unwrap(),
);
client.apply_default_timeout(&mut request);
assert_eq!(request.timeout(), Some(&Duration::from_secs(5)));
}
#[test]
fn test_http_client_preserves_explicit_timeout() {
let breaker = CircuitBreakerClient::with_defaults(reqwest::Client::new());
let client = breaker.with_timeout(Some(Duration::from_secs(5)));
let mut request = reqwest::Request::new(
Method::GET,
reqwest::Url::parse("https://example.com").unwrap(),
);
*request.timeout_mut() = Some(Duration::from_secs(1));
client.apply_default_timeout(&mut request);
assert_eq!(request.timeout(), Some(&Duration::from_secs(1)));
}
}