1use std::sync::atomic::{AtomicUsize, Ordering};
21use std::sync::Arc;
22use std::time::{Duration, Instant};
23
24#[derive(Debug, Clone, Copy, PartialEq, Eq)]
26pub enum CircuitState {
27 Closed,
29 Open,
31 HalfOpen,
33}
34
35#[derive(Debug, Clone)]
37pub struct CircuitBreakerConfig {
38 pub failure_threshold: usize,
40
41 pub success_threshold: usize,
43
44 pub open_duration: Duration,
46}
47
48#[derive(Debug, Clone)]
50pub enum CircuitResult<T> {
51 Success(T),
53 Failure(String),
55 Rejected(String),
57 RetryAllowed(String),
59}
60
61impl<T> CircuitResult<T> {
62 pub fn is_success(&self) -> bool {
64 matches!(self, CircuitResult::Success(_))
65 }
66
67 pub fn is_rejected(&self) -> bool {
69 matches!(self, CircuitResult::Rejected(_))
70 }
71
72 pub fn unwrap(self) -> T {
74 match self {
75 CircuitResult::Success(v) => v,
76 CircuitResult::Failure(e) => panic!("unwrap on Failure: {}", e),
77 CircuitResult::Rejected(e) => panic!("unwrap on Rejected: {}", e),
78 CircuitResult::RetryAllowed(e) => panic!("unwrap on RetryAllowed: {}", e),
79 }
80 }
81}
82
83#[derive(Debug)]
85pub struct CircuitBreaker {
86 name: String,
88
89 state: std::sync::atomic::AtomicU8,
91
92 failure_count: Arc<AtomicUsize>,
94
95 success_count: Arc<AtomicUsize>,
97
98 open_since_ms: std::sync::atomic::AtomicU64,
100
101 config: CircuitBreakerConfig,
103}
104
105impl CircuitBreaker {
106 pub fn new(name: &str, failure_threshold: usize, open_duration: Duration) -> Self {
112 Self {
113 name: name.to_string(),
114 state: std::sync::atomic::AtomicU8::new(CircuitState::Closed as u8),
115 failure_count: Arc::new(AtomicUsize::new(0)),
116 success_count: Arc::new(AtomicUsize::new(0)),
117 open_since_ms: std::sync::atomic::AtomicU64::new(0),
118 config: CircuitBreakerConfig {
119 failure_threshold,
120 success_threshold: 3,
121 open_duration,
122 },
123 }
124 }
125
126 pub fn default_for(name: &str) -> Self {
128 Self::new(name, 5, Duration::from_secs(60))
129 }
130
131 pub fn state(&self) -> CircuitState {
133 let state = self.state.load(Ordering::SeqCst);
134 let state = CircuitState::try_from(state).unwrap_or(CircuitState::Closed);
135
136 if state == CircuitState::Open {
138 if let Some(since) = self.open_time() {
139 if since.elapsed() >= self.config.open_duration {
140 return CircuitState::HalfOpen;
141 }
142 }
143 }
144
145 state
146 }
147
148 fn open_time(&self) -> Option<Instant> {
150 let ts = self.open_since_ms.load(Ordering::SeqCst);
151 if ts == 0 {
152 None
153 } else {
154 Some(Instant::now() - Duration::from_millis(ts))
155 }
156 }
157
158 pub fn record_success(&self) {
160 let state = self.state();
161
162 match state {
163 CircuitState::Closed => {
164 self.failure_count.store(0, Ordering::SeqCst);
166 }
167 CircuitState::HalfOpen => {
168 let count = self.success_count.fetch_add(1, Ordering::SeqCst) + 1;
169 if count >= self.config.success_threshold {
170 self.state
172 .store(CircuitState::Closed as u8, Ordering::SeqCst);
173 self.failure_count.store(0, Ordering::SeqCst);
174 self.success_count.store(0, Ordering::SeqCst);
175 self.open_since_ms.store(0, Ordering::SeqCst);
176 tracing::info!(
177 "[circuit-breaker] {}: circuit closed (recovered)",
178 self.name
179 );
180 }
181 }
182 CircuitState::Open => {
183 }
185 }
186 }
187
188 pub fn record_failure(&self) {
190 let state = self.state();
191
192 match state {
193 CircuitState::Closed => {
194 let count = self.failure_count.fetch_add(1, Ordering::SeqCst) + 1;
195 if count >= self.config.failure_threshold {
196 self.state.store(CircuitState::Open as u8, Ordering::SeqCst);
198 self.open_since_ms.store(
199 Instant::now().elapsed().as_millis().try_into().unwrap_or(0),
200 Ordering::SeqCst,
201 );
202 tracing::warn!(
203 "[circuit-breaker] {}: circuit opened ({} failures)",
204 self.name,
205 count
206 );
207 }
208 }
209 CircuitState::HalfOpen => {
210 self.state.store(CircuitState::Open as u8, Ordering::SeqCst);
212 self.success_count.store(0, Ordering::SeqCst);
213 tracing::warn!(
214 "[circuit-breaker] {}: circuit reopened (failure in half-open)",
215 self.name
216 );
217 }
218 CircuitState::Open => {
219 }
221 }
222 }
223
224 pub fn can_request(&self) -> bool {
226 let state = self.state();
227 match state {
228 CircuitState::Closed | CircuitState::HalfOpen => true,
229 CircuitState::Open => false,
230 }
231 }
232
233 pub async fn execute<F, T, E>(&self, operation: F) -> CircuitResult<T>
237 where
238 F: std::future::Future<Output = Result<T, E>>,
239 E: std::fmt::Display,
240 {
241 let state = self.state();
242
243 match state {
244 CircuitState::Closed => match operation.await {
245 Ok(result) => {
246 self.record_success();
247 CircuitResult::Success(result)
248 }
249 Err(e) => {
250 self.record_failure();
251 CircuitResult::Failure(e.to_string())
252 }
253 },
254 CircuitState::Open => CircuitResult::Rejected(format!(
255 "circuit is open for {} (source may be temporarily unavailable)",
256 self.name
257 )),
258 CircuitState::HalfOpen => {
259 match operation.await {
261 Ok(_result) => {
262 self.record_success();
263 CircuitResult::RetryAllowed("half-open: success".to_string())
264 }
265 Err(e) => {
266 self.record_failure();
267 CircuitResult::Failure(e.to_string())
268 }
269 }
270 }
271 }
272 }
273
274 pub fn reset(&self) {
276 self.state
277 .store(CircuitState::Closed as u8, Ordering::SeqCst);
278 self.failure_count.store(0, Ordering::SeqCst);
279 self.success_count.store(0, Ordering::SeqCst);
280 self.open_since_ms.store(0, Ordering::SeqCst);
281 }
282}
283
284impl TryFrom<u8> for CircuitState {
285 type Error = ();
286
287 fn try_from(value: u8) -> Result<Self, Self::Error> {
288 match value {
289 0 => Ok(CircuitState::Closed),
290 1 => Ok(CircuitState::Open),
291 2 => Ok(CircuitState::HalfOpen),
292 _ => Err(()),
293 }
294 }
295}
296
297#[derive(Debug, Default)]
299pub struct CircuitBreakerManager {
300 breakers: Arc<std::sync::RwLock<std::collections::HashMap<String, Arc<CircuitBreaker>>>>,
301}
302
303impl CircuitBreakerManager {
304 pub fn new() -> Self {
306 Self {
307 breakers: Arc::new(std::sync::RwLock::new(std::collections::HashMap::new())),
308 }
309 }
310
311 pub fn get(&self, source_id: &str) -> Arc<CircuitBreaker> {
313 {
314 let read_guard = self.breakers.read().expect("RwLock poisoned");
315 if let Some(breaker) = read_guard.get(source_id) {
316 return Arc::clone(breaker);
317 }
318 }
319
320 {
321 let mut write_guard = self.breakers.write().expect("RwLock poisoned");
322 if let Some(breaker) = write_guard.get(source_id) {
324 return Arc::clone(breaker);
325 }
326
327 let breaker = Arc::new(CircuitBreaker::default_for(source_id));
328 write_guard.insert(source_id.to_string(), Arc::clone(&breaker));
329 breaker
330 }
331 }
332
333 pub fn reset_all(&self) {
335 let guard = self.breakers.write().expect("RwLock poisoned");
336 for breaker in guard.values() {
337 breaker.reset();
338 }
339 }
340
341 pub fn status(&self) -> Vec<(String, CircuitState, bool)> {
343 let guard = self.breakers.read().expect("RwLock poisoned");
344 guard
345 .iter()
346 .map(|(name, breaker)| (name.clone(), breaker.state(), breaker.can_request()))
347 .collect()
348 }
349}
350
351#[cfg(test)]
352mod tests {
353 use super::*;
354 use std::time::Duration;
355
356 #[tokio::test]
357 async fn test_circuit_breaker_closed_by_default() {
358 let breaker = CircuitBreaker::default_for("test");
359 assert_eq!(breaker.state(), CircuitState::Closed);
360 assert!(breaker.can_request());
361 }
362
363 #[tokio::test]
364 async fn test_circuit_breaker_opens_after_failures() {
365 let breaker = Arc::new(CircuitBreaker::new("test", 3, Duration::from_secs(60)));
366
367 breaker.record_failure();
369 breaker.record_failure();
370 assert_eq!(breaker.state(), CircuitState::Closed);
371 assert!(breaker.can_request());
372
373 breaker.record_failure();
375 assert_eq!(breaker.state(), CircuitState::Open);
376 assert!(!breaker.can_request());
377 }
378
379 #[tokio::test]
380 async fn test_circuit_breaker_success_resets() {
381 let breaker = Arc::new(CircuitBreaker::new("test", 3, Duration::from_secs(60)));
382
383 breaker.record_failure();
384 breaker.record_failure();
385 assert_eq!(breaker.failure_count.load(Ordering::SeqCst), 2);
386
387 breaker.record_success();
388 assert_eq!(breaker.failure_count.load(Ordering::SeqCst), 0);
389 }
390
391 #[tokio::test]
392 async fn test_circuit_breaker_execute_success() {
393 let breaker = Arc::new(CircuitBreaker::new("test", 3, Duration::from_secs(60)));
394
395 let result = breaker.execute(async { Ok::<i32, &str>(42) }).await;
396 assert!(result.is_success());
397 assert_eq!(result.unwrap(), 42);
398 }
399
400 #[tokio::test]
401 async fn test_circuit_breaker_execute_rejected() {
402 let breaker = Arc::new(CircuitBreaker::new("test", 1, Duration::from_secs(60)));
403
404 breaker.record_failure();
406 assert_eq!(breaker.state(), CircuitState::Open);
407
408 let result = breaker.execute(async { Ok::<i32, &str>(42) }).await;
410 assert!(result.is_rejected());
411 }
412
413 #[test]
414 fn test_manager() {
415 let manager = CircuitBreakerManager::new();
416
417 let breaker1 = manager.get("source1");
419 let breaker2 = manager.get("source2");
420 let breaker1_again = manager.get("source1");
421
422 assert!(Arc::ptr_eq(&breaker1, &breaker1_again));
424 assert!(!Arc::ptr_eq(&breaker1, &breaker2));
426 }
427
428 #[test]
429 fn test_manager_status() {
430 let manager = CircuitBreakerManager::new();
431
432 let _ = manager.get("arxiv");
433 let _ = manager.get("semantic");
434
435 let status = manager.status();
436 assert_eq!(status.len(), 2);
437 assert!(status
438 .iter()
439 .all(|(_, state, _)| *state == CircuitState::Closed));
440 }
441}