1use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
7use std::time::Duration;
8
9use async_trait::async_trait;
10
11use crate::error::PeError;
12use crate::llm::{LlmProvider, LlmResponse, ToolSchema};
13use crate::message::Message;
14use crate::provider_middleware::ProviderMiddleware;
15
16const STATE_CLOSED: u32 = 0;
18const STATE_OPEN: u32 = 1;
19const STATE_HALF_OPEN: u32 = 2;
20
21pub struct CircuitBreaker {
27 failure_threshold: u32,
28 recovery_timeout: Duration,
29 state: AtomicU32,
31 failure_count: AtomicU32,
33 opened_at: AtomicU64,
35}
36
37impl CircuitBreaker {
38 pub fn new(failure_threshold: u32, recovery_timeout: Duration) -> Self {
43 Self {
44 failure_threshold,
45 recovery_timeout,
46 state: AtomicU32::new(STATE_CLOSED),
47 failure_count: AtomicU32::new(0),
48 opened_at: AtomicU64::new(0),
49 }
50 }
51
52 pub fn state_name(&self) -> &'static str {
54 match self.state.load(Ordering::SeqCst) {
55 STATE_CLOSED => "closed",
56 STATE_OPEN => "open",
57 STATE_HALF_OPEN => "half-open",
58 _ => "unknown",
59 }
60 }
61
62 pub fn failure_count(&self) -> u32 {
64 self.failure_count.load(Ordering::SeqCst)
65 }
66
67 fn now_millis() -> u64 {
68 std::time::SystemTime::now()
69 .duration_since(std::time::UNIX_EPOCH)
70 .unwrap_or_default()
71 .as_millis() as u64
72 }
73
74 fn record_success(&self) {
75 self.failure_count.store(0, Ordering::SeqCst);
76 self.state.store(STATE_CLOSED, Ordering::SeqCst);
77 }
78
79 fn record_failure(&self) {
80 let count = self.failure_count.fetch_add(1, Ordering::SeqCst) + 1;
81 if count >= self.failure_threshold {
82 self.state.store(STATE_OPEN, Ordering::SeqCst);
83 self.opened_at.store(Self::now_millis(), Ordering::SeqCst);
84 }
85 }
86
87 fn should_allow(&self) -> bool {
88 match self.state.load(Ordering::SeqCst) {
89 STATE_CLOSED => true,
90 STATE_HALF_OPEN => {
91 self.state
94 .compare_exchange(
95 STATE_HALF_OPEN,
96 STATE_CLOSED,
97 Ordering::SeqCst,
98 Ordering::SeqCst,
99 )
100 .is_ok()
101 }
102 STATE_OPEN => {
103 let opened = self.opened_at.load(Ordering::SeqCst);
104 let elapsed = Self::now_millis().saturating_sub(opened);
105 if elapsed >= self.recovery_timeout.as_millis() as u64 {
106 self.state
108 .compare_exchange(
109 STATE_OPEN,
110 STATE_HALF_OPEN,
111 Ordering::SeqCst,
112 Ordering::SeqCst,
113 )
114 .is_ok()
115 } else {
116 false
117 }
118 }
119 _ => false,
120 }
121 }
122}
123
124#[async_trait]
125impl ProviderMiddleware for CircuitBreaker {
126 async fn wrap_complete(
127 &self,
128 messages: &[Message],
129 tools: &[ToolSchema],
130 next: &dyn LlmProvider,
131 ) -> Result<LlmResponse, PeError> {
132 if !self.should_allow() {
133 return Err(PeError::LlmProvider {
134 details: "circuit breaker open — provider is unavailable".into(),
135 });
136 }
137
138 match next.complete(messages, tools).await {
139 Ok(resp) => {
140 self.record_success();
141 Ok(resp)
142 }
143 Err(e) if e.is_transient() => {
144 self.record_failure();
145 Err(e)
146 }
147 Err(e) => Err(e), }
149 }
150}
151
152#[cfg(test)]
153mod tests {
154 use super::*;
155 use crate::mock_provider::MockProvider;
156
157 fn llm_err() -> PeError {
158 PeError::LlmProvider {
159 details: "err".into(),
160 }
161 }
162
163 fn fail_provider(n: usize) -> MockProvider {
164 let mut p = MockProvider::new();
165 for _ in 0..n {
166 p = p.respond_with_error(llm_err());
167 }
168 p
169 }
170
171 #[tokio::test]
172 async fn test_closed_allows_calls() {
173 let cb = CircuitBreaker::new(3, Duration::from_secs(60));
174 let resp = cb
175 .wrap_complete(&[], &[], &MockProvider::new().respond_with("ok"))
176 .await
177 .unwrap();
178 assert_eq!(resp.message.content.as_text(), Some("ok"));
179 assert_eq!(cb.state_name(), "closed");
180 }
181
182 #[tokio::test]
183 async fn test_opens_after_threshold_failures() {
184 let cb = CircuitBreaker::new(2, Duration::from_secs(60));
185 let provider = fail_provider(2);
186 let _ = cb.wrap_complete(&[], &[], &provider).await;
187 assert_eq!(cb.state_name(), "closed");
188 let _ = cb.wrap_complete(&[], &[], &provider).await;
189 assert_eq!(cb.state_name(), "open");
190 }
191
192 #[tokio::test]
193 async fn test_open_rejects_immediately() {
194 let cb = CircuitBreaker::new(1, Duration::from_secs(60));
195 let _ = cb.wrap_complete(&[], &[], &fail_provider(1)).await;
196 assert_eq!(cb.state_name(), "open");
197
198 let ok = MockProvider::new().respond_with("should not reach");
199 let err = cb.wrap_complete(&[], &[], &ok).await.unwrap_err();
200 assert!(matches!(err, PeError::LlmProvider { .. }));
201 assert_eq!(ok.remaining(), 1);
202 }
203
204 #[tokio::test]
205 async fn test_half_open_recovery_and_reopen() {
206 let cb = CircuitBreaker::new(1, Duration::from_millis(10));
207 let _ = cb.wrap_complete(&[], &[], &fail_provider(1)).await;
208 assert_eq!(cb.state_name(), "open");
209 tokio::time::sleep(Duration::from_millis(20)).await;
210
211 let probe = MockProvider::new().respond_with("recovered");
213 let resp = cb.wrap_complete(&[], &[], &probe).await.unwrap();
214 assert_eq!(resp.message.content.as_text(), Some("recovered"));
215 assert_eq!(cb.state_name(), "closed");
216
217 let _ = cb.wrap_complete(&[], &[], &fail_provider(1)).await;
219 tokio::time::sleep(Duration::from_millis(20)).await;
220 let _ = cb.wrap_complete(&[], &[], &fail_provider(1)).await;
221 assert_eq!(cb.state_name(), "open");
222 }
223
224 #[tokio::test]
225 async fn test_permanent_errors_dont_trip_breaker() {
226 let cb = CircuitBreaker::new(1, Duration::from_secs(60));
227 let p = MockProvider::new().respond_with_error(PeError::PermissionDenied {
228 action: "write".into(),
229 });
230 let _ = cb.wrap_complete(&[], &[], &p).await;
231 assert_eq!(cb.state_name(), "closed");
232 assert_eq!(cb.failure_count(), 0);
233 }
234
235 #[tokio::test]
236 async fn test_success_resets_failure_count() {
237 let cb = CircuitBreaker::new(3, Duration::from_secs(60));
238 let provider = fail_provider(2).respond_with("ok");
239 let _ = cb.wrap_complete(&[], &[], &provider).await;
240 let _ = cb.wrap_complete(&[], &[], &provider).await;
241 assert_eq!(cb.failure_count(), 2);
242 let _ = cb.wrap_complete(&[], &[], &provider).await;
243 assert_eq!(cb.failure_count(), 0);
244 }
245}