#![allow(missing_docs)]
use std::collections::HashMap;
use std::sync::Arc;
use parking_lot::RwLock;
use std::time::{Duration, Instant};
use parking_lot::Mutex;
use crate::auth::TenantScope;
use crate::error::Error;
#[derive(Debug, Clone)]
pub struct CircuitConfig {
pub failure_threshold: u32,
pub initial_open_duration: Duration,
pub max_open_duration: Duration,
pub backoff_multiplier: f64,
}
impl Default for CircuitConfig {
fn default() -> Self {
Self {
failure_threshold: 5,
initial_open_duration: Duration::from_secs(30),
max_open_duration: Duration::from_secs(300),
backoff_multiplier: 2.0,
}
}
}
#[derive(Debug)]
enum CircuitState {
Closed {
consecutive_failures: u32,
},
Open {
until: Instant,
prev_duration: Duration,
},
HalfOpen,
}
pub struct ProviderCircuit {
state: Mutex<CircuitState>,
config: CircuitConfig,
}
pub struct CircuitPermit {
circuit: Arc<ProviderCircuit>,
consumed: std::sync::atomic::AtomicBool,
}
impl std::fmt::Debug for CircuitPermit {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str("CircuitPermit")
}
}
impl CircuitPermit {
pub fn record_success(self) {
self.consumed
.store(true, std::sync::atomic::Ordering::SeqCst);
self.circuit.record_success();
}
pub fn record_failure(self) {
self.consumed
.store(true, std::sync::atomic::Ordering::SeqCst);
self.circuit.record_failure();
}
}
impl Drop for CircuitPermit {
fn drop(&mut self) {
if !self.consumed.load(std::sync::atomic::Ordering::SeqCst) {
self.circuit.record_failure();
}
}
}
impl ProviderCircuit {
pub fn new(config: CircuitConfig) -> Self {
Self {
state: Mutex::new(CircuitState::Closed {
consecutive_failures: 0,
}),
config,
}
}
pub fn permit(self: &Arc<Self>) -> Result<CircuitPermit, Error> {
let mut state = self.state.lock();
match *state {
CircuitState::Closed { .. } => Ok(CircuitPermit {
circuit: Arc::clone(self),
consumed: std::sync::atomic::AtomicBool::new(false),
}),
CircuitState::Open {
until,
prev_duration,
} => {
if Instant::now() >= until {
*state = CircuitState::HalfOpen;
Ok(CircuitPermit {
circuit: Arc::clone(self),
consumed: std::sync::atomic::AtomicBool::new(false),
})
} else {
Err(Error::CircuitOpen {
until,
prev_duration,
})
}
}
CircuitState::HalfOpen => Err(Error::CircuitOpen {
until: Instant::now() + Duration::from_millis(50),
prev_duration: Duration::ZERO,
}),
}
}
fn record_success(&self) {
let mut state = self.state.lock();
*state = CircuitState::Closed {
consecutive_failures: 0,
};
}
fn record_failure(&self) {
let mut state = self.state.lock();
match *state {
CircuitState::Closed {
consecutive_failures,
} => {
let n = consecutive_failures + 1;
*state = if n >= self.config.failure_threshold {
CircuitState::Open {
until: Instant::now() + self.config.initial_open_duration,
prev_duration: self.config.initial_open_duration,
}
} else {
CircuitState::Closed {
consecutive_failures: n,
}
};
}
CircuitState::HalfOpen => {
let new_dur_secs = self.config.initial_open_duration.as_secs_f64()
* self.config.backoff_multiplier;
let new_dur =
Duration::from_secs_f64(new_dur_secs).min(self.config.max_open_duration);
*state = CircuitState::Open {
until: Instant::now() + new_dur,
prev_duration: new_dur,
};
}
CircuitState::Open { .. } => { }
}
}
}
#[derive(Hash, Eq, PartialEq, Clone, Debug)]
pub struct CircuitKey {
pub tenant_id: String,
pub provider: String,
}
pub struct CircuitTracker {
circuits: RwLock<HashMap<CircuitKey, Arc<ProviderCircuit>>>,
config: CircuitConfig,
}
impl CircuitTracker {
pub fn new(config: CircuitConfig) -> Self {
Self {
circuits: RwLock::new(HashMap::new()),
config,
}
}
pub fn circuit_for(&self, scope: &TenantScope, provider: &str) -> Arc<ProviderCircuit> {
let key = CircuitKey {
tenant_id: scope.tenant_id.clone(),
provider: provider.to_string(),
};
if let Some(c) = self.circuits.read().get(&key) {
return Arc::clone(c);
}
let mut g = self.circuits.write();
Arc::clone(
g.entry(key)
.or_insert_with(|| Arc::new(ProviderCircuit::new(self.config.clone()))),
)
}
}
pub fn is_circuit_failure(err: &Error) -> bool {
use crate::llm::error_class::ErrorClass;
matches!(
crate::llm::error_class::classify(err),
ErrorClass::ServerError | ErrorClass::RateLimited | ErrorClass::Network
)
}
pub struct CircuitBreakerProvider<P: super::LlmProvider> {
inner: P,
tracker: Arc<CircuitTracker>,
provider_name: String,
scope: TenantScope,
}
impl<P: super::LlmProvider> CircuitBreakerProvider<P> {
pub fn new(
inner: P,
tracker: Arc<CircuitTracker>,
provider_name: impl Into<String>,
scope: TenantScope,
) -> Self {
Self {
inner,
tracker,
provider_name: provider_name.into(),
scope,
}
}
}
impl<P: super::LlmProvider> super::LlmProvider for CircuitBreakerProvider<P> {
fn model_name(&self) -> Option<&str> {
self.inner.model_name()
}
async fn complete(
&self,
request: super::types::CompletionRequest,
) -> Result<super::types::CompletionResponse, Error> {
let circuit = self.tracker.circuit_for(&self.scope, &self.provider_name);
let permit = circuit.permit()?;
let result = self.inner.complete(request).await;
match &result {
Ok(_) => permit.record_success(),
Err(e) if is_circuit_failure(e) => permit.record_failure(),
Err(_) => permit.record_success(),
}
result
}
async fn stream_complete(
&self,
request: super::types::CompletionRequest,
on_text: &super::OnText,
) -> Result<super::types::CompletionResponse, Error> {
let circuit = self.tracker.circuit_for(&self.scope, &self.provider_name);
let permit = circuit.permit()?;
let result = self.inner.stream_complete(request, on_text).await;
match &result {
Ok(_) => permit.record_success(),
Err(e) if is_circuit_failure(e) => permit.record_failure(),
Err(_) => permit.record_success(),
}
result
}
}
#[cfg(test)]
mod tests {
use super::*;
fn cfg() -> CircuitConfig {
CircuitConfig {
failure_threshold: 3,
initial_open_duration: Duration::from_millis(50),
max_open_duration: Duration::from_millis(500),
backoff_multiplier: 2.0,
}
}
#[test]
fn closed_circuit_passes_requests() {
let c = Arc::new(ProviderCircuit::new(cfg()));
let p = c.permit().unwrap();
p.record_success();
}
#[test]
fn n_failures_open_circuit() {
let c = Arc::new(ProviderCircuit::new(cfg()));
for _ in 0..3 {
let p = c.permit().unwrap();
p.record_failure();
}
let err = c.permit().unwrap_err();
assert!(matches!(err, Error::CircuitOpen { .. }));
}
#[test]
fn success_resets_consecutive_failures() {
let c = Arc::new(ProviderCircuit::new(cfg()));
c.permit().unwrap().record_failure();
c.permit().unwrap().record_failure();
c.permit().unwrap().record_success();
c.permit().unwrap().record_failure();
assert!(c.permit().is_ok());
}
#[test]
fn open_transitions_to_half_open_after_duration() {
let c = Arc::new(ProviderCircuit::new(cfg()));
for _ in 0..3 {
c.permit().unwrap().record_failure();
}
std::thread::sleep(Duration::from_millis(60));
assert!(c.permit().is_ok(), "should be HalfOpen permit");
}
#[test]
fn half_open_success_closes_circuit() {
let c = Arc::new(ProviderCircuit::new(cfg()));
for _ in 0..3 {
c.permit().unwrap().record_failure();
}
std::thread::sleep(Duration::from_millis(60));
c.permit().unwrap().record_success();
for _ in 0..10 {
let p = c.permit();
assert!(p.is_ok());
p.unwrap().record_success();
}
}
#[test]
fn half_open_failure_reopens_with_doubled_duration() {
let c = Arc::new(ProviderCircuit::new(cfg()));
for _ in 0..3 {
c.permit().unwrap().record_failure();
}
std::thread::sleep(Duration::from_millis(70));
c.permit().unwrap().record_failure();
std::thread::sleep(Duration::from_millis(60));
assert!(c.permit().is_err());
std::thread::sleep(Duration::from_millis(60));
assert!(c.permit().is_ok());
}
#[test]
fn repeated_half_open_failures_clamp_at_max() {
let c = Arc::new(ProviderCircuit::new(CircuitConfig {
failure_threshold: 1,
initial_open_duration: Duration::from_millis(100),
max_open_duration: Duration::from_millis(150),
backoff_multiplier: 4.0,
}));
c.permit().unwrap().record_failure(); std::thread::sleep(Duration::from_millis(110));
c.permit().unwrap().record_failure(); std::thread::sleep(Duration::from_millis(160));
assert!(
c.permit().is_ok(),
"should be openable again at clamped duration"
);
}
#[tokio::test(flavor = "multi_thread")]
async fn permit_can_be_moved_across_await() {
let c = Arc::new(ProviderCircuit::new(cfg()));
let p = c.permit().unwrap();
let task = tokio::spawn(async move {
tokio::task::yield_now().await;
p.record_success();
});
task.await.unwrap();
}
#[tokio::test(flavor = "multi_thread")]
async fn concurrent_requests_during_half_open_only_one_probes() {
let c = Arc::new(ProviderCircuit::new(CircuitConfig {
failure_threshold: 1,
initial_open_duration: Duration::from_millis(20),
max_open_duration: Duration::from_millis(200),
backoff_multiplier: 2.0,
}));
c.permit().unwrap().record_failure(); tokio::time::sleep(Duration::from_millis(30)).await;
let probe = c.permit().expect("first probe granted");
let second = c.permit();
assert!(matches!(second, Err(Error::CircuitOpen { .. })));
probe.record_success();
assert!(c.permit().is_ok());
}
#[test]
fn tracker_returns_same_arc_for_same_key() {
let t = CircuitTracker::new(cfg());
let a = t.circuit_for(&TenantScope::new("acme"), "anthropic");
let b = t.circuit_for(&TenantScope::new("acme"), "anthropic");
assert!(Arc::ptr_eq(&a, &b));
}
#[test]
fn tracker_isolates_tenants() {
let t = CircuitTracker::new(cfg());
let a = t.circuit_for(&TenantScope::new("acme"), "anthropic");
let b = t.circuit_for(&TenantScope::new("globex"), "anthropic");
assert!(!Arc::ptr_eq(&a, &b));
}
#[test]
fn tracker_isolates_providers() {
let t = CircuitTracker::new(cfg());
let a = t.circuit_for(&TenantScope::new("acme"), "anthropic");
let b = t.circuit_for(&TenantScope::new("acme"), "openai");
assert!(!Arc::ptr_eq(&a, &b));
}
#[test]
fn is_circuit_failure_classifies_correctly() {
let server = Error::Api {
status: 503,
message: "service unavailable".into(),
};
assert!(is_circuit_failure(&server));
let rate = Error::Api {
status: 429,
message: "too many requests".into(),
};
assert!(is_circuit_failure(&rate));
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.expect("test runtime");
let http_err = rt
.block_on(reqwest::get("http://[::0]:1"))
.expect_err("should fail");
assert!(is_circuit_failure(&Error::Http(http_err)));
let auth = Error::Api {
status: 401,
message: "unauthorized".into(),
};
assert!(!is_circuit_failure(&auth));
let bad = Error::Api {
status: 400,
message: "bad json".into(),
};
assert!(!is_circuit_failure(&bad));
}
use crate::llm::LlmProvider;
use crate::llm::types::{CompletionRequest, Message};
struct FailingProvider {
error: Box<dyn Fn() -> Error + Send + Sync>,
}
impl LlmProvider for FailingProvider {
async fn complete(
&self,
_r: CompletionRequest,
) -> Result<crate::llm::types::CompletionResponse, Error> {
Err((self.error)())
}
}
fn dummy_request() -> CompletionRequest {
CompletionRequest {
system: "test".into(),
messages: vec![Message::user("hi")],
tools: vec![],
max_tokens: 10,
tool_choice: None,
reasoning_effort: None,
}
}
#[tokio::test(flavor = "multi_thread")]
async fn circuit_opens_after_threshold_failures() {
let tracker = Arc::new(CircuitTracker::new(CircuitConfig {
failure_threshold: 3,
initial_open_duration: Duration::from_secs(60),
max_open_duration: Duration::from_secs(120),
backoff_multiplier: 2.0,
}));
let inner = FailingProvider {
error: Box::new(|| Error::Api {
status: 503,
message: "down".into(),
}),
};
let wrapper = CircuitBreakerProvider::new(
inner,
tracker.clone(),
"anthropic",
TenantScope::new("acme"),
);
for _ in 0..3 {
let _ = wrapper.complete(dummy_request()).await;
}
let err = wrapper.complete(dummy_request()).await.unwrap_err();
assert!(matches!(err, Error::CircuitOpen { .. }));
}
#[tokio::test(flavor = "multi_thread")]
async fn auth_errors_do_not_trip_circuit() {
let tracker = Arc::new(CircuitTracker::new(cfg()));
let inner = FailingProvider {
error: Box::new(|| Error::Api {
status: 401,
message: "no key".into(),
}),
};
let wrapper = CircuitBreakerProvider::new(
inner,
tracker.clone(),
"anthropic",
TenantScope::new("acme"),
);
for _ in 0..10 {
let _ = wrapper.complete(dummy_request()).await;
}
let circuit = tracker.circuit_for(&TenantScope::new("acme"), "anthropic");
assert!(circuit.permit().is_ok());
}
#[tokio::test(flavor = "multi_thread")]
async fn circuit_outer_retry_inner_one_permit_per_outer_call() {
let tracker = Arc::new(CircuitTracker::new(CircuitConfig {
failure_threshold: 2,
initial_open_duration: Duration::from_secs(60),
max_open_duration: Duration::from_secs(120),
backoff_multiplier: 2.0,
}));
let inner = FailingProvider {
error: Box::new(|| Error::Api {
status: 503,
message: "down".into(),
}),
};
let retrying = crate::llm::retry::RetryingProvider::new(
inner,
crate::llm::retry::RetryConfig {
max_retries: 3,
base_delay: Duration::from_millis(1),
max_delay: Duration::from_millis(10),
},
);
let wrapper = CircuitBreakerProvider::new(
retrying,
tracker.clone(),
"anthropic",
TenantScope::new("acme"),
);
let _ = wrapper.complete(dummy_request()).await;
let _ = wrapper.complete(dummy_request()).await;
let err = wrapper.complete(dummy_request()).await.unwrap_err();
assert!(matches!(err, Error::CircuitOpen { .. }));
}
#[test]
fn permit_drop_without_consume_records_failure() {
let c = Arc::new(ProviderCircuit::new(CircuitConfig {
failure_threshold: 1,
initial_open_duration: Duration::from_millis(50),
max_open_duration: Duration::from_millis(500),
backoff_multiplier: 2.0,
}));
let permit = c.permit().unwrap();
drop(permit); assert!(
c.permit().is_err(),
"circuit should be open after unconsumed permit drop"
);
}
}