Skip to main content

camel_processor/
circuit_breaker.rs

1use std::future::Future;
2use std::pin::Pin;
3use std::sync::{Arc, Mutex};
4use std::task::{Context, Poll};
5use std::time::Instant;
6
7use tower::{Layer, Service};
8
9use camel_api::{CamelError, CircuitBreakerConfig, Exchange};
10
11// ── State ──────────────────────────────────────────────────────────────
12
13enum CircuitState {
14    Closed { consecutive_failures: u32 },
15    Open { opened_at: Instant },
16    HalfOpen,
17}
18
19// ── Layer ──────────────────────────────────────────────────────────────
20
21/// Tower Layer that wraps an inner service with circuit-breaker logic.
22#[derive(Clone)]
23pub struct CircuitBreakerLayer {
24    config: CircuitBreakerConfig,
25    state: Arc<Mutex<CircuitState>>,
26}
27
28impl CircuitBreakerLayer {
29    pub fn new(config: CircuitBreakerConfig) -> Self {
30        Self {
31            config,
32            state: Arc::new(Mutex::new(CircuitState::Closed {
33                consecutive_failures: 0,
34            })),
35        }
36    }
37}
38
39impl<S> Layer<S> for CircuitBreakerLayer {
40    type Service = CircuitBreakerService<S>;
41
42    fn layer(&self, inner: S) -> Self::Service {
43        CircuitBreakerService {
44            inner,
45            config: self.config.clone(),
46            state: Arc::clone(&self.state),
47        }
48    }
49}
50
51// ── Service ────────────────────────────────────────────────────────────
52
53/// Tower Service implementing the circuit-breaker pattern.
54pub struct CircuitBreakerService<S> {
55    inner: S,
56    config: CircuitBreakerConfig,
57    state: Arc<Mutex<CircuitState>>,
58}
59
60impl<S: Clone> Clone for CircuitBreakerService<S> {
61    fn clone(&self) -> Self {
62        Self {
63            inner: self.inner.clone(),
64            config: self.config.clone(),
65            state: Arc::clone(&self.state),
66        }
67    }
68}
69
70impl<S> Service<Exchange> for CircuitBreakerService<S>
71where
72    S: Service<Exchange, Response = Exchange, Error = CamelError> + Clone + Send + 'static,
73    S::Future: Send,
74{
75    type Response = Exchange;
76    type Error = CamelError;
77    type Future = Pin<Box<dyn Future<Output = Result<Exchange, CamelError>> + Send>>;
78
79    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
80        let mut state = self.state.lock().unwrap_or_else(|e| e.into_inner());
81        match *state {
82            CircuitState::Closed { .. } => {
83                drop(state);
84                self.inner.poll_ready(cx)
85            }
86            CircuitState::Open { opened_at } => {
87                if opened_at.elapsed() >= self.config.open_duration {
88                    tracing::info!("Circuit breaker transitioning from Open to HalfOpen");
89                    *state = CircuitState::HalfOpen;
90                    drop(state);
91                    self.inner.poll_ready(cx)
92                } else if self.config.fallback.is_some() {
93                    Poll::Ready(Ok(()))
94                } else {
95                    Poll::Ready(Err(CamelError::CircuitOpen(
96                        "circuit breaker is open".into(),
97                    )))
98                }
99            }
100            CircuitState::HalfOpen => {
101                drop(state);
102                self.inner.poll_ready(cx)
103            }
104        }
105    }
106
107    fn call(&mut self, exchange: Exchange) -> Self::Future {
108        {
109            let mut st = self.state.lock().unwrap_or_else(|e| e.into_inner());
110            if let CircuitState::Open { opened_at } = *st {
111                if opened_at.elapsed() < self.config.open_duration {
112                    if let Some(mut fallback) = self.config.fallback.clone() {
113                        return Box::pin(async move { fallback.call(exchange).await });
114                    }
115                    return Box::pin(async {
116                        Err(CamelError::CircuitOpen("circuit breaker is open".into()))
117                    });
118                }
119
120                tracing::info!("Circuit breaker transitioning from Open to HalfOpen");
121                *st = CircuitState::HalfOpen;
122            }
123        }
124
125        // Clone inner service (Tower pattern) and state handle.
126        let mut inner = self.inner.clone();
127        let state = Arc::clone(&self.state);
128        let config = self.config.clone();
129
130        // Snapshot the current state before calling (briefly lock).
131        let current_is_half_open = matches!(
132            *state.lock().unwrap_or_else(|e| e.into_inner()),
133            CircuitState::HalfOpen
134        );
135
136        Box::pin(async move {
137            let result = inner.call(exchange).await;
138
139            // Update state based on result (briefly lock).
140            let mut st = state.lock().unwrap_or_else(|e| e.into_inner());
141            match &result {
142                Ok(_) => {
143                    // Success → reset to Closed.
144                    if current_is_half_open {
145                        tracing::info!("Circuit breaker transitioning from HalfOpen to Closed");
146                    }
147                    *st = CircuitState::Closed {
148                        consecutive_failures: 0,
149                    };
150                }
151                Err(_) => {
152                    if current_is_half_open {
153                        // Half-open failure → reopen circuit.
154                        tracing::warn!(
155                            "Circuit breaker transitioning from HalfOpen to Open (probe failed)"
156                        );
157                        *st = CircuitState::Open {
158                            opened_at: Instant::now(),
159                        };
160                    } else if let CircuitState::Closed {
161                        consecutive_failures,
162                    } = &mut *st
163                    {
164                        *consecutive_failures += 1;
165                        if *consecutive_failures >= config.failure_threshold {
166                            tracing::warn!(
167                                threshold = config.failure_threshold,
168                                "Circuit breaker transitioning from Closed to Open (failure threshold reached)"
169                            );
170                            *st = CircuitState::Open {
171                                opened_at: Instant::now(),
172                            };
173                        }
174                    }
175                }
176            }
177
178            result
179        })
180    }
181}
182
183// ── Tests ──────────────────────────────────────────────────────────────
184
185#[cfg(test)]
186mod tests {
187    use super::*;
188    use camel_api::{BoxProcessor, BoxProcessorExt, Message};
189    use std::sync::atomic::{AtomicU32, Ordering};
190    use std::time::Duration;
191    use tower::ServiceExt;
192
193    fn make_exchange() -> Exchange {
194        Exchange::new(Message::new("test"))
195    }
196
197    fn ok_processor() -> BoxProcessor {
198        BoxProcessor::from_fn(|ex| Box::pin(async move { Ok(ex) }))
199    }
200
201    fn failing_processor() -> BoxProcessor {
202        BoxProcessor::from_fn(|_ex| {
203            Box::pin(async { Err(CamelError::ProcessorError("boom".into())) })
204        })
205    }
206
207    fn fail_n_times(n: u32) -> BoxProcessor {
208        let count = Arc::new(AtomicU32::new(0));
209        BoxProcessor::from_fn(move |ex| {
210            let count = Arc::clone(&count);
211            Box::pin(async move {
212                let c = count.fetch_add(1, Ordering::SeqCst);
213                if c < n {
214                    Err(CamelError::ProcessorError(format!("attempt {c}")))
215                } else {
216                    Ok(ex)
217                }
218            })
219        })
220    }
221
222    fn tag_processor(tag: &'static str) -> BoxProcessor {
223        BoxProcessor::from_fn(move |_ex| {
224            Box::pin(async move {
225                let mut out = make_exchange();
226                out.input.body = tag.to_string().into();
227                Ok(out)
228            })
229        })
230    }
231
232    /// 1. Circuit stays closed on success.
233    #[tokio::test]
234    async fn test_stays_closed_on_success() {
235        let config = CircuitBreakerConfig::new().failure_threshold(3);
236        let layer = CircuitBreakerLayer::new(config);
237        let mut svc = layer.layer(ok_processor());
238
239        for _ in 0..5 {
240            let result = svc.ready().await.unwrap().call(make_exchange()).await;
241            assert!(result.is_ok());
242        }
243
244        // State should still be closed with 0 failures.
245        let state = svc.state.lock().unwrap();
246        match *state {
247            CircuitState::Closed {
248                consecutive_failures,
249            } => assert_eq!(consecutive_failures, 0),
250            _ => panic!("expected Closed state"),
251        }
252    }
253
254    /// 2. Circuit opens after failure_threshold consecutive failures.
255    #[tokio::test]
256    async fn test_opens_after_failure_threshold() {
257        let config = CircuitBreakerConfig::new().failure_threshold(3);
258        let layer = CircuitBreakerLayer::new(config);
259        let mut svc = layer.layer(failing_processor());
260
261        // Three consecutive failures should open the circuit.
262        for _ in 0..3 {
263            let result = svc.ready().await.unwrap().call(make_exchange()).await;
264            assert!(result.is_err());
265        }
266
267        // The next poll_ready should return CircuitOpen error.
268        let waker = futures::task::noop_waker();
269        let mut cx = Context::from_waker(&waker);
270        let poll = Pin::new(&mut svc).poll_ready(&mut cx);
271        match poll {
272            Poll::Ready(Err(CamelError::CircuitOpen(_))) => {} // expected
273            other => panic!("expected CircuitOpen error, got {other:?}"),
274        }
275    }
276
277    /// 3. Circuit transitions to half-open after open_duration.
278    #[tokio::test]
279    async fn test_transitions_to_half_open_after_duration() {
280        let config = CircuitBreakerConfig::new()
281            .failure_threshold(2)
282            .open_duration(Duration::from_millis(50));
283        let layer = CircuitBreakerLayer::new(config);
284        // Use fail_n_times(2) so the first 2 calls fail (opening the circuit),
285        // then the third (half-open probe) succeeds.
286        let mut svc = layer.layer(fail_n_times(2));
287
288        // Trigger 2 failures to open the circuit.
289        for _ in 0..2 {
290            let _ = svc.ready().await.unwrap().call(make_exchange()).await;
291        }
292
293        // Circuit is now open. Wait for open_duration to elapse.
294        tokio::time::sleep(Duration::from_millis(60)).await;
295
296        // poll_ready should transition to HalfOpen and succeed.
297        let result = svc.ready().await.unwrap().call(make_exchange()).await;
298        assert!(result.is_ok(), "half-open probe should succeed");
299
300        // After successful probe, circuit should be back to Closed.
301        let state = svc.state.lock().unwrap();
302        match *state {
303            CircuitState::Closed {
304                consecutive_failures,
305            } => assert_eq!(consecutive_failures, 0),
306            _ => panic!("expected Closed state after successful half-open probe"),
307        }
308    }
309
310    /// 4. Half-open failure reopens circuit.
311    #[tokio::test]
312    async fn test_half_open_failure_reopens() {
313        let config = CircuitBreakerConfig::new()
314            .failure_threshold(2)
315            .open_duration(Duration::from_millis(50));
316        let layer = CircuitBreakerLayer::new(config);
317        let mut svc = layer.layer(failing_processor());
318
319        // Trigger 2 failures to open the circuit.
320        for _ in 0..2 {
321            let _ = svc.ready().await.unwrap().call(make_exchange()).await;
322        }
323
324        // Wait for open_duration to elapse, transitioning to HalfOpen.
325        tokio::time::sleep(Duration::from_millis(60)).await;
326
327        // Half-open probe fails → circuit reopens.
328        let result = svc.ready().await.unwrap().call(make_exchange()).await;
329        assert!(result.is_err());
330
331        // Circuit should be open again.
332        let state = svc.state.lock().unwrap();
333        match *state {
334            CircuitState::Open { .. } => {} // expected
335            _ => panic!("expected Open state after half-open failure"),
336        }
337    }
338
339    /// 5. Intermittent failures below threshold don't open circuit.
340    #[tokio::test]
341    async fn test_intermittent_failures_dont_open() {
342        let config = CircuitBreakerConfig::new().failure_threshold(3);
343        let layer = CircuitBreakerLayer::new(config);
344
345        // Alternate: fail, fail, success, fail, fail, success
346        // The counter should reset on success, so threshold of 3 is never reached.
347        let call_count = Arc::new(AtomicU32::new(0));
348        let cc = Arc::clone(&call_count);
349        let inner = BoxProcessor::from_fn(move |ex| {
350            let cc = Arc::clone(&cc);
351            Box::pin(async move {
352                let c = cc.fetch_add(1, Ordering::SeqCst);
353                // Pattern: fail, fail, success, fail, fail, success
354                if c % 3 == 2 {
355                    Ok(ex)
356                } else {
357                    Err(CamelError::ProcessorError("intermittent".into()))
358                }
359            })
360        });
361
362        let mut svc = layer.layer(inner);
363
364        for _ in 0..6 {
365            let _ = svc.ready().await.unwrap().call(make_exchange()).await;
366        }
367
368        // Circuit should still be closed because successes reset the counter.
369        let state = svc.state.lock().unwrap();
370        match *state {
371            CircuitState::Closed { .. } => {} // expected
372            _ => panic!("expected circuit to remain Closed"),
373        }
374    }
375
376    #[tokio::test]
377    async fn test_open_uses_fallback_when_configured() {
378        let fallback = tag_processor("fallback");
379        let config = CircuitBreakerConfig::new()
380            .failure_threshold(1)
381            .open_duration(Duration::from_secs(60))
382            .fallback(fallback);
383        let layer = CircuitBreakerLayer::new(config);
384        let mut svc = layer.layer(failing_processor());
385
386        let _ = svc.ready().await.unwrap().call(make_exchange()).await;
387        let result = svc
388            .ready()
389            .await
390            .unwrap()
391            .call(make_exchange())
392            .await
393            .unwrap();
394        assert_eq!(result.input.body.as_text(), Some("fallback"));
395    }
396
397    #[tokio::test]
398    async fn test_open_without_fallback_returns_err() {
399        let config = CircuitBreakerConfig::new()
400            .failure_threshold(1)
401            .open_duration(Duration::from_secs(60));
402        let layer = CircuitBreakerLayer::new(config);
403        let mut svc = layer.layer(failing_processor());
404
405        let _ = svc.ready().await.unwrap().call(make_exchange()).await;
406        let result = svc.ready().await;
407        assert!(matches!(result, Err(CamelError::CircuitOpen(_))));
408    }
409}