Skip to main content

pe_core/
circuit_breaker.rs

1//! Circuit breaker -- fail-fast on repeated transient LLM errors.
2//!
3//! States: **Closed** (normal), **Open** (rejecting), **HalfOpen** (probing).
4//! Uses atomic operations for lock-free state management.
5
6use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
7use std::time::Duration;
8
9use async_trait::async_trait;
10
11use crate::error::PeError;
12use crate::llm::{LlmProvider, LlmResponse, ToolSchema};
13use crate::message::Message;
14use crate::provider_middleware::ProviderMiddleware;
15
16/// Circuit breaker state encoded as u32 for atomic storage.
17const STATE_CLOSED: u32 = 0;
18const STATE_OPEN: u32 = 1;
19const STATE_HALF_OPEN: u32 = 2;
20
21/// Lock-free circuit breaker for LLM providers.
22///
23/// Tracks consecutive transient failures. When `failure_threshold` is reached,
24/// the breaker opens and rejects calls immediately. After `recovery_timeout`,
25/// it moves to half-open and allows one probe call.
26pub struct CircuitBreaker {
27    failure_threshold: u32,
28    recovery_timeout: Duration,
29    /// Current state (STATE_CLOSED / STATE_OPEN / STATE_HALF_OPEN).
30    state: AtomicU32,
31    /// Consecutive failure count.
32    failure_count: AtomicU32,
33    /// Timestamp (millis since epoch) when the breaker opened.
34    opened_at: AtomicU64,
35}
36
37impl CircuitBreaker {
38    /// Create a circuit breaker.
39    ///
40    /// - `failure_threshold`: consecutive transient failures before opening.
41    /// - `recovery_timeout`: time to wait before probing in half-open state.
42    pub fn new(failure_threshold: u32, recovery_timeout: Duration) -> Self {
43        Self {
44            failure_threshold,
45            recovery_timeout,
46            state: AtomicU32::new(STATE_CLOSED),
47            failure_count: AtomicU32::new(0),
48            opened_at: AtomicU64::new(0),
49        }
50    }
51
52    /// Current state as a human-readable string (for diagnostics).
53    pub fn state_name(&self) -> &'static str {
54        match self.state.load(Ordering::SeqCst) {
55            STATE_CLOSED => "closed",
56            STATE_OPEN => "open",
57            STATE_HALF_OPEN => "half-open",
58            _ => "unknown",
59        }
60    }
61
62    /// Current consecutive failure count.
63    pub fn failure_count(&self) -> u32 {
64        self.failure_count.load(Ordering::SeqCst)
65    }
66
67    fn now_millis() -> u64 {
68        std::time::SystemTime::now()
69            .duration_since(std::time::UNIX_EPOCH)
70            .unwrap_or_default()
71            .as_millis() as u64
72    }
73
74    fn record_success(&self) {
75        self.failure_count.store(0, Ordering::SeqCst);
76        self.state.store(STATE_CLOSED, Ordering::SeqCst);
77    }
78
79    fn record_failure(&self) {
80        let count = self.failure_count.fetch_add(1, Ordering::SeqCst) + 1;
81        if count >= self.failure_threshold {
82            self.state.store(STATE_OPEN, Ordering::SeqCst);
83            self.opened_at.store(Self::now_millis(), Ordering::SeqCst);
84        }
85    }
86
87    fn should_allow(&self) -> bool {
88        match self.state.load(Ordering::SeqCst) {
89            STATE_CLOSED => true,
90            STATE_HALF_OPEN => {
91                // Only one probe at a time: CAS from HalfOpen to Closed (probe in flight).
92                // All other callers are rejected until the probe completes.
93                self.state
94                    .compare_exchange(
95                        STATE_HALF_OPEN,
96                        STATE_CLOSED,
97                        Ordering::SeqCst,
98                        Ordering::SeqCst,
99                    )
100                    .is_ok()
101            }
102            STATE_OPEN => {
103                let opened = self.opened_at.load(Ordering::SeqCst);
104                let elapsed = Self::now_millis().saturating_sub(opened);
105                if elapsed >= self.recovery_timeout.as_millis() as u64 {
106                    // CAS from Open to HalfOpen — only one thread wins
107                    self.state
108                        .compare_exchange(
109                            STATE_OPEN,
110                            STATE_HALF_OPEN,
111                            Ordering::SeqCst,
112                            Ordering::SeqCst,
113                        )
114                        .is_ok()
115                } else {
116                    false
117                }
118            }
119            _ => false,
120        }
121    }
122}
123
124#[async_trait]
125impl ProviderMiddleware for CircuitBreaker {
126    async fn wrap_complete(
127        &self,
128        messages: &[Message],
129        tools: &[ToolSchema],
130        next: &dyn LlmProvider,
131    ) -> Result<LlmResponse, PeError> {
132        if !self.should_allow() {
133            return Err(PeError::LlmProvider {
134                details: "circuit breaker open — provider is unavailable".into(),
135            });
136        }
137
138        match next.complete(messages, tools).await {
139            Ok(resp) => {
140                self.record_success();
141                Ok(resp)
142            }
143            Err(e) if e.is_transient() => {
144                self.record_failure();
145                Err(e)
146            }
147            Err(e) => Err(e), // permanent errors don't affect the breaker
148        }
149    }
150}
151
152#[cfg(test)]
153mod tests {
154    use super::*;
155    use crate::mock_provider::MockProvider;
156
157    fn llm_err() -> PeError {
158        PeError::LlmProvider {
159            details: "err".into(),
160        }
161    }
162
163    fn fail_provider(n: usize) -> MockProvider {
164        let mut p = MockProvider::new();
165        for _ in 0..n {
166            p = p.respond_with_error(llm_err());
167        }
168        p
169    }
170
171    #[tokio::test]
172    async fn test_closed_allows_calls() {
173        let cb = CircuitBreaker::new(3, Duration::from_secs(60));
174        let resp = cb
175            .wrap_complete(&[], &[], &MockProvider::new().respond_with("ok"))
176            .await
177            .unwrap();
178        assert_eq!(resp.message.content.as_text(), Some("ok"));
179        assert_eq!(cb.state_name(), "closed");
180    }
181
182    #[tokio::test]
183    async fn test_opens_after_threshold_failures() {
184        let cb = CircuitBreaker::new(2, Duration::from_secs(60));
185        let provider = fail_provider(2);
186        let _ = cb.wrap_complete(&[], &[], &provider).await;
187        assert_eq!(cb.state_name(), "closed");
188        let _ = cb.wrap_complete(&[], &[], &provider).await;
189        assert_eq!(cb.state_name(), "open");
190    }
191
192    #[tokio::test]
193    async fn test_open_rejects_immediately() {
194        let cb = CircuitBreaker::new(1, Duration::from_secs(60));
195        let _ = cb.wrap_complete(&[], &[], &fail_provider(1)).await;
196        assert_eq!(cb.state_name(), "open");
197
198        let ok = MockProvider::new().respond_with("should not reach");
199        let err = cb.wrap_complete(&[], &[], &ok).await.unwrap_err();
200        assert!(matches!(err, PeError::LlmProvider { .. }));
201        assert_eq!(ok.remaining(), 1);
202    }
203
204    #[tokio::test]
205    async fn test_half_open_recovery_and_reopen() {
206        let cb = CircuitBreaker::new(1, Duration::from_millis(10));
207        let _ = cb.wrap_complete(&[], &[], &fail_provider(1)).await;
208        assert_eq!(cb.state_name(), "open");
209        tokio::time::sleep(Duration::from_millis(20)).await;
210
211        // Successful probe resets to closed
212        let probe = MockProvider::new().respond_with("recovered");
213        let resp = cb.wrap_complete(&[], &[], &probe).await.unwrap();
214        assert_eq!(resp.message.content.as_text(), Some("recovered"));
215        assert_eq!(cb.state_name(), "closed");
216
217        // Trip again, wait, probe fails -> reopens
218        let _ = cb.wrap_complete(&[], &[], &fail_provider(1)).await;
219        tokio::time::sleep(Duration::from_millis(20)).await;
220        let _ = cb.wrap_complete(&[], &[], &fail_provider(1)).await;
221        assert_eq!(cb.state_name(), "open");
222    }
223
224    #[tokio::test]
225    async fn test_permanent_errors_dont_trip_breaker() {
226        let cb = CircuitBreaker::new(1, Duration::from_secs(60));
227        let p = MockProvider::new().respond_with_error(PeError::PermissionDenied {
228            action: "write".into(),
229        });
230        let _ = cb.wrap_complete(&[], &[], &p).await;
231        assert_eq!(cb.state_name(), "closed");
232        assert_eq!(cb.failure_count(), 0);
233    }
234
235    #[tokio::test]
236    async fn test_success_resets_failure_count() {
237        let cb = CircuitBreaker::new(3, Duration::from_secs(60));
238        let provider = fail_provider(2).respond_with("ok");
239        let _ = cb.wrap_complete(&[], &[], &provider).await;
240        let _ = cb.wrap_complete(&[], &[], &provider).await;
241        assert_eq!(cb.failure_count(), 2);
242        let _ = cb.wrap_complete(&[], &[], &provider).await;
243        assert_eq!(cb.failure_count(), 0);
244    }
245}