1#![allow(missing_docs)]
17use std::collections::HashMap;
18use std::sync::Arc;
19
20use parking_lot::RwLock;
21use std::time::{Duration, Instant};
22
23use parking_lot::Mutex;
24
25use crate::auth::TenantScope;
26use crate::error::Error;
27
28#[derive(Debug, Clone)]
36pub struct CircuitConfig {
37 pub failure_threshold: u32,
38 pub initial_open_duration: Duration,
39 pub max_open_duration: Duration,
40 pub backoff_multiplier: f64,
41}
42
43impl Default for CircuitConfig {
44 fn default() -> Self {
45 Self {
46 failure_threshold: 5,
47 initial_open_duration: Duration::from_secs(30),
48 max_open_duration: Duration::from_secs(300),
49 backoff_multiplier: 2.0,
50 }
51 }
52}
53
54#[derive(Debug)]
55enum CircuitState {
56 Closed {
57 consecutive_failures: u32,
58 },
59 Open {
60 until: Instant,
61 prev_duration: Duration,
62 },
63 HalfOpen,
64}
65
66pub struct ProviderCircuit {
76 state: Mutex<CircuitState>,
77 config: CircuitConfig,
78}
79
80pub struct CircuitPermit {
88 circuit: Arc<ProviderCircuit>,
89 consumed: std::sync::atomic::AtomicBool,
90}
91
92impl std::fmt::Debug for CircuitPermit {
93 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
94 f.write_str("CircuitPermit")
95 }
96}
97
98impl CircuitPermit {
99 pub fn record_success(self) {
100 self.consumed
101 .store(true, std::sync::atomic::Ordering::SeqCst);
102 self.circuit.record_success();
103 }
104 pub fn record_failure(self) {
105 self.consumed
106 .store(true, std::sync::atomic::Ordering::SeqCst);
107 self.circuit.record_failure();
108 }
109}
110
111impl Drop for CircuitPermit {
112 fn drop(&mut self) {
113 if !self.consumed.load(std::sync::atomic::Ordering::SeqCst) {
117 self.circuit.record_failure();
118 }
119 }
120}
121
122impl ProviderCircuit {
123 pub fn new(config: CircuitConfig) -> Self {
124 Self {
125 state: Mutex::new(CircuitState::Closed {
126 consecutive_failures: 0,
127 }),
128 config,
129 }
130 }
131
132 pub fn permit(self: &Arc<Self>) -> Result<CircuitPermit, Error> {
135 let mut state = self.state.lock();
136 match *state {
137 CircuitState::Closed { .. } => Ok(CircuitPermit {
138 circuit: Arc::clone(self),
139 consumed: std::sync::atomic::AtomicBool::new(false),
140 }),
141 CircuitState::Open {
142 until,
143 prev_duration,
144 } => {
145 if Instant::now() >= until {
146 *state = CircuitState::HalfOpen;
147 Ok(CircuitPermit {
148 circuit: Arc::clone(self),
149 consumed: std::sync::atomic::AtomicBool::new(false),
150 })
151 } else {
152 Err(Error::CircuitOpen {
153 until,
154 prev_duration,
155 })
156 }
157 }
158 CircuitState::HalfOpen => Err(Error::CircuitOpen {
159 until: Instant::now() + Duration::from_millis(50),
160 prev_duration: Duration::ZERO,
161 }),
162 }
163 }
164
165 fn record_success(&self) {
166 let mut state = self.state.lock();
167 *state = CircuitState::Closed {
168 consecutive_failures: 0,
169 };
170 }
171
172 fn record_failure(&self) {
173 let mut state = self.state.lock();
174 match *state {
175 CircuitState::Closed {
176 consecutive_failures,
177 } => {
178 let n = consecutive_failures + 1;
179 *state = if n >= self.config.failure_threshold {
180 CircuitState::Open {
181 until: Instant::now() + self.config.initial_open_duration,
182 prev_duration: self.config.initial_open_duration,
183 }
184 } else {
185 CircuitState::Closed {
186 consecutive_failures: n,
187 }
188 };
189 }
190 CircuitState::HalfOpen => {
191 let new_dur_secs = self.config.initial_open_duration.as_secs_f64()
192 * self.config.backoff_multiplier;
193 let new_dur =
194 Duration::from_secs_f64(new_dur_secs).min(self.config.max_open_duration);
195 *state = CircuitState::Open {
196 until: Instant::now() + new_dur,
197 prev_duration: new_dur,
198 };
199 }
200 CircuitState::Open { .. } => { }
201 }
202 }
203}
204
205#[derive(Hash, Eq, PartialEq, Clone, Debug)]
211pub struct CircuitKey {
212 pub tenant_id: String,
213 pub provider: String,
214}
215
216pub struct CircuitTracker {
222 circuits: RwLock<HashMap<CircuitKey, Arc<ProviderCircuit>>>,
223 config: CircuitConfig,
224}
225
226impl CircuitTracker {
227 pub fn new(config: CircuitConfig) -> Self {
228 Self {
229 circuits: RwLock::new(HashMap::new()),
230 config,
231 }
232 }
233
234 pub fn circuit_for(&self, scope: &TenantScope, provider: &str) -> Arc<ProviderCircuit> {
236 let key = CircuitKey {
237 tenant_id: scope.tenant_id.clone(),
238 provider: provider.to_string(),
239 };
240 if let Some(c) = self.circuits.read().get(&key) {
242 return Arc::clone(c);
243 }
244 let mut g = self.circuits.write();
246 Arc::clone(
247 g.entry(key)
248 .or_insert_with(|| Arc::new(ProviderCircuit::new(self.config.clone()))),
249 )
250 }
251}
252
253pub fn is_circuit_failure(err: &Error) -> bool {
262 use crate::llm::error_class::ErrorClass;
263 matches!(
264 crate::llm::error_class::classify(err),
265 ErrorClass::ServerError | ErrorClass::RateLimited | ErrorClass::Network
266 )
267}
268
269pub struct CircuitBreakerProvider<P: super::LlmProvider> {
300 inner: P,
301 tracker: Arc<CircuitTracker>,
302 provider_name: String,
303 scope: TenantScope,
304}
305
306impl<P: super::LlmProvider> CircuitBreakerProvider<P> {
307 pub fn new(
308 inner: P,
309 tracker: Arc<CircuitTracker>,
310 provider_name: impl Into<String>,
311 scope: TenantScope,
312 ) -> Self {
313 Self {
314 inner,
315 tracker,
316 provider_name: provider_name.into(),
317 scope,
318 }
319 }
320}
321
322impl<P: super::LlmProvider> super::LlmProvider for CircuitBreakerProvider<P> {
323 fn model_name(&self) -> Option<&str> {
324 self.inner.model_name()
325 }
326
327 async fn complete(
328 &self,
329 request: super::types::CompletionRequest,
330 ) -> Result<super::types::CompletionResponse, Error> {
331 let circuit = self.tracker.circuit_for(&self.scope, &self.provider_name);
332 let permit = circuit.permit()?;
333 let result = self.inner.complete(request).await;
334 match &result {
335 Ok(_) => permit.record_success(),
336 Err(e) if is_circuit_failure(e) => permit.record_failure(),
338 Err(_) => permit.record_success(),
342 }
343 result
344 }
345
346 async fn stream_complete(
347 &self,
348 request: super::types::CompletionRequest,
349 on_text: &super::OnText,
350 ) -> Result<super::types::CompletionResponse, Error> {
351 let circuit = self.tracker.circuit_for(&self.scope, &self.provider_name);
352 let permit = circuit.permit()?;
353 let result = self.inner.stream_complete(request, on_text).await;
354 match &result {
355 Ok(_) => permit.record_success(),
356 Err(e) if is_circuit_failure(e) => permit.record_failure(),
357 Err(_) => permit.record_success(),
358 }
359 result
360 }
361}
362
363#[cfg(test)]
364mod tests {
365 use super::*;
366
367 fn cfg() -> CircuitConfig {
368 CircuitConfig {
369 failure_threshold: 3,
370 initial_open_duration: Duration::from_millis(50),
371 max_open_duration: Duration::from_millis(500),
372 backoff_multiplier: 2.0,
373 }
374 }
375
376 #[test]
377 fn closed_circuit_passes_requests() {
378 let c = Arc::new(ProviderCircuit::new(cfg()));
379 let p = c.permit().unwrap();
380 p.record_success();
381 }
382
383 #[test]
384 fn n_failures_open_circuit() {
385 let c = Arc::new(ProviderCircuit::new(cfg()));
386 for _ in 0..3 {
387 let p = c.permit().unwrap();
388 p.record_failure();
389 }
390 let err = c.permit().unwrap_err();
391 assert!(matches!(err, Error::CircuitOpen { .. }));
392 }
393
394 #[test]
395 fn success_resets_consecutive_failures() {
396 let c = Arc::new(ProviderCircuit::new(cfg()));
397 c.permit().unwrap().record_failure();
398 c.permit().unwrap().record_failure();
399 c.permit().unwrap().record_success();
400 c.permit().unwrap().record_failure();
402 assert!(c.permit().is_ok());
403 }
404
405 #[test]
406 fn open_transitions_to_half_open_after_duration() {
407 let c = Arc::new(ProviderCircuit::new(cfg()));
408 for _ in 0..3 {
409 c.permit().unwrap().record_failure();
410 }
411 std::thread::sleep(Duration::from_millis(60));
412 assert!(c.permit().is_ok(), "should be HalfOpen permit");
413 }
414
415 #[test]
416 fn half_open_success_closes_circuit() {
417 let c = Arc::new(ProviderCircuit::new(cfg()));
418 for _ in 0..3 {
419 c.permit().unwrap().record_failure();
420 }
421 std::thread::sleep(Duration::from_millis(60));
422 c.permit().unwrap().record_success();
423 for _ in 0..10 {
426 let p = c.permit();
427 assert!(p.is_ok());
428 p.unwrap().record_success();
429 }
430 }
431
432 #[test]
433 fn half_open_failure_reopens_with_doubled_duration() {
434 let c = Arc::new(ProviderCircuit::new(cfg()));
435 for _ in 0..3 {
436 c.permit().unwrap().record_failure();
437 }
438 std::thread::sleep(Duration::from_millis(70));
440 c.permit().unwrap().record_failure();
442 std::thread::sleep(Duration::from_millis(60));
444 assert!(c.permit().is_err());
445 std::thread::sleep(Duration::from_millis(60));
447 assert!(c.permit().is_ok());
448 }
449
450 #[test]
451 fn repeated_half_open_failures_clamp_at_max() {
452 let c = Arc::new(ProviderCircuit::new(CircuitConfig {
453 failure_threshold: 1,
454 initial_open_duration: Duration::from_millis(100),
455 max_open_duration: Duration::from_millis(150),
456 backoff_multiplier: 4.0,
457 }));
458 c.permit().unwrap().record_failure(); std::thread::sleep(Duration::from_millis(110));
460 c.permit().unwrap().record_failure(); std::thread::sleep(Duration::from_millis(160));
462 assert!(
463 c.permit().is_ok(),
464 "should be openable again at clamped duration"
465 );
466 }
467
468 #[tokio::test(flavor = "multi_thread")]
469 async fn permit_can_be_moved_across_await() {
470 let c = Arc::new(ProviderCircuit::new(cfg()));
472 let p = c.permit().unwrap();
473 let task = tokio::spawn(async move {
474 tokio::task::yield_now().await;
475 p.record_success();
476 });
477 task.await.unwrap();
478 }
479
480 #[tokio::test(flavor = "multi_thread")]
481 async fn concurrent_requests_during_half_open_only_one_probes() {
482 let c = Arc::new(ProviderCircuit::new(CircuitConfig {
485 failure_threshold: 1,
486 initial_open_duration: Duration::from_millis(20),
487 max_open_duration: Duration::from_millis(200),
488 backoff_multiplier: 2.0,
489 }));
490 c.permit().unwrap().record_failure(); tokio::time::sleep(Duration::from_millis(30)).await;
492
493 let probe = c.permit().expect("first probe granted");
495
496 let second = c.permit();
498 assert!(matches!(second, Err(Error::CircuitOpen { .. })));
499
500 probe.record_success();
502 assert!(c.permit().is_ok());
503 }
504
505 #[test]
506 fn tracker_returns_same_arc_for_same_key() {
507 let t = CircuitTracker::new(cfg());
508 let a = t.circuit_for(&TenantScope::new("acme"), "anthropic");
509 let b = t.circuit_for(&TenantScope::new("acme"), "anthropic");
510 assert!(Arc::ptr_eq(&a, &b));
511 }
512
513 #[test]
514 fn tracker_isolates_tenants() {
515 let t = CircuitTracker::new(cfg());
516 let a = t.circuit_for(&TenantScope::new("acme"), "anthropic");
517 let b = t.circuit_for(&TenantScope::new("globex"), "anthropic");
518 assert!(!Arc::ptr_eq(&a, &b));
519 }
520
521 #[test]
522 fn tracker_isolates_providers() {
523 let t = CircuitTracker::new(cfg());
524 let a = t.circuit_for(&TenantScope::new("acme"), "anthropic");
525 let b = t.circuit_for(&TenantScope::new("acme"), "openai");
526 assert!(!Arc::ptr_eq(&a, &b));
527 }
528
529 #[test]
530 fn is_circuit_failure_classifies_correctly() {
531 let server = Error::Api {
533 status: 503,
534 message: "service unavailable".into(),
535 };
536 assert!(is_circuit_failure(&server));
537
538 let rate = Error::Api {
540 status: 429,
541 message: "too many requests".into(),
542 };
543 assert!(is_circuit_failure(&rate));
544
545 let rt = tokio::runtime::Builder::new_current_thread()
547 .enable_all()
548 .build()
549 .expect("test runtime");
550 let http_err = rt
551 .block_on(reqwest::get("http://[::0]:1"))
552 .expect_err("should fail");
553 assert!(is_circuit_failure(&Error::Http(http_err)));
554
555 let auth = Error::Api {
557 status: 401,
558 message: "unauthorized".into(),
559 };
560 assert!(!is_circuit_failure(&auth));
561
562 let bad = Error::Api {
564 status: 400,
565 message: "bad json".into(),
566 };
567 assert!(!is_circuit_failure(&bad));
568 }
569
570 use crate::llm::LlmProvider;
573 use crate::llm::types::{CompletionRequest, Message};
574
575 struct FailingProvider {
576 error: Box<dyn Fn() -> Error + Send + Sync>,
577 }
578
579 impl LlmProvider for FailingProvider {
580 async fn complete(
581 &self,
582 _r: CompletionRequest,
583 ) -> Result<crate::llm::types::CompletionResponse, Error> {
584 Err((self.error)())
585 }
586 }
587
588 fn dummy_request() -> CompletionRequest {
589 CompletionRequest {
590 system: "test".into(),
591 messages: vec![Message::user("hi")],
592 tools: vec![],
593 max_tokens: 10,
594 tool_choice: None,
595 reasoning_effort: None,
596 }
597 }
598
599 #[tokio::test(flavor = "multi_thread")]
600 async fn circuit_opens_after_threshold_failures() {
601 let tracker = Arc::new(CircuitTracker::new(CircuitConfig {
602 failure_threshold: 3,
603 initial_open_duration: Duration::from_secs(60),
604 max_open_duration: Duration::from_secs(120),
605 backoff_multiplier: 2.0,
606 }));
607 let inner = FailingProvider {
608 error: Box::new(|| Error::Api {
609 status: 503,
610 message: "down".into(),
611 }),
612 };
613 let wrapper = CircuitBreakerProvider::new(
614 inner,
615 tracker.clone(),
616 "anthropic",
617 TenantScope::new("acme"),
618 );
619
620 for _ in 0..3 {
622 let _ = wrapper.complete(dummy_request()).await;
623 }
624 let err = wrapper.complete(dummy_request()).await.unwrap_err();
626 assert!(matches!(err, Error::CircuitOpen { .. }));
627 }
628
629 #[tokio::test(flavor = "multi_thread")]
630 async fn auth_errors_do_not_trip_circuit() {
631 let tracker = Arc::new(CircuitTracker::new(cfg()));
632 let inner = FailingProvider {
633 error: Box::new(|| Error::Api {
634 status: 401,
635 message: "no key".into(),
636 }),
637 };
638 let wrapper = CircuitBreakerProvider::new(
639 inner,
640 tracker.clone(),
641 "anthropic",
642 TenantScope::new("acme"),
643 );
644
645 for _ in 0..10 {
646 let _ = wrapper.complete(dummy_request()).await;
647 }
648 let circuit = tracker.circuit_for(&TenantScope::new("acme"), "anthropic");
650 assert!(circuit.permit().is_ok());
651 }
652
653 #[tokio::test(flavor = "multi_thread")]
654 async fn circuit_outer_retry_inner_one_permit_per_outer_call() {
655 let tracker = Arc::new(CircuitTracker::new(CircuitConfig {
657 failure_threshold: 2,
658 initial_open_duration: Duration::from_secs(60),
659 max_open_duration: Duration::from_secs(120),
660 backoff_multiplier: 2.0,
661 }));
662 let inner = FailingProvider {
663 error: Box::new(|| Error::Api {
664 status: 503,
665 message: "down".into(),
666 }),
667 };
668 let retrying = crate::llm::retry::RetryingProvider::new(
669 inner,
670 crate::llm::retry::RetryConfig {
671 max_retries: 3,
672 base_delay: Duration::from_millis(1),
673 max_delay: Duration::from_millis(10),
674 },
675 );
676 let wrapper = CircuitBreakerProvider::new(
677 retrying,
678 tracker.clone(),
679 "anthropic",
680 TenantScope::new("acme"),
681 );
682
683 let _ = wrapper.complete(dummy_request()).await;
685 let _ = wrapper.complete(dummy_request()).await;
686 let err = wrapper.complete(dummy_request()).await.unwrap_err();
687 assert!(matches!(err, Error::CircuitOpen { .. }));
688 }
689
690 #[test]
691 fn permit_drop_without_consume_records_failure() {
692 let c = Arc::new(ProviderCircuit::new(CircuitConfig {
694 failure_threshold: 1,
695 initial_open_duration: Duration::from_millis(50),
696 max_open_duration: Duration::from_millis(500),
697 backoff_multiplier: 2.0,
698 }));
699 let permit = c.permit().unwrap();
701 drop(permit); assert!(
704 c.permit().is_err(),
705 "circuit should be open after unconsumed permit drop"
706 );
707 }
708}