1use std::sync::Arc;
7use std::time::{Duration, Instant};
8
9use forge_error::DispatchError;
10use forge_sandbox::{ResourceDispatcher, ToolDispatcher};
11use serde_json::Value;
12use tokio::sync::Mutex;
13
14#[derive(Debug, Clone)]
16pub struct CircuitBreakerConfig {
17 pub failure_threshold: u32,
19 pub recovery_timeout: Duration,
21}
22
23impl Default for CircuitBreakerConfig {
24 fn default() -> Self {
25 Self {
26 failure_threshold: 3,
27 recovery_timeout: Duration::from_secs(30),
28 }
29 }
30}
31
32#[derive(Debug, Clone, Copy, PartialEq, Eq)]
33pub(crate) enum CircuitState {
34 Closed,
35 Open,
36 HalfOpen,
37}
38
39pub(crate) struct CircuitBreakerState {
41 pub(crate) state: CircuitState,
42 pub(crate) consecutive_failures: u32,
43 pub(crate) last_failure_time: Option<Instant>,
44}
45
46pub struct CircuitBreakerDispatcher {
48 inner: Arc<dyn ToolDispatcher>,
49 config: CircuitBreakerConfig,
50 server_name: String,
51 state: Mutex<CircuitBreakerState>,
52}
53
54impl CircuitBreakerDispatcher {
55 pub fn new(
57 inner: Arc<dyn ToolDispatcher>,
58 config: CircuitBreakerConfig,
59 server_name: impl Into<String>,
60 ) -> Self {
61 Self {
62 inner,
63 config,
64 server_name: server_name.into(),
65 state: Mutex::new(CircuitBreakerState {
66 state: CircuitState::Closed,
67 consecutive_failures: 0,
68 last_failure_time: None,
69 }),
70 }
71 }
72}
73
74#[async_trait::async_trait]
75impl ToolDispatcher for CircuitBreakerDispatcher {
76 #[tracing::instrument(skip(self, args), fields(server, tool))]
77 async fn call_tool(
78 &self,
79 server: &str,
80 tool: &str,
81 args: Value,
82 ) -> Result<Value, DispatchError> {
83 {
84 let mut st = self.state.lock().await;
85 match st.state {
86 CircuitState::Open => {
87 if let Some(last_fail) = st.last_failure_time {
88 if last_fail.elapsed() >= self.config.recovery_timeout {
89 st.state = CircuitState::HalfOpen;
90 tracing::info!(
91 server = %self.server_name,
92 "circuit breaker half-open, allowing probe call"
93 );
94 } else {
95 return Err(DispatchError::CircuitOpen(self.server_name.clone()));
96 }
97 }
98 }
99 CircuitState::HalfOpen | CircuitState::Closed => {}
100 }
101 }
102
103 let result = self.inner.call_tool(server, tool, args).await;
104
105 {
106 let mut st = self.state.lock().await;
107 match &result {
108 Ok(_) => {
109 if st.state == CircuitState::HalfOpen {
110 tracing::info!(
111 server = %self.server_name,
112 "circuit breaker closed after successful probe"
113 );
114 }
115 st.state = CircuitState::Closed;
116 st.consecutive_failures = 0;
117 st.last_failure_time = None;
118 }
119 Err(e) if e.trips_circuit_breaker() => {
120 st.consecutive_failures += 1;
121 st.last_failure_time = Some(Instant::now());
122 if st.state == CircuitState::HalfOpen {
123 st.state = CircuitState::Open;
124 tracing::warn!(
125 server = %self.server_name,
126 "circuit breaker re-opened after failed probe"
127 );
128 } else if st.consecutive_failures >= self.config.failure_threshold {
129 st.state = CircuitState::Open;
130 tracing::warn!(
131 server = %self.server_name,
132 failures = st.consecutive_failures,
133 "circuit breaker opened"
134 );
135 }
136 }
137 Err(_) => {
138 if st.state == CircuitState::HalfOpen {
142 tracing::info!(
143 server = %self.server_name,
144 "circuit breaker closed: server responded (non-fault error)"
145 );
146 st.state = CircuitState::Closed;
147 st.consecutive_failures = 0;
148 st.last_failure_time = None;
149 }
150 }
153 }
154 }
155
156 result
157 }
158}
159
160pub struct CircuitBreakerResourceDispatcher {
162 inner: Arc<dyn ResourceDispatcher>,
163 server_name: String,
164 config: CircuitBreakerConfig,
165 state: Arc<Mutex<CircuitBreakerState>>,
166}
167
168impl CircuitBreakerResourceDispatcher {
169 pub fn new(
171 inner: Arc<dyn ResourceDispatcher>,
172 config: CircuitBreakerConfig,
173 server_name: impl Into<String>,
174 ) -> Self {
175 Self {
176 inner,
177 config,
178 server_name: server_name.into(),
179 state: Arc::new(Mutex::new(CircuitBreakerState {
180 state: CircuitState::Closed,
181 consecutive_failures: 0,
182 last_failure_time: None,
183 })),
184 }
185 }
186}
187
188#[async_trait::async_trait]
189impl ResourceDispatcher for CircuitBreakerResourceDispatcher {
190 #[tracing::instrument(skip(self), fields(server, uri))]
191 async fn read_resource(
192 &self,
193 server: &str,
194 uri: &str,
195 ) -> Result<serde_json::Value, DispatchError> {
196 {
197 let mut st = self.state.lock().await;
198 match st.state {
199 CircuitState::Open => {
200 if let Some(last_fail) = st.last_failure_time {
201 if last_fail.elapsed() >= self.config.recovery_timeout {
202 st.state = CircuitState::HalfOpen;
203 } else {
204 return Err(DispatchError::CircuitOpen(self.server_name.clone()));
205 }
206 }
207 }
208 CircuitState::HalfOpen | CircuitState::Closed => {}
209 }
210 }
211
212 let result = self.inner.read_resource(server, uri).await;
213
214 {
215 let mut st = self.state.lock().await;
216 match &result {
217 Ok(_) => {
218 st.state = CircuitState::Closed;
219 st.consecutive_failures = 0;
220 st.last_failure_time = None;
221 }
222 Err(e) if e.trips_circuit_breaker() => {
223 st.consecutive_failures += 1;
224 st.last_failure_time = Some(Instant::now());
225 if st.state == CircuitState::HalfOpen
226 || st.consecutive_failures >= self.config.failure_threshold
227 {
228 st.state = CircuitState::Open;
229 }
230 }
231 Err(_) => {
232 if st.state == CircuitState::HalfOpen {
233 st.state = CircuitState::Closed;
234 st.consecutive_failures = 0;
235 st.last_failure_time = None;
236 }
237 }
238 }
239 }
240
241 result
242 }
243}
244
245#[cfg(test)]
246mod tests {
247 use super::*;
248 use std::sync::atomic::{AtomicUsize, Ordering};
249
250 struct OkDispatcher;
251
252 #[async_trait::async_trait]
253 impl ToolDispatcher for OkDispatcher {
254 async fn call_tool(
255 &self,
256 _server: &str,
257 tool: &str,
258 _args: Value,
259 ) -> Result<Value, DispatchError> {
260 Ok(serde_json::json!({"tool": tool, "status": "ok"}))
261 }
262 }
263
264 struct FailDispatcher {
265 calls: AtomicUsize,
266 }
267
268 impl FailDispatcher {
269 fn new() -> Self {
270 Self {
271 calls: AtomicUsize::new(0),
272 }
273 }
274 fn call_count(&self) -> usize {
275 self.calls.load(Ordering::SeqCst)
276 }
277 }
278
279 #[async_trait::async_trait]
280 impl ToolDispatcher for FailDispatcher {
281 async fn call_tool(
282 &self,
283 _server: &str,
284 _tool: &str,
285 _args: Value,
286 ) -> Result<Value, DispatchError> {
287 self.calls.fetch_add(1, Ordering::SeqCst);
288 Err(DispatchError::Timeout {
289 server: "s".into(),
290 timeout_ms: 5000,
291 })
292 }
293 }
294
295 struct FailThenOkDispatcher {
297 calls: AtomicUsize,
298 fail_count: usize,
299 }
300
301 #[async_trait::async_trait]
302 impl ToolDispatcher for FailThenOkDispatcher {
303 async fn call_tool(
304 &self,
305 _server: &str,
306 tool: &str,
307 _args: Value,
308 ) -> Result<Value, DispatchError> {
309 let n = self.calls.fetch_add(1, Ordering::SeqCst);
310 if n < self.fail_count {
311 Err(DispatchError::Timeout {
312 server: "s".into(),
313 timeout_ms: 5000,
314 })
315 } else {
316 Ok(serde_json::json!({"tool": tool, "status": "ok"}))
317 }
318 }
319 }
320
321 fn test_config(threshold: u32, recovery_ms: u64) -> CircuitBreakerConfig {
322 CircuitBreakerConfig {
323 failure_threshold: threshold,
324 recovery_timeout: Duration::from_millis(recovery_ms),
325 }
326 }
327
328 #[tokio::test]
329 async fn passes_through_on_success() {
330 let inner = Arc::new(OkDispatcher);
331 let cb = CircuitBreakerDispatcher::new(inner, test_config(3, 1000), "test");
332 let result = cb.call_tool("test", "echo", serde_json::json!({})).await;
333 assert!(result.is_ok());
334 assert_eq!(result.unwrap()["status"], "ok");
335 }
336
337 #[tokio::test]
338 async fn opens_after_threshold_failures() {
339 let inner = Arc::new(FailDispatcher::new());
340 let cb = CircuitBreakerDispatcher::new(inner.clone(), test_config(3, 60_000), "flaky");
341
342 for _ in 0..3 {
343 let _ = cb.call_tool("flaky", "tool", serde_json::json!({})).await;
344 }
345 assert_eq!(inner.call_count(), 3);
346
347 let result = cb.call_tool("flaky", "tool", serde_json::json!({})).await;
348 assert!(matches!(result, Err(DispatchError::CircuitOpen(_))));
349 assert_eq!(
350 inner.call_count(),
351 3,
352 "inner should not be called when open"
353 );
354 }
355
356 #[tokio::test]
357 async fn rejects_when_open() {
358 let inner = Arc::new(FailDispatcher::new());
359 let cb = CircuitBreakerDispatcher::new(inner.clone(), test_config(2, 60_000), "s");
360
361 for _ in 0..2 {
362 let _ = cb.call_tool("s", "t", serde_json::json!({})).await;
363 }
364
365 for _ in 0..5 {
366 let result = cb.call_tool("s", "t", serde_json::json!({})).await;
367 assert!(matches!(result, Err(DispatchError::CircuitOpen(_))));
368 }
369 assert_eq!(
370 inner.call_count(),
371 2,
372 "no additional calls should reach inner"
373 );
374 }
375
376 #[tokio::test]
377 async fn half_open_after_recovery_timeout() {
378 let inner = Arc::new(FailThenOkDispatcher {
379 calls: AtomicUsize::new(0),
380 fail_count: 3,
381 });
382 let cb = CircuitBreakerDispatcher::new(inner, test_config(3, 50), "s");
383
384 for _ in 0..3 {
385 let _ = cb.call_tool("s", "t", serde_json::json!({})).await;
386 }
387
388 let result = cb.call_tool("s", "t", serde_json::json!({})).await;
389 assert!(result.is_err());
390
391 tokio::time::sleep(Duration::from_millis(60)).await;
392
393 let result = cb.call_tool("s", "t", serde_json::json!({})).await;
394 assert!(result.is_ok(), "probe should succeed after recovery");
395 }
396
397 #[tokio::test]
398 async fn probe_failure_reopens_circuit() {
399 let inner = Arc::new(FailDispatcher::new());
400 let cb = CircuitBreakerDispatcher::new(inner.clone(), test_config(2, 50), "s");
401
402 for _ in 0..2 {
403 let _ = cb.call_tool("s", "t", serde_json::json!({})).await;
404 }
405
406 tokio::time::sleep(Duration::from_millis(60)).await;
407
408 let result = cb.call_tool("s", "t", serde_json::json!({})).await;
409 assert!(result.is_err());
410
411 let before = inner.call_count();
412 let result = cb.call_tool("s", "t", serde_json::json!({})).await;
413 assert!(matches!(result, Err(DispatchError::CircuitOpen(_))));
414 assert_eq!(
415 inner.call_count(),
416 before,
417 "should not reach inner after probe failure"
418 );
419 }
420
421 #[tokio::test]
422 async fn success_resets_failure_counter() {
423 let inner = Arc::new(FailThenOkDispatcher {
424 calls: AtomicUsize::new(0),
425 fail_count: 2,
426 });
427 let cb = CircuitBreakerDispatcher::new(inner, test_config(3, 60_000), "s");
428
429 let _ = cb.call_tool("s", "t", serde_json::json!({})).await;
430 let _ = cb.call_tool("s", "t", serde_json::json!({})).await;
431
432 let result = cb.call_tool("s", "t", serde_json::json!({})).await;
433 assert!(result.is_ok());
434
435 let st = cb.state.lock().await;
436 assert_eq!(st.state, CircuitState::Closed);
437 assert_eq!(st.consecutive_failures, 0);
438 }
439
440 #[tokio::test]
441 async fn error_message_includes_server_and_failure_count() {
442 let inner = Arc::new(FailDispatcher::new());
443 let cb = CircuitBreakerDispatcher::new(inner, test_config(2, 60_000), "my-server");
444
445 for _ in 0..2 {
446 let _ = cb
447 .call_tool("my-server", "tool", serde_json::json!({}))
448 .await;
449 }
450
451 let err = cb
452 .call_tool("my-server", "tool", serde_json::json!({}))
453 .await
454 .unwrap_err();
455 assert!(matches!(err, DispatchError::CircuitOpen(ref s) if s == "my-server"));
456 }
457
458 struct FailResourceDispatcher;
461
462 #[async_trait::async_trait]
463 impl ResourceDispatcher for FailResourceDispatcher {
464 async fn read_resource(&self, _server: &str, _uri: &str) -> Result<Value, DispatchError> {
465 Err(DispatchError::Timeout {
466 server: "flaky".into(),
467 timeout_ms: 5000,
468 })
469 }
470 }
471
472 #[tokio::test]
473 async fn rs_c08_circuit_breaker_trips_on_repeated_resource_failures() {
474 let inner: Arc<dyn ResourceDispatcher> = Arc::new(FailResourceDispatcher);
475 let cb = CircuitBreakerResourceDispatcher::new(inner, test_config(2, 60_000), "flaky");
476
477 for _ in 0..2 {
478 let _ = cb.read_resource("flaky", "file:///log").await;
479 }
480
481 let result = cb.read_resource("flaky", "file:///log").await;
482 assert!(matches!(result, Err(DispatchError::CircuitOpen(_))));
483 }
484
485 #[tokio::test]
486 async fn probe_success_closes_circuit() {
487 let inner = Arc::new(FailThenOkDispatcher {
488 calls: AtomicUsize::new(0),
489 fail_count: 2,
490 });
491 let cb = CircuitBreakerDispatcher::new(inner, test_config(2, 50), "s");
492
493 for _ in 0..2 {
494 let _ = cb.call_tool("s", "t", serde_json::json!({})).await;
495 }
496
497 tokio::time::sleep(Duration::from_millis(60)).await;
498
499 let result = cb.call_tool("s", "t", serde_json::json!({})).await;
500 assert!(result.is_ok());
501
502 let result = cb.call_tool("s", "t", serde_json::json!({})).await;
503 assert!(result.is_ok());
504 }
505
506 struct ToolErrorDispatcher {
510 calls: AtomicUsize,
511 }
512
513 impl ToolErrorDispatcher {
514 fn new() -> Self {
515 Self {
516 calls: AtomicUsize::new(0),
517 }
518 }
519 fn call_count(&self) -> usize {
520 self.calls.load(Ordering::SeqCst)
521 }
522 }
523
524 #[async_trait::async_trait]
525 impl ToolDispatcher for ToolErrorDispatcher {
526 async fn call_tool(
527 &self,
528 _server: &str,
529 _tool: &str,
530 _args: Value,
531 ) -> Result<Value, DispatchError> {
532 self.calls.fetch_add(1, Ordering::SeqCst);
533 Err(DispatchError::ToolError {
534 server: "s".into(),
535 tool: "scan".into(),
536 message: "Invalid params: missing field 'base_url'".into(),
537 })
538 }
539 }
540
541 #[async_trait::async_trait]
542 impl ResourceDispatcher for ToolErrorDispatcher {
543 async fn read_resource(&self, _server: &str, _uri: &str) -> Result<Value, DispatchError> {
544 self.calls.fetch_add(1, Ordering::SeqCst);
545 Err(DispatchError::ToolError {
546 server: "s".into(),
547 tool: "read".into(),
548 message: "Invalid params".into(),
549 })
550 }
551 }
552
553 struct SequencedDispatcher {
555 sequence: Vec<Option<bool>>,
557 calls: AtomicUsize,
558 }
559
560 #[async_trait::async_trait]
561 impl ToolDispatcher for SequencedDispatcher {
562 async fn call_tool(
563 &self,
564 _server: &str,
565 tool: &str,
566 _args: Value,
567 ) -> Result<Value, DispatchError> {
568 let n = self.calls.fetch_add(1, Ordering::SeqCst);
569 match self.sequence.get(n) {
570 Some(Some(true)) => Err(DispatchError::Timeout {
571 server: "s".into(),
572 timeout_ms: 5000,
573 }),
574 Some(Some(false)) => Err(DispatchError::ToolError {
575 server: "s".into(),
576 tool: tool.into(),
577 message: "bad params".into(),
578 }),
579 Some(None) => Ok(serde_json::json!({"tool": tool, "ok": true})),
580 None => Ok(serde_json::json!({"tool": tool, "ok": true})),
581 }
582 }
583 }
584
585 #[tokio::test]
586 async fn tool_error_does_not_count_toward_threshold() {
587 let inner = Arc::new(ToolErrorDispatcher::new());
588 let cb = CircuitBreakerDispatcher::new(inner.clone(), test_config(3, 60_000), "arbiter");
589
590 for _ in 0..10 {
592 let result = cb.call_tool("arbiter", "scan", serde_json::json!({})).await;
593 assert!(result.is_err());
594 assert!(
595 matches!(result, Err(DispatchError::ToolError { .. })),
596 "expected ToolError"
597 );
598 }
599
600 assert_eq!(inner.call_count(), 10);
602
603 let st = cb.state.lock().await;
605 assert_eq!(st.state, CircuitState::Closed);
606 assert_eq!(st.consecutive_failures, 0);
607 }
608
609 #[tokio::test]
610 async fn timeout_still_trips_after_threshold() {
611 let inner = Arc::new(FailDispatcher::new());
613 let cb = CircuitBreakerDispatcher::new(inner.clone(), test_config(3, 60_000), "s");
614
615 for _ in 0..3 {
616 let _ = cb.call_tool("s", "t", serde_json::json!({})).await;
617 }
618
619 let result = cb.call_tool("s", "t", serde_json::json!({})).await;
620 assert!(matches!(result, Err(DispatchError::CircuitOpen(_))));
621 assert_eq!(inner.call_count(), 3);
622 }
623
624 #[tokio::test]
625 async fn mixed_errors_only_server_faults_count() {
626 let inner = Arc::new(SequencedDispatcher {
629 sequence: vec![
630 Some(false), Some(true), Some(false), Some(true), Some(false), Some(true), ],
637 calls: AtomicUsize::new(0),
638 });
639 let cb = CircuitBreakerDispatcher::new(inner, test_config(3, 60_000), "s");
640
641 for _ in 0..6 {
642 let _ = cb.call_tool("s", "t", serde_json::json!({})).await;
643 }
644
645 let result = cb.call_tool("s", "t", serde_json::json!({})).await;
647 assert!(
648 matches!(result, Err(DispatchError::CircuitOpen(_))),
649 "expected CircuitOpen after 3 timeouts, got: {:?}",
650 result
651 );
652 }
653
654 #[tokio::test]
655 async fn client_error_preserves_failure_counter() {
656 let inner = Arc::new(SequencedDispatcher {
659 sequence: vec![
660 Some(true), Some(false), Some(true), ],
664 calls: AtomicUsize::new(0),
665 });
666 let cb = CircuitBreakerDispatcher::new(inner, test_config(3, 60_000), "s");
667
668 for _ in 0..3 {
669 let _ = cb.call_tool("s", "t", serde_json::json!({})).await;
670 }
671
672 let st = cb.state.lock().await;
673 assert_eq!(
674 st.consecutive_failures, 2,
675 "ToolError should not reset counter"
676 );
677 assert_eq!(st.state, CircuitState::Closed);
678 }
679
680 #[tokio::test]
681 async fn success_still_resets_counter_after_tool_errors() {
682 let inner = Arc::new(SequencedDispatcher {
685 sequence: vec![
686 Some(true), Some(false), None, ],
690 calls: AtomicUsize::new(0),
691 });
692 let cb = CircuitBreakerDispatcher::new(inner, test_config(3, 60_000), "s");
693
694 for _ in 0..3 {
695 let _ = cb.call_tool("s", "t", serde_json::json!({})).await;
696 }
697
698 let st = cb.state.lock().await;
699 assert_eq!(st.consecutive_failures, 0);
700 assert_eq!(st.state, CircuitState::Closed);
701 }
702
703 #[tokio::test]
704 async fn half_open_probe_tool_error_closes_circuit() {
705 let inner = Arc::new(SequencedDispatcher {
708 sequence: vec![
709 Some(true), Some(true), Some(false), None, ],
714 calls: AtomicUsize::new(0),
715 });
716 let cb = CircuitBreakerDispatcher::new(inner, test_config(2, 50), "s");
717
718 let _ = cb.call_tool("s", "t", serde_json::json!({})).await;
720 let _ = cb.call_tool("s", "t", serde_json::json!({})).await;
721
722 let result = cb.call_tool("s", "t", serde_json::json!({})).await;
724 assert!(matches!(result, Err(DispatchError::CircuitOpen(_))));
725
726 tokio::time::sleep(Duration::from_millis(60)).await;
728
729 let result = cb.call_tool("s", "t", serde_json::json!({})).await;
731 assert!(
732 matches!(result, Err(DispatchError::ToolError { .. })),
733 "probe should return ToolError, got: {:?}",
734 result
735 );
736
737 let st = cb.state.lock().await;
739 assert_eq!(st.state, CircuitState::Closed);
740 assert_eq!(st.consecutive_failures, 0);
741 }
742
743 #[tokio::test]
744 async fn half_open_probe_timeout_reopens_circuit() {
745 let inner = Arc::new(FailDispatcher::new());
747 let cb = CircuitBreakerDispatcher::new(inner, test_config(2, 50), "s");
748
749 for _ in 0..2 {
750 let _ = cb.call_tool("s", "t", serde_json::json!({})).await;
751 }
752
753 tokio::time::sleep(Duration::from_millis(60)).await;
754
755 let result = cb.call_tool("s", "t", serde_json::json!({})).await;
757 assert!(result.is_err());
758
759 let result = cb.call_tool("s", "t", serde_json::json!({})).await;
761 assert!(matches!(result, Err(DispatchError::CircuitOpen(_))));
762 }
763
764 #[tokio::test]
765 async fn resource_tool_error_does_not_trip_breaker() {
766 let inner: Arc<dyn ResourceDispatcher> = Arc::new(ToolErrorDispatcher::new());
767 let cb = CircuitBreakerResourceDispatcher::new(inner, test_config(2, 60_000), "s");
768
769 for _ in 0..5 {
771 let result = cb.read_resource("s", "file:///log").await;
772 assert!(matches!(result, Err(DispatchError::ToolError { .. })));
773 }
774
775 let st = cb.state.lock().await;
776 assert_eq!(st.state, CircuitState::Closed);
777 assert_eq!(st.consecutive_failures, 0);
778 }
779}