Skip to main content

hyperi_rustlib/tiered_sink/
circuit.rs

1// Project:   hyperi-rustlib
2// File:      src/tiered_sink/circuit.rs
3// Purpose:   Circuit breaker for sink health tracking
4// Language:  Rust
5//
6// License:   BUSL-1.1
7// Copyright: (c) 2026 HYPERI PTY LIMITED
8
9//! Circuit breaker for sink health tracking.
10
11use std::sync::Arc;
12use std::sync::atomic::{AtomicU8, AtomicU32, AtomicU64, Ordering};
13use std::time::Duration;
14use tokio::sync::RwLock;
15
16/// Circuit breaker state.
17#[derive(Debug, Clone, Copy, PartialEq, Eq)]
18pub enum CircuitState {
19    /// Circuit is closed - requests flow through normally.
20    Closed,
21    /// Circuit is open - requests are rejected, sink is known unhealthy.
22    Open,
23    /// Circuit is half-open - one probe request allowed to test recovery.
24    HalfOpen,
25}
26
27/// Circuit breaker for protecting against unhealthy sinks.
28///
29/// The circuit breaker tracks consecutive failures and opens when
30/// a threshold is reached. After a timeout, it allows a single probe
31/// request to test if the sink has recovered.
32pub struct CircuitBreaker {
33    state: RwLock<CircuitState>,
34    consecutive_failures: AtomicU32,
35    failure_threshold: u32,
36    reset_timeout: Duration,
37    last_failure_time: AtomicU64, // epoch millis
38    /// Atomic mirror of circuit state for sync health check access.
39    /// 0 = Closed, 1 = Open, 2 = HalfOpen.
40    health_state: Arc<AtomicU8>,
41}
42
43impl CircuitBreaker {
44    /// Create a new circuit breaker.
45    ///
46    /// - `failure_threshold`: Number of consecutive failures before opening
47    /// - `reset_timeout`: Time to wait before allowing a probe request
48    #[must_use]
49    pub fn new(failure_threshold: u32, reset_timeout: Duration) -> Self {
50        let health_state = Arc::new(AtomicU8::new(0)); // 0 = Closed
51
52        #[cfg(feature = "health")]
53        {
54            let hs = Arc::clone(&health_state);
55            crate::health::HealthRegistry::register("circuit_breaker", move || {
56                match hs.load(Ordering::Relaxed) {
57                    0 => crate::health::HealthStatus::Healthy,   // Closed
58                    2 => crate::health::HealthStatus::Degraded,  // HalfOpen
59                    _ => crate::health::HealthStatus::Unhealthy, // Open
60                }
61            });
62        }
63
64        Self {
65            state: RwLock::new(CircuitState::Closed),
66            consecutive_failures: AtomicU32::new(0),
67            failure_threshold,
68            reset_timeout,
69            last_failure_time: AtomicU64::new(0),
70            health_state,
71        }
72    }
73
74    /// Sync the atomic health state mirror with the current circuit state.
75    fn sync_health_state(&self, state: CircuitState) {
76        let val = match state {
77            CircuitState::Closed => 0,
78            CircuitState::Open => 1,
79            CircuitState::HalfOpen => 2,
80        };
81        self.health_state.store(val, Ordering::Relaxed);
82    }
83
84    /// Get current circuit state.
85    pub async fn state(&self) -> CircuitState {
86        let mut state = self.state.write().await;
87
88        // Check if we should transition from Open to HalfOpen
89        if *state == CircuitState::Open {
90            let last_failure = self.last_failure_time.load(Ordering::SeqCst);
91            let now = current_epoch_millis();
92            let elapsed = Duration::from_millis(now.saturating_sub(last_failure));
93
94            if elapsed >= self.reset_timeout {
95                *state = CircuitState::HalfOpen;
96                self.sync_health_state(*state);
97            }
98        }
99
100        *state
101    }
102
103    /// Check if requests should be allowed through.
104    pub async fn is_closed(&self) -> bool {
105        self.state().await == CircuitState::Closed
106    }
107
108    /// Check if circuit is open (requests should be rejected).
109    pub async fn is_open(&self) -> bool {
110        self.state().await == CircuitState::Open
111    }
112
113    /// Record a successful request.
114    pub async fn record_success(&self) {
115        let mut state = self.state.write().await;
116        self.consecutive_failures.store(0, Ordering::SeqCst);
117        *state = CircuitState::Closed;
118        self.sync_health_state(*state);
119    }
120
121    /// Record a failed request.
122    pub async fn record_failure(&self) {
123        let failures = self.consecutive_failures.fetch_add(1, Ordering::SeqCst) + 1;
124        self.last_failure_time
125            .store(current_epoch_millis(), Ordering::SeqCst);
126
127        if failures >= self.failure_threshold {
128            let mut state = self.state.write().await;
129            *state = CircuitState::Open;
130            self.sync_health_state(*state);
131        }
132    }
133
134    /// Get the number of consecutive failures.
135    #[must_use]
136    pub fn consecutive_failures(&self) -> u32 {
137        self.consecutive_failures.load(Ordering::SeqCst)
138    }
139
140    /// Reset the circuit breaker to closed state.
141    pub async fn reset(&self) {
142        self.consecutive_failures.store(0, Ordering::SeqCst);
143        let mut state = self.state.write().await;
144        *state = CircuitState::Closed;
145        self.sync_health_state(*state);
146    }
147}
148
149impl std::fmt::Debug for CircuitBreaker {
150    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
151        f.debug_struct("CircuitBreaker")
152            .field("failure_threshold", &self.failure_threshold)
153            .field("reset_timeout", &self.reset_timeout)
154            .field("consecutive_failures", &self.consecutive_failures())
155            .finish_non_exhaustive()
156    }
157}
158
159fn current_epoch_millis() -> u64 {
160    use std::time::SystemTime;
161    SystemTime::now()
162        .duration_since(SystemTime::UNIX_EPOCH)
163        .map_or(0, |d| u64::try_from(d.as_millis()).unwrap_or(u64::MAX))
164}
165
166#[cfg(test)]
167mod tests {
168    use super::*;
169
170    #[tokio::test]
171    async fn test_initial_state_is_closed() {
172        let cb = CircuitBreaker::new(3, Duration::from_secs(30));
173        assert_eq!(cb.state().await, CircuitState::Closed);
174        assert!(cb.is_closed().await);
175    }
176
177    #[tokio::test]
178    async fn test_opens_after_threshold() {
179        let cb = CircuitBreaker::new(3, Duration::from_secs(30));
180
181        cb.record_failure().await;
182        assert!(cb.is_closed().await);
183
184        cb.record_failure().await;
185        assert!(cb.is_closed().await);
186
187        cb.record_failure().await;
188        assert!(cb.is_open().await);
189        assert_eq!(cb.consecutive_failures(), 3);
190    }
191
192    #[tokio::test]
193    async fn test_success_resets_failures() {
194        let cb = CircuitBreaker::new(3, Duration::from_secs(30));
195
196        cb.record_failure().await;
197        cb.record_failure().await;
198        assert_eq!(cb.consecutive_failures(), 2);
199
200        cb.record_success().await;
201        assert_eq!(cb.consecutive_failures(), 0);
202        assert!(cb.is_closed().await);
203    }
204
205    #[tokio::test]
206    async fn test_half_open_after_timeout() {
207        let cb = CircuitBreaker::new(1, Duration::from_millis(50));
208
209        cb.record_failure().await;
210        assert!(cb.is_open().await);
211
212        // Wait for reset timeout
213        tokio::time::sleep(Duration::from_millis(100)).await;
214
215        assert_eq!(cb.state().await, CircuitState::HalfOpen);
216    }
217
218    #[tokio::test]
219    async fn test_half_open_success_closes() {
220        let cb = CircuitBreaker::new(1, Duration::from_millis(10));
221
222        cb.record_failure().await;
223        tokio::time::sleep(Duration::from_millis(20)).await;
224
225        assert_eq!(cb.state().await, CircuitState::HalfOpen);
226
227        cb.record_success().await;
228        assert!(cb.is_closed().await);
229    }
230
231    #[tokio::test]
232    async fn test_half_open_failure_reopens() {
233        let cb = CircuitBreaker::new(1, Duration::from_millis(10));
234
235        cb.record_failure().await;
236        tokio::time::sleep(Duration::from_millis(20)).await;
237
238        assert_eq!(cb.state().await, CircuitState::HalfOpen);
239
240        cb.record_failure().await;
241        assert!(cb.is_open().await);
242    }
243
244    #[tokio::test]
245    async fn test_reset() {
246        let cb = CircuitBreaker::new(1, Duration::from_secs(30));
247
248        cb.record_failure().await;
249        assert!(cb.is_open().await);
250
251        cb.reset().await;
252        assert!(cb.is_closed().await);
253        assert_eq!(cb.consecutive_failures(), 0);
254    }
255}