Skip to main content

memlink_runtime/
circuit.rs

1//! Circuit breaker pattern for fault tolerance.
2//!
3//! Prevents cascading failures by stopping requests to failing modules
4//! and allowing them time to recover.
5
6use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
7use std::sync::Arc;
8use std::time::{Duration, Instant};
9
10/// Circuit breaker state.
11#[derive(Debug, Clone, Copy, PartialEq, Eq)]
12pub enum CircuitState {
13    /// Circuit is closed, requests flow normally.
14    Closed,
15    /// Circuit is open, requests are rejected immediately.
16    Open,
17    /// Circuit is half-open, testing if service recovered.
18    HalfOpen,
19}
20
21/// Circuit breaker configuration.
22#[derive(Debug, Clone)]
23pub struct CircuitConfig {
24    /// Number of consecutive failures before opening.
25    pub failure_threshold: u32,
26    /// Number of successes in half-open state to close.
27    pub success_threshold: u32,
28    /// Time to wait before transitioning from open to half-open.
29    pub open_timeout: Duration,
30    /// Time window for counting failures.
31    pub failure_window: Duration,
32    /// Minimum calls before circuit can open.
33    pub min_calls: u32,
34}
35
36impl Default for CircuitConfig {
37    fn default() -> Self {
38        Self {
39            failure_threshold: 5,
40            success_threshold: 3,
41            open_timeout: Duration::from_secs(30),
42            failure_window: Duration::from_secs(60),
43            min_calls: 10,
44        }
45    }
46}
47
48/// Statistics about circuit breaker state.
49#[derive(Debug, Clone)]
50pub struct CircuitStats {
51    pub state: CircuitState,
52    pub consecutive_failures: u32,
53    pub consecutive_successes: u32,
54    pub total_calls: u64,
55    pub total_failures: u64,
56    pub total_successes: u64,
57    pub total_rejected: u64,
58    pub last_failure: Option<Instant>,
59    pub last_state_change: Instant,
60}
61
62/// Circuit breaker for a module.
63#[derive(Debug)]
64pub struct CircuitBreaker {
65    /// Module name.
66    module_name: String,
67    /// Current state.
68    state: std::sync::Mutex<CircuitState>,
69    /// Consecutive failure count.
70    consecutive_failures: AtomicU32,
71    /// Consecutive success count (in half-open).
72    consecutive_successes: AtomicU32,
73    /// Total calls.
74    total_calls: AtomicU64,
75    /// Total failures.
76    total_failures: AtomicU64,
77    /// Total successes.
78    total_successes: AtomicU64,
79    /// Total rejected (circuit open).
80    total_rejected: AtomicU64,
81    /// Timestamp of last failure.
82    last_failure: std::sync::Mutex<Option<Instant>>,
83    /// Timestamp of last state change.
84    last_state_change: std::sync::Mutex<Instant>,
85    /// Configuration.
86    config: CircuitConfig,
87}
88
89impl CircuitBreaker {
90    /// Creates a new circuit breaker.
91    pub fn new(module_name: String, config: CircuitConfig) -> Self {
92        Self {
93            module_name,
94            state: std::sync::Mutex::new(CircuitState::Closed),
95            consecutive_failures: AtomicU32::new(0),
96            consecutive_successes: AtomicU32::new(0),
97            total_calls: AtomicU64::new(0),
98            total_failures: AtomicU64::new(0),
99            total_successes: AtomicU64::new(0),
100            total_rejected: AtomicU64::new(0),
101            last_failure: std::sync::Mutex::new(None),
102            last_state_change: std::sync::Mutex::new(Instant::now()),
103            config,
104        }
105    }
106
107    /// Creates a circuit breaker with default configuration.
108    pub fn with_defaults(module_name: String) -> Self {
109        Self::new(module_name, CircuitConfig::default())
110    }
111
112    /// Checks if a request can proceed.
113    pub fn can_execute(&self) -> bool {
114        let mut state = self.state.lock().unwrap();
115
116        match *state {
117            CircuitState::Closed => true,
118            CircuitState::Open => {
119                // Check if timeout has elapsed
120                let elapsed = self.last_state_change.lock().unwrap().elapsed();
121                if elapsed >= self.config.open_timeout {
122                    // Transition to half-open
123                    *state = CircuitState::HalfOpen;
124                    *self.last_state_change.lock().unwrap() = Instant::now();
125                    self.consecutive_successes.store(0, Ordering::Relaxed);
126                    true
127                } else {
128                    self.total_rejected.fetch_add(1, Ordering::Relaxed);
129                    false
130                }
131            }
132            CircuitState::HalfOpen => true,
133        }
134    }
135
136    /// Records a successful call.
137    pub fn record_success(&self) {
138        self.total_calls.fetch_add(1, Ordering::Relaxed);
139        self.total_successes.fetch_add(1, Ordering::Relaxed);
140        self.consecutive_failures.store(0, Ordering::Relaxed);
141
142        let mut state = self.state.lock().unwrap();
143
144        if *state == CircuitState::HalfOpen {
145            let successes = self.consecutive_successes.fetch_add(1, Ordering::Relaxed) + 1;
146            if successes >= self.config.success_threshold {
147                // Transition to closed
148                *state = CircuitState::Closed;
149                *self.last_state_change.lock().unwrap() = Instant::now();
150                self.consecutive_failures.store(0, Ordering::Relaxed);
151            }
152        }
153    }
154
155    /// Records a failed call.
156    pub fn record_failure(&self) {
157        self.total_calls.fetch_add(1, Ordering::Relaxed);
158        self.total_failures.fetch_add(1, Ordering::Relaxed);
159
160        let failures = self.consecutive_failures.fetch_add(1, Ordering::Relaxed) + 1;
161        *self.last_failure.lock().unwrap() = Some(Instant::now());
162
163        let mut state = self.state.lock().unwrap();
164
165        if *state == CircuitState::HalfOpen {
166            // Any failure in half-open goes back to open
167            *state = CircuitState::Open;
168            *self.last_state_change.lock().unwrap() = Instant::now();
169        } else if *state == CircuitState::Closed {
170            // Check if we should open
171            if failures >= self.config.failure_threshold {
172                let total = self.total_calls.load(Ordering::Relaxed);
173                if total >= self.config.min_calls as u64 {
174                    *state = CircuitState::Open;
175                    *self.last_state_change.lock().unwrap() = Instant::now();
176                }
177            }
178        }
179    }
180
181    /// Returns the current state.
182    pub fn state(&self) -> CircuitState {
183        *self.state.lock().unwrap()
184    }
185
186    /// Returns statistics.
187    pub fn stats(&self) -> CircuitStats {
188        CircuitStats {
189            state: self.state(),
190            consecutive_failures: self.consecutive_failures.load(Ordering::Relaxed),
191            consecutive_successes: self.consecutive_successes.load(Ordering::Relaxed),
192            total_calls: self.total_calls.load(Ordering::Relaxed),
193            total_failures: self.total_failures.load(Ordering::Relaxed),
194            total_successes: self.total_successes.load(Ordering::Relaxed),
195            total_rejected: self.total_rejected.load(Ordering::Relaxed),
196            last_failure: *self.last_failure.lock().unwrap(),
197            last_state_change: *self.last_state_change.lock().unwrap(),
198        }
199    }
200
201    /// Returns the module name.
202    pub fn module_name(&self) -> &str {
203        &self.module_name
204    }
205
206    /// Forces the circuit to open (for testing/manual intervention).
207    pub fn force_open(&self) {
208        *self.state.lock().unwrap() = CircuitState::Open;
209        *self.last_state_change.lock().unwrap() = Instant::now();
210    }
211
212    /// Forces the circuit to close (for testing/manual intervention).
213    pub fn force_close(&self) {
214        *self.state.lock().unwrap() = CircuitState::Closed;
215        *self.last_state_change.lock().unwrap() = Instant::now();
216        self.consecutive_failures.store(0, Ordering::Relaxed);
217    }
218
219    /// Resets all counters.
220    pub fn reset(&self) {
221        self.consecutive_failures.store(0, Ordering::Relaxed);
222        self.consecutive_successes.store(0, Ordering::Relaxed);
223        self.total_calls.store(0, Ordering::Relaxed);
224        self.total_failures.store(0, Ordering::Relaxed);
225        self.total_successes.store(0, Ordering::Relaxed);
226        self.total_rejected.store(0, Ordering::Relaxed);
227        *self.last_failure.lock().unwrap() = None;
228        self.force_close();
229    }
230}
231
232/// Registry of circuit breakers for all modules.
233#[derive(Debug)]
234pub struct CircuitRegistry {
235    /// Circuit breakers by module name.
236    circuits: DashMap<String, Arc<CircuitBreaker>>,
237    /// Default configuration.
238    default_config: CircuitConfig,
239}
240
241impl CircuitRegistry {
242    /// Creates a new circuit registry.
243    pub fn new() -> Self {
244        Self {
245            circuits: DashMap::new(),
246            default_config: CircuitConfig::default(),
247        }
248    }
249
250    /// Creates a registry with custom default configuration.
251    pub fn with_config(config: CircuitConfig) -> Self {
252        Self {
253            circuits: DashMap::new(),
254            default_config: config,
255        }
256    }
257
258    /// Gets or creates a circuit breaker for a module.
259    pub fn get_or_create(&self, module_name: &str) -> Arc<CircuitBreaker> {
260        self.circuits
261            .entry(module_name.to_string())
262            .or_insert_with(|| {
263                Arc::new(CircuitBreaker::new(
264                    module_name.to_string(),
265                    self.default_config.clone(),
266                ))
267            })
268            .clone()
269    }
270
271    /// Registers a module with custom configuration.
272    pub fn register(&self, module_name: &str, config: CircuitConfig) {
273        self.circuits.insert(
274            module_name.to_string(),
275            Arc::new(CircuitBreaker::new(module_name.to_string(), config)),
276        );
277    }
278
279    /// Checks if a request can proceed.
280    pub fn can_execute(&self, module_name: &str) -> bool {
281        self.get_or_create(module_name).can_execute()
282    }
283
284    /// Records a success.
285    pub fn record_success(&self, module_name: &str) {
286        self.get_or_create(module_name).record_success();
287    }
288
289    /// Records a failure.
290    pub fn record_failure(&self, module_name: &str) {
291        self.get_or_create(module_name).record_failure();
292    }
293
294    /// Returns statistics for a module.
295    pub fn stats(&self, module_name: &str) -> Option<CircuitStats> {
296        self.circuits.get(module_name).map(|c| c.stats())
297    }
298
299    /// Returns all circuit breaker stats.
300    pub fn all_stats(&self) -> Vec<(String, CircuitStats)> {
301        self.circuits
302            .iter()
303            .map(|e| (e.key().clone(), e.value().stats()))
304            .collect()
305    }
306
307    /// Returns the number of registered circuits.
308    pub fn circuit_count(&self) -> usize {
309        self.circuits.len()
310    }
311
312    /// Returns the number of open circuits.
313    pub fn open_circuit_count(&self) -> usize {
314        self.circuits
315            .iter()
316            .filter(|e| *e.value().state.lock().unwrap() == CircuitState::Open)
317            .count()
318    }
319}
320
321impl Default for CircuitRegistry {
322    fn default() -> Self {
323        Self::new()
324    }
325}
326
327use dashmap::DashMap;
328
329#[cfg(test)]
330mod tests {
331    use super::*;
332
333    #[test]
334    fn test_circuit_closed_initially() {
335        let cb = CircuitBreaker::with_defaults("test".to_string());
336        assert_eq!(cb.state(), CircuitState::Closed);
337        assert!(cb.can_execute());
338    }
339
340    #[test]
341    fn test_circuit_opens_after_failures() {
342        let config = CircuitConfig {
343            failure_threshold: 3,
344            min_calls: 0,
345            ..CircuitConfig::default()
346        };
347        let cb = CircuitBreaker::new("test".to_string(), config);
348
349        // Record failures
350        for _ in 0..3 {
351            cb.record_failure();
352        }
353
354        assert_eq!(cb.state(), CircuitState::Open);
355        assert!(!cb.can_execute());
356    }
357
358    #[test]
359    fn test_circuit_half_open_after_timeout() {
360        let config = CircuitConfig {
361            failure_threshold: 1,
362            min_calls: 0,
363            open_timeout: Duration::from_millis(100),
364            ..CircuitConfig::default()
365        };
366        let cb = CircuitBreaker::new("test".to_string(), config);
367
368        // Open the circuit
369        cb.record_failure();
370        assert_eq!(cb.state(), CircuitState::Open);
371
372        // Wait for timeout
373        std::thread::sleep(Duration::from_millis(150));
374
375        // Should transition to half-open
376        assert!(cb.can_execute());
377        assert_eq!(cb.state(), CircuitState::HalfOpen);
378    }
379
380    #[test]
381    fn test_circuit_closes_after_successes() {
382        let config = CircuitConfig {
383            failure_threshold: 1,
384            success_threshold: 2,
385            min_calls: 0,
386            open_timeout: Duration::from_millis(10),
387            ..CircuitConfig::default()
388        };
389        let cb = CircuitBreaker::new("test".to_string(), config);
390
391        // Open the circuit
392        cb.record_failure();
393
394        // Wait for half-open
395        std::thread::sleep(Duration::from_millis(50));
396        cb.can_execute(); // Triggers transition to half-open
397
398        // Record successes
399        cb.record_success();
400        cb.record_success();
401
402        assert_eq!(cb.state(), CircuitState::Closed);
403    }
404
405    #[test]
406    fn test_circuit_reopens_on_failure_in_half_open() {
407        let config = CircuitConfig {
408            failure_threshold: 1,
409            success_threshold: 2,
410            min_calls: 0,
411            open_timeout: Duration::from_millis(10),
412            ..CircuitConfig::default()
413        };
414        let cb = CircuitBreaker::new("test".to_string(), config);
415
416        // Open the circuit
417        cb.record_failure();
418
419        // Wait for half-open
420        std::thread::sleep(Duration::from_millis(50));
421        cb.can_execute();
422
423        // Record one success
424        cb.record_success();
425
426        // Record failure - should go back to open
427        cb.record_failure();
428
429        assert_eq!(cb.state(), CircuitState::Open);
430    }
431
432    #[test]
433    fn test_circuit_registry() {
434        let registry = CircuitRegistry::new();
435
436        let cb1 = registry.get_or_create("module1");
437        let cb2 = registry.get_or_create("module2");
438
439        assert_eq!(registry.circuit_count(), 2);
440
441        cb1.record_failure();
442        cb2.record_success();
443
444        let stats1 = registry.stats("module1").unwrap();
445        let stats2 = registry.stats("module2").unwrap();
446
447        assert_eq!(stats1.total_failures, 1);
448        assert_eq!(stats2.total_successes, 1);
449    }
450}