use std::collections::HashMap;
use std::sync::Arc;
use std::time::{Duration, Instant};
use anyhow::Result;
use async_trait::async_trait;
use futures::stream::BoxStream;
use tokio::sync::Mutex;
use brainwires_core::message::{ChatResponse, Message, StreamChunk};
use brainwires_core::provider::{ChatOptions, Provider};
use brainwires_core::tool::Tool;
use crate::error::ResilienceError;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CircuitState {
Closed,
Open,
HalfOpen,
}
#[derive(Debug, Clone)]
pub struct CircuitBreakerConfig {
pub failure_threshold: u32,
pub cooldown: Duration,
}
impl Default for CircuitBreakerConfig {
fn default() -> Self {
Self {
failure_threshold: 5,
cooldown: Duration::from_secs(30),
}
}
}
#[derive(Debug)]
struct Entry {
state: CircuitState,
failures: u32,
open_until: Option<Instant>,
}
impl Default for Entry {
fn default() -> Self {
Self {
state: CircuitState::Closed,
failures: 0,
open_until: None,
}
}
}
pub struct CircuitBreakerProvider<P: Provider + ?Sized> {
inner: Arc<P>,
fallback: Option<Arc<dyn Provider>>,
cfg: CircuitBreakerConfig,
state: Arc<Mutex<HashMap<String, Entry>>>,
}
impl<P: Provider + ?Sized> CircuitBreakerProvider<P> {
pub fn new(inner: Arc<P>, cfg: CircuitBreakerConfig) -> Self {
Self {
inner,
fallback: None,
cfg,
state: Arc::new(Mutex::new(HashMap::new())),
}
}
pub fn with_fallback(mut self, fallback: Arc<dyn Provider>) -> Self {
self.fallback = Some(fallback);
self
}
pub async fn state_for(&self, model: &str) -> CircuitState {
let guard = self.state.lock().await;
let key = self.key(model);
guard.get(&key).map_or(CircuitState::Closed, |e| e.state)
}
fn key(&self, model: &str) -> String {
format!("{}::{}", self.inner.name(), model)
}
async fn transition_in(&self, key: &str) -> Result<(), ResilienceError> {
let mut guard = self.state.lock().await;
let entry = guard.entry(key.to_string()).or_default();
if entry.state == CircuitState::Open {
if let Some(deadline) = entry.open_until
&& Instant::now() >= deadline
{
entry.state = CircuitState::HalfOpen;
} else {
return Err(ResilienceError::CircuitOpen {
provider: self.inner.name().to_string(),
model: key.to_string(),
failures: entry.failures,
});
}
}
Ok(())
}
async fn record_success(&self, key: &str) {
let mut guard = self.state.lock().await;
let entry = guard.entry(key.to_string()).or_default();
entry.state = CircuitState::Closed;
entry.failures = 0;
entry.open_until = None;
}
async fn record_failure(&self, key: &str) {
let mut guard = self.state.lock().await;
let entry = guard.entry(key.to_string()).or_default();
entry.failures = entry.failures.saturating_add(1);
if entry.failures >= self.cfg.failure_threshold || entry.state == CircuitState::HalfOpen {
entry.state = CircuitState::Open;
entry.open_until = Some(Instant::now() + self.cfg.cooldown);
}
}
}
#[async_trait]
impl<P: Provider + ?Sized + 'static> Provider for CircuitBreakerProvider<P> {
fn name(&self) -> &str {
self.inner.name()
}
fn max_output_tokens(&self) -> Option<u32> {
self.inner.max_output_tokens()
}
async fn chat(
&self,
messages: &[Message],
tools: Option<&[Tool]>,
options: &ChatOptions,
) -> Result<ChatResponse> {
let model_label = options.model.as_deref().unwrap_or("default");
let key = self.key(model_label);
if let Err(e) = self.transition_in(&key).await {
if let Some(fallback) = &self.fallback {
tracing::warn!(
provider = self.inner.name(),
model = model_label,
"circuit open; routing to fallback"
);
return fallback.chat(messages, tools, options).await;
}
return Err(e.into());
}
match self.inner.chat(messages, tools, options).await {
Ok(resp) => {
self.record_success(&key).await;
Ok(resp)
}
Err(e) => {
self.record_failure(&key).await;
Err(e)
}
}
}
fn stream_chat<'a>(
&'a self,
messages: &'a [Message],
tools: Option<&'a [Tool]>,
options: &'a ChatOptions,
) -> BoxStream<'a, Result<StreamChunk>> {
self.inner.stream_chat(messages, tools, options)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn state_starts_closed() {
use crate::tests_util::EchoProvider;
let cb = CircuitBreakerProvider::new(
Arc::new(EchoProvider::ok("p1")),
CircuitBreakerConfig::default(),
);
assert_eq!(cb.state_for("any").await, CircuitState::Closed);
}
#[tokio::test]
async fn opens_after_threshold() {
use crate::tests_util::EchoProvider;
let cb = CircuitBreakerProvider::new(
Arc::new(EchoProvider::always_err("p1", "500 internal server error")),
CircuitBreakerConfig {
failure_threshold: 3,
cooldown: Duration::from_millis(50),
},
);
let opts = ChatOptions::default();
for _ in 0..3 {
let _ = cb.chat(&[], None, &opts).await;
}
let key = cb.key("default");
assert_eq!(
cb.state.lock().await.get(&key).map(|e| e.state),
Some(CircuitState::Open),
);
let err = cb.chat(&[], None, &opts).await.unwrap_err();
assert!(matches!(
err.downcast_ref::<ResilienceError>(),
Some(ResilienceError::CircuitOpen { .. }),
));
}
#[tokio::test]
async fn half_open_then_closes_on_success() {
use crate::tests_util::ToggleProvider;
let prov = Arc::new(ToggleProvider::new("p1"));
let cb = CircuitBreakerProvider::new(
prov.clone(),
CircuitBreakerConfig {
failure_threshold: 2,
cooldown: Duration::from_millis(20),
},
);
let opts = ChatOptions::default();
prov.set_fail(true);
let _ = cb.chat(&[], None, &opts).await;
let _ = cb.chat(&[], None, &opts).await;
tokio::time::sleep(Duration::from_millis(30)).await;
prov.set_fail(false);
cb.chat(&[], None, &opts).await.expect("half-open success");
assert_eq!(cb.state_for("default").await, CircuitState::Closed);
}
#[tokio::test]
async fn half_open_reopens_on_failure() {
use crate::tests_util::ToggleProvider;
let prov = Arc::new(ToggleProvider::new("p1"));
let cb = CircuitBreakerProvider::new(
prov.clone(),
CircuitBreakerConfig {
failure_threshold: 2,
cooldown: Duration::from_millis(20),
},
);
let opts = ChatOptions::default();
prov.set_fail(true);
let _ = cb.chat(&[], None, &opts).await;
let _ = cb.chat(&[], None, &opts).await;
assert_eq!(cb.state_for("default").await, CircuitState::Open);
tokio::time::sleep(Duration::from_millis(30)).await;
let err = cb.chat(&[], None, &opts).await.unwrap_err();
assert!(
err.to_string().contains("500"),
"expected provider error from half-open trial, got: {err}"
);
assert_eq!(
cb.state_for("default").await,
CircuitState::Open,
"half-open + failure must re-open the circuit",
);
let err2 = cb.chat(&[], None, &opts).await.unwrap_err();
assert!(matches!(
err2.downcast_ref::<ResilienceError>(),
Some(ResilienceError::CircuitOpen { .. })
));
}
}