replication_engine/
circuit_breaker.rs

1// Copyright (c) 2025-2026 Adrian Robinson. Licensed under the AGPL-3.0.
2// See LICENSE file in the project root for full license text.
3
4//! Circuit breaker pattern for sync-engine protection.
5//!
6//! Prevents cascading failures when sync-engine is overloaded or unhealthy.
7//! Uses the `recloser` crate following sync-engine's proven patterns.
8//!
9//! # States
10//!
11//! - **Closed**: Normal operation, requests pass through
12//! - **Open**: Sync-engine unhealthy, requests fail-fast without attempting
13//! - **HalfOpen**: Testing if sync-engine recovered, limited requests allowed
14//!
15//! # Usage
16//!
17//! ```rust,no_run
18//! # use replication_engine::circuit_breaker::{SyncEngineCircuit, CircuitError};
19//! # async fn example() -> Result<(), CircuitError<String>> {
20//! let circuit = SyncEngineCircuit::new();
21//!
22//! // Wrap sync-engine write calls
23//! match circuit.writes.call(|| async { Ok::<(), String>(()) }).await {
24//!     Ok(()) => { /* success */ }
25//!     Err(CircuitError::Rejected) => { /* circuit open, backoff */ }
26//!     Err(CircuitError::Inner(e)) => { /* sync-engine error */ }
27//! }
28//! # Ok(())
29//! # }
30//! ```
31
32use recloser::{AsyncRecloser, Error as RecloserError, Recloser};
33use std::future::Future;
34use std::sync::atomic::{AtomicU64, Ordering};
35use std::time::Duration;
36use tracing::{debug, warn};
37
38/// Circuit breaker state for metrics/monitoring.
39#[derive(Debug, Clone, Copy, PartialEq, Eq)]
40pub enum CircuitState {
41    /// Normal operation, requests pass through
42    Closed = 0,
43    /// Testing if service recovered
44    HalfOpen = 1,
45    /// Service unhealthy, fail-fast
46    Open = 2,
47}
48
49impl std::fmt::Display for CircuitState {
50    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
51        match self {
52            Self::Closed => write!(f, "closed"),
53            Self::HalfOpen => write!(f, "half_open"),
54            Self::Open => write!(f, "open"),
55        }
56    }
57}
58
59/// Error type for circuit-protected operations.
60#[derive(Debug, thiserror::Error)]
61pub enum CircuitError<E> {
62    /// The circuit breaker rejected the call (circuit is open).
63    #[error("circuit breaker open, request rejected")]
64    Rejected,
65
66    /// The underlying operation failed.
67    #[error("operation failed: {0}")]
68    Inner(#[source] E),
69}
70
71impl<E> CircuitError<E> {
72    /// Check if this is a rejection (circuit open).
73    pub fn is_rejected(&self) -> bool {
74        matches!(self, CircuitError::Rejected)
75    }
76
77    /// Check if this is an inner error.
78    pub fn is_inner(&self) -> bool {
79        matches!(self, CircuitError::Inner(_))
80    }
81
82    /// Get the inner error if present.
83    pub fn inner(&self) -> Option<&E> {
84        match self {
85            CircuitError::Inner(e) => Some(e),
86            _ => None,
87        }
88    }
89}
90
91impl<E> From<RecloserError<E>> for CircuitError<E> {
92    fn from(err: RecloserError<E>) -> Self {
93        match err {
94            RecloserError::Rejected => CircuitError::Rejected,
95            RecloserError::Inner(e) => CircuitError::Inner(e),
96        }
97    }
98}
99
100/// Configuration for a circuit breaker.
101#[derive(Debug, Clone)]
102pub struct CircuitConfig {
103    /// Number of consecutive failures to trip the circuit.
104    pub failure_threshold: u32,
105    /// Number of consecutive successes in half-open to close circuit.
106    pub success_threshold: u32,
107    /// How long to wait before attempting recovery (half-open).
108    pub recovery_timeout: Duration,
109}
110
111impl Default for CircuitConfig {
112    fn default() -> Self {
113        Self {
114            failure_threshold: 5,
115            success_threshold: 2,
116            recovery_timeout: Duration::from_secs(30),
117        }
118    }
119}
120
121impl CircuitConfig {
122    /// Aggressive config for critical paths (trips faster, recovers cautiously).
123    ///
124    /// Use this for sync-engine calls where we don't want to hammer a struggling service.
125    #[must_use]
126    pub fn aggressive() -> Self {
127        Self {
128            failure_threshold: 3,
129            success_threshold: 3,
130            recovery_timeout: Duration::from_secs(60),
131        }
132    }
133
134    /// Lenient config for less critical paths (tolerates more failures).
135    #[must_use]
136    pub fn lenient() -> Self {
137        Self {
138            failure_threshold: 10,
139            success_threshold: 1,
140            recovery_timeout: Duration::from_secs(15),
141        }
142    }
143
144    /// Fast recovery for testing.
145    #[cfg(test)]
146    pub fn test() -> Self {
147        Self {
148            failure_threshold: 2,
149            success_threshold: 1,
150            recovery_timeout: Duration::from_millis(50),
151        }
152    }
153}
154
155/// A named circuit breaker with metrics tracking.
156pub struct CircuitBreaker {
157    name: String,
158    inner: AsyncRecloser,
159
160    // Metrics
161    calls_total: AtomicU64,
162    successes: AtomicU64,
163    failures: AtomicU64,
164    rejections: AtomicU64,
165}
166
167impl CircuitBreaker {
168    /// Create a new circuit breaker with the given name and config.
169    pub fn new(name: impl Into<String>, config: CircuitConfig) -> Self {
170        let recloser = Recloser::custom()
171            .error_rate(config.failure_threshold as f32 / 100.0)
172            .closed_len(config.failure_threshold as usize)
173            .half_open_len(config.success_threshold as usize)
174            .open_wait(config.recovery_timeout)
175            .build();
176
177        Self {
178            name: name.into(),
179            inner: recloser.into(),
180            calls_total: AtomicU64::new(0),
181            successes: AtomicU64::new(0),
182            failures: AtomicU64::new(0),
183            rejections: AtomicU64::new(0),
184        }
185    }
186
187    /// Create with default config.
188    pub fn with_defaults(name: impl Into<String>) -> Self {
189        Self::new(name, CircuitConfig::default())
190    }
191
192    /// Get the circuit breaker name.
193    #[must_use]
194    pub fn name(&self) -> &str {
195        &self.name
196    }
197
198    /// Execute an async operation through the circuit breaker.
199    ///
200    /// Takes a closure that returns a Future, allowing lazy evaluation.
201    pub async fn call<F, Fut, T, E>(&self, f: F) -> Result<T, CircuitError<E>>
202    where
203        F: FnOnce() -> Fut,
204        Fut: Future<Output = Result<T, E>>,
205    {
206        self.calls_total.fetch_add(1, Ordering::Relaxed);
207
208        match self.inner.call(f()).await {
209            Ok(result) => {
210                self.successes.fetch_add(1, Ordering::Relaxed);
211                debug!(circuit = %self.name, "Circuit call succeeded");
212                Ok(result)
213            }
214            Err(RecloserError::Rejected) => {
215                self.rejections.fetch_add(1, Ordering::Relaxed);
216                warn!(circuit = %self.name, "Circuit breaker rejected call (open)");
217                Err(CircuitError::Rejected)
218            }
219            Err(RecloserError::Inner(e)) => {
220                self.failures.fetch_add(1, Ordering::Relaxed);
221                debug!(circuit = %self.name, "Circuit call failed");
222                Err(CircuitError::Inner(e))
223            }
224        }
225    }
226
227    /// Get total number of calls.
228    #[must_use]
229    pub fn calls_total(&self) -> u64 {
230        self.calls_total.load(Ordering::Relaxed)
231    }
232
233    /// Get number of successful calls.
234    #[must_use]
235    pub fn successes(&self) -> u64 {
236        self.successes.load(Ordering::Relaxed)
237    }
238
239    /// Get number of failed calls (operation errors).
240    #[must_use]
241    pub fn failures(&self) -> u64 {
242        self.failures.load(Ordering::Relaxed)
243    }
244
245    /// Get number of rejected calls (circuit open).
246    #[must_use]
247    pub fn rejections(&self) -> u64 {
248        self.rejections.load(Ordering::Relaxed)
249    }
250
251    /// Get failure rate (0.0 - 1.0).
252    #[must_use]
253    pub fn failure_rate(&self) -> f64 {
254        let total = self.calls_total();
255        if total == 0 {
256            return 0.0;
257        }
258        self.failures() as f64 / total as f64
259    }
260
261    /// Check if circuit is likely open (based on recent rejections).
262    #[must_use]
263    pub fn is_likely_open(&self) -> bool {
264        self.rejections() > 0 && self.rejections() > self.successes()
265    }
266
267    /// Reset all metrics.
268    pub fn reset_metrics(&self) {
269        self.calls_total.store(0, Ordering::Relaxed);
270        self.successes.store(0, Ordering::Relaxed);
271        self.failures.store(0, Ordering::Relaxed);
272        self.rejections.store(0, Ordering::Relaxed);
273    }
274}
275
276/// Circuit breaker specifically for sync-engine operations.
277///
278/// This wraps sync-engine calls (submit, delete, is_current) with
279/// circuit breaker protection to prevent hammering an overloaded service.
280pub struct SyncEngineCircuit {
281    /// Circuit for write operations (submit, delete)
282    pub writes: CircuitBreaker,
283    /// Circuit for read operations (is_current, get)
284    pub reads: CircuitBreaker,
285}
286
287impl Default for SyncEngineCircuit {
288    fn default() -> Self {
289        Self::new()
290    }
291}
292
293impl SyncEngineCircuit {
294    /// Create sync-engine circuits with appropriate configs.
295    pub fn new() -> Self {
296        Self {
297            // Writes: aggressive (don't hammer sync-engine when struggling)
298            writes: CircuitBreaker::new("sync_engine_writes", CircuitConfig::aggressive()),
299            // Reads: more lenient (is_current can tolerate more failures)
300            reads: CircuitBreaker::new("sync_engine_reads", CircuitConfig::default()),
301        }
302    }
303
304    /// Create with custom configs.
305    pub fn with_configs(writes_config: CircuitConfig, reads_config: CircuitConfig) -> Self {
306        Self {
307            writes: CircuitBreaker::new("sync_engine_writes", writes_config),
308            reads: CircuitBreaker::new("sync_engine_reads", reads_config),
309        }
310    }
311
312    /// Get aggregated metrics.
313    pub fn metrics(&self) -> SyncEngineCircuitMetrics {
314        SyncEngineCircuitMetrics {
315            writes_total: self.writes.calls_total(),
316            writes_successes: self.writes.successes(),
317            writes_failures: self.writes.failures(),
318            writes_rejections: self.writes.rejections(),
319            reads_total: self.reads.calls_total(),
320            reads_successes: self.reads.successes(),
321            reads_failures: self.reads.failures(),
322            reads_rejections: self.reads.rejections(),
323        }
324    }
325
326    /// Check if any circuit is open.
327    pub fn any_open(&self) -> bool {
328        self.writes.is_likely_open() || self.reads.is_likely_open()
329    }
330}
331
332/// Aggregated metrics from sync-engine circuits.
333#[derive(Debug, Clone, Default)]
334pub struct SyncEngineCircuitMetrics {
335    pub writes_total: u64,
336    pub writes_successes: u64,
337    pub writes_failures: u64,
338    pub writes_rejections: u64,
339    pub reads_total: u64,
340    pub reads_successes: u64,
341    pub reads_failures: u64,
342    pub reads_rejections: u64,
343}
344
345#[cfg(test)]
346mod tests {
347    use super::*;
348    use std::sync::atomic::AtomicUsize;
349
350    #[tokio::test]
351    async fn test_circuit_passes_successful_calls() {
352        let cb = CircuitBreaker::new("test", CircuitConfig::test());
353
354        let result: Result<i32, CircuitError<&str>> = cb.call(|| async { Ok(42) }).await;
355
356        assert!(result.is_ok());
357        assert_eq!(result.unwrap(), 42);
358        assert_eq!(cb.successes(), 1);
359        assert_eq!(cb.failures(), 0);
360    }
361
362    #[tokio::test]
363    async fn test_circuit_tracks_failures() {
364        let cb = CircuitBreaker::new("test", CircuitConfig::test());
365
366        let result: Result<i32, CircuitError<&str>> = cb.call(|| async { Err("boom") }).await;
367
368        assert!(matches!(result, Err(CircuitError::Inner("boom"))));
369        assert_eq!(cb.successes(), 0);
370        assert_eq!(cb.failures(), 1);
371    }
372
373    #[tokio::test]
374    async fn test_circuit_opens_after_threshold() {
375        let config = CircuitConfig {
376            failure_threshold: 2,
377            success_threshold: 1,
378            recovery_timeout: Duration::from_secs(60),
379        };
380        let cb = CircuitBreaker::new("test", config);
381
382        // Fail multiple times to trip the breaker
383        for _ in 0..5 {
384            let _: Result<i32, CircuitError<&str>> = cb.call(|| async { Err("fail") }).await;
385        }
386
387        // Verify we have failures and/or rejections
388        assert!(cb.failures() >= 2 || cb.rejections() >= 1);
389    }
390
391    #[tokio::test]
392    async fn test_circuit_metrics_accumulate() {
393        // Use high threshold to avoid tripping
394        let config = CircuitConfig {
395            failure_threshold: 100,
396            success_threshold: 1,
397            recovery_timeout: Duration::from_secs(60),
398        };
399        let cb = CircuitBreaker::new("test", config);
400
401        let _: Result<i32, CircuitError<&str>> = cb.call(|| async { Ok(1) }).await;
402        let _: Result<i32, CircuitError<&str>> = cb.call(|| async { Ok(2) }).await;
403        let _: Result<i32, CircuitError<&str>> = cb.call(|| async { Ok(3) }).await;
404
405        assert_eq!(cb.calls_total(), 3);
406        assert_eq!(cb.successes(), 3);
407        assert_eq!(cb.failures(), 0);
408    }
409
410    #[tokio::test]
411    async fn test_failure_rate_calculation() {
412        let config = CircuitConfig {
413            failure_threshold: 100,
414            success_threshold: 1,
415            recovery_timeout: Duration::from_secs(60),
416        };
417        let cb = CircuitBreaker::new("test", config);
418
419        // 2 success, 2 failure = 50% failure rate
420        let _: Result<i32, CircuitError<&str>> = cb.call(|| async { Ok(1) }).await;
421        let _: Result<i32, CircuitError<&str>> = cb.call(|| async { Err("x") }).await;
422        let _: Result<i32, CircuitError<&str>> = cb.call(|| async { Ok(2) }).await;
423        let _: Result<i32, CircuitError<&str>> = cb.call(|| async { Err("y") }).await;
424
425        assert!((cb.failure_rate() - 0.5).abs() < 0.01);
426    }
427
428    #[tokio::test]
429    async fn test_reset_metrics() {
430        let cb = CircuitBreaker::new("test", CircuitConfig::test());
431
432        let _: Result<i32, CircuitError<&str>> = cb.call(|| async { Ok(1) }).await;
433
434        assert!(cb.calls_total() > 0);
435
436        cb.reset_metrics();
437
438        assert_eq!(cb.calls_total(), 0);
439        assert_eq!(cb.successes(), 0);
440        assert_eq!(cb.failures(), 0);
441        assert_eq!(cb.rejections(), 0);
442    }
443
444    #[tokio::test]
445    async fn test_sync_engine_circuits() {
446        let circuits = SyncEngineCircuit::new();
447
448        assert_eq!(circuits.writes.name(), "sync_engine_writes");
449        assert_eq!(circuits.reads.name(), "sync_engine_reads");
450    }
451
452    #[tokio::test]
453    async fn test_circuit_with_async_state() {
454        let cb = CircuitBreaker::new("test", CircuitConfig::test());
455        let counter = std::sync::Arc::new(AtomicUsize::new(0));
456
457        let counter_clone = counter.clone();
458        let result: Result<usize, CircuitError<&str>> = cb
459            .call(|| async move {
460                counter_clone.fetch_add(1, Ordering::SeqCst);
461                Ok(counter_clone.load(Ordering::SeqCst))
462            })
463            .await;
464
465        assert_eq!(result.unwrap(), 1);
466        assert_eq!(counter.load(Ordering::SeqCst), 1);
467    }
468
469    #[tokio::test]
470    async fn test_sync_engine_circuit_metrics() {
471        let circuits = SyncEngineCircuit::new();
472
473        let _: Result<i32, CircuitError<&str>> = circuits.writes.call(|| async { Ok(1) }).await;
474        let _: Result<i32, CircuitError<&str>> =
475            circuits.reads.call(|| async { Err("timeout") }).await;
476
477        let metrics = circuits.metrics();
478
479        assert_eq!(metrics.writes_total, 1);
480        assert_eq!(metrics.writes_successes, 1);
481        assert_eq!(metrics.reads_total, 1);
482        assert_eq!(metrics.reads_failures, 1);
483    }
484
485    #[test]
486    fn test_circuit_config_presets() {
487        let default = CircuitConfig::default();
488        let aggressive = CircuitConfig::aggressive();
489        let lenient = CircuitConfig::lenient();
490
491        // Aggressive trips faster
492        assert!(aggressive.failure_threshold < default.failure_threshold);
493        // Lenient tolerates more
494        assert!(lenient.failure_threshold > default.failure_threshold);
495        // Aggressive waits longer to recover
496        assert!(aggressive.recovery_timeout > lenient.recovery_timeout);
497    }
498
499    #[test]
500    fn test_circuit_error_methods() {
501        let rejected: CircuitError<&str> = CircuitError::Rejected;
502        assert!(rejected.is_rejected());
503        assert!(!rejected.is_inner());
504        assert!(rejected.inner().is_none());
505
506        let inner: CircuitError<&str> = CircuitError::Inner("boom");
507        assert!(!inner.is_rejected());
508        assert!(inner.is_inner());
509        assert_eq!(inner.inner(), Some(&"boom"));
510    }
511}