Skip to main content

mcp_proxy/
outlier.rs

1//! Passive health checks via outlier detection.
2//!
3//! Tracks error rates on live traffic and automatically ejects unhealthy backends.
4//! Unlike the circuit breaker (which uses failure *rate* over a sliding window),
5//! outlier detection triggers on *consecutive* errors, catching hard-down backends
6//! faster. Cross-backend coordination via `max_ejection_percent` prevents ejecting
7//! all backends simultaneously.
8//!
9//! # Configuration
10//!
11//! ```toml
12//! [[backends]]
13//! name = "flaky-api"
14//! transport = "http"
15//! url = "http://localhost:8080"
16//!
17//! [backends.outlier_detection]
18//! consecutive_errors = 5       # eject after 5 consecutive errors
19//! interval_seconds = 10        # evaluation interval
20//! base_ejection_seconds = 30   # how long to eject
21//! max_ejection_percent = 50    # never eject more than half of backends
22//! ```
23
24use std::convert::Infallible;
25use std::future::Future;
26use std::pin::Pin;
27use std::sync::Arc;
28use std::sync::atomic::{AtomicBool, AtomicU32, AtomicU64, Ordering};
29use std::task::{Context, Poll};
30
31use tower::Service;
32use tower_mcp::router::{RouterRequest, RouterResponse};
33use tower_mcp_types::JsonRpcError;
34
35use crate::config::OutlierDetectionConfig;
36
37/// A Tower [`Layer`](tower::Layer) that applies outlier detection to a backend.
38#[derive(Clone)]
39pub struct OutlierDetectionLayer {
40    name: String,
41    config: OutlierDetectionConfig,
42    detector: OutlierDetector,
43}
44
45impl OutlierDetectionLayer {
46    /// Create a new outlier detection layer for a specific backend.
47    pub fn new(name: String, config: OutlierDetectionConfig, detector: OutlierDetector) -> Self {
48        Self {
49            name,
50            config,
51            detector,
52        }
53    }
54}
55
56impl<S> tower::Layer<S> for OutlierDetectionLayer {
57    type Service = OutlierDetectionService<S>;
58
59    fn layer(&self, inner: S) -> Self::Service {
60        OutlierDetectionService::new(
61            inner,
62            self.name.clone(),
63            self.config.clone(),
64            self.detector.clone(),
65        )
66    }
67}
68
69/// Shared state tracking ejection status across all backends.
70///
71/// Each backend registers with the detector and reports errors.
72/// The detector enforces `max_ejection_percent` globally.
73#[derive(Clone)]
74pub struct OutlierDetector {
75    inner: Arc<OutlierDetectorInner>,
76}
77
78struct OutlierDetectorInner {
79    /// Total number of backends registered.
80    total_backends: AtomicU32,
81    /// Number of currently ejected backends.
82    ejected_count: AtomicU32,
83    /// Maximum percentage of backends that can be ejected (0-100).
84    max_ejection_percent: u32,
85}
86
87impl OutlierDetector {
88    /// Create a new outlier detector.
89    ///
90    /// `max_ejection_percent` caps how many backends can be ejected at once
91    /// (as a percentage of total registered backends).
92    pub fn new(max_ejection_percent: u32) -> Self {
93        Self {
94            inner: Arc::new(OutlierDetectorInner {
95                total_backends: AtomicU32::new(0),
96                ejected_count: AtomicU32::new(0),
97                max_ejection_percent,
98            }),
99        }
100    }
101
102    /// Register a backend. Call once per backend at startup.
103    pub fn register_backend(&self) {
104        self.inner.total_backends.fetch_add(1, Ordering::Relaxed);
105    }
106
107    /// Try to eject a backend. Returns `true` if ejection is allowed.
108    ///
109    /// Respects `max_ejection_percent` -- if ejecting this backend would
110    /// exceed the threshold, returns `false`.
111    pub fn try_eject(&self) -> bool {
112        let total = self.inner.total_backends.load(Ordering::Relaxed);
113        if total == 0 {
114            return false;
115        }
116
117        let currently_ejected = self.inner.ejected_count.load(Ordering::Relaxed);
118        let max_ejectable = (total as u64 * self.inner.max_ejection_percent as u64 / 100) as u32;
119        // Always allow at least 1 ejection if max_ejection_percent > 0
120        let max_ejectable = if self.inner.max_ejection_percent > 0 {
121            max_ejectable.max(1)
122        } else {
123            0
124        };
125
126        if currently_ejected >= max_ejectable {
127            tracing::debug!(
128                currently_ejected,
129                max_ejectable,
130                total,
131                "Ejection blocked: max_ejection_percent reached"
132            );
133            return false;
134        }
135
136        self.inner.ejected_count.fetch_add(1, Ordering::Relaxed);
137        true
138    }
139
140    /// Record that a backend has been un-ejected.
141    pub fn record_uneject(&self) {
142        self.inner.ejected_count.fetch_sub(1, Ordering::Relaxed);
143    }
144
145    /// Current number of ejected backends (for observability).
146    pub fn ejected_count(&self) -> u32 {
147        self.inner.ejected_count.load(Ordering::Relaxed)
148    }
149
150    /// Total registered backends.
151    pub fn total_backends(&self) -> u32 {
152        self.inner.total_backends.load(Ordering::Relaxed)
153    }
154}
155
156/// Per-backend outlier detection state.
157struct BackendOutlierState {
158    /// Number of consecutive errors observed.
159    consecutive_errors: AtomicU32,
160    /// Whether this backend is currently ejected.
161    ejected: AtomicBool,
162    /// When the backend was ejected (millis since UNIX epoch).
163    ejected_at_ms: AtomicU64,
164}
165
166/// Per-backend outlier detection middleware.
167///
168/// Wraps a backend service and tracks consecutive errors. When the threshold
169/// is exceeded, the backend is ejected (requests fail immediately) for the
170/// configured duration.
171#[derive(Clone)]
172pub struct OutlierDetectionService<S> {
173    inner: S,
174    state: Arc<BackendOutlierState>,
175    detector: OutlierDetector,
176    config: OutlierDetectionConfig,
177    name: String,
178}
179
180impl<S> OutlierDetectionService<S> {
181    /// Create a new outlier detection service for a specific backend.
182    pub fn new(
183        inner: S,
184        name: String,
185        config: OutlierDetectionConfig,
186        detector: OutlierDetector,
187    ) -> Self {
188        detector.register_backend();
189        Self {
190            inner,
191            state: Arc::new(BackendOutlierState {
192                consecutive_errors: AtomicU32::new(0),
193                ejected: AtomicBool::new(false),
194                ejected_at_ms: AtomicU64::new(0),
195            }),
196            detector,
197            config,
198            name,
199        }
200    }
201
202    /// Check if the ejection period has expired and un-eject if so.
203    fn maybe_uneject(&self) -> bool {
204        if !self.state.ejected.load(Ordering::Relaxed) {
205            return false;
206        }
207
208        let ejected_at = self.state.ejected_at_ms.load(Ordering::Relaxed);
209        let now = now_ms();
210        let elapsed_secs = now.saturating_sub(ejected_at) / 1000;
211
212        if elapsed_secs >= self.config.base_ejection_seconds {
213            self.state.ejected.store(false, Ordering::Relaxed);
214            self.state.consecutive_errors.store(0, Ordering::Relaxed);
215            self.detector.record_uneject();
216            tracing::info!(
217                backend = %self.name,
218                ejected_for_secs = elapsed_secs,
219                "Backend un-ejected, allowing traffic"
220            );
221            true
222        } else {
223            false
224        }
225    }
226}
227
228fn now_ms() -> u64 {
229    std::time::SystemTime::now()
230        .duration_since(std::time::UNIX_EPOCH)
231        .unwrap_or_default()
232        .as_millis() as u64
233}
234
235/// Returns true if the MCP response indicates a server-side error.
236fn is_server_error(response: &RouterResponse) -> bool {
237    match &response.inner {
238        Err(err) => {
239            // JSON-RPC internal error (-32603) and server errors (-32000 to -32099)
240            err.code == -32603 || (-32099..=-32000).contains(&err.code)
241        }
242        Ok(_) => false,
243    }
244}
245
246impl<S> Service<RouterRequest> for OutlierDetectionService<S>
247where
248    S: Service<RouterRequest, Response = RouterResponse, Error = Infallible>
249        + Clone
250        + Send
251        + 'static,
252    S::Future: Send,
253{
254    type Response = RouterResponse;
255    type Error = Infallible;
256    type Future = Pin<Box<dyn Future<Output = Result<RouterResponse, Infallible>> + Send>>;
257
258    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
259        self.inner.poll_ready(cx)
260    }
261
262    fn call(&mut self, req: RouterRequest) -> Self::Future {
263        // Check if ejected
264        self.maybe_uneject();
265
266        if self.state.ejected.load(Ordering::Relaxed) {
267            let id = req.id.clone();
268            let name = self.name.clone();
269            return Box::pin(async move {
270                tracing::debug!(backend = %name, "Request rejected: backend ejected");
271                Ok(RouterResponse {
272                    id,
273                    inner: Err(JsonRpcError {
274                        code: -32000,
275                        message: format!("backend '{name}' is ejected due to consecutive errors"),
276                        data: None,
277                    }),
278                })
279            });
280        }
281
282        let state = Arc::clone(&self.state);
283        let detector = self.detector.clone();
284        let config = self.config.clone();
285        let name = self.name.clone();
286        let fut = self.inner.call(req);
287
288        Box::pin(async move {
289            let response = fut.await?;
290
291            if is_server_error(&response) {
292                let errors = state.consecutive_errors.fetch_add(1, Ordering::Relaxed) + 1;
293                tracing::debug!(
294                    backend = %name,
295                    consecutive_errors = errors,
296                    threshold = config.consecutive_errors,
297                    "Backend error observed"
298                );
299
300                if errors >= config.consecutive_errors && !state.ejected.load(Ordering::Relaxed) {
301                    if detector.try_eject() {
302                        state.ejected.store(true, Ordering::Relaxed);
303                        state.ejected_at_ms.store(now_ms(), Ordering::Relaxed);
304                        tracing::warn!(
305                            backend = %name,
306                            consecutive_errors = errors,
307                            ejection_seconds = config.base_ejection_seconds,
308                            "Backend ejected due to consecutive errors"
309                        );
310                    } else {
311                        tracing::warn!(
312                            backend = %name,
313                            consecutive_errors = errors,
314                            "Backend would be ejected but max_ejection_percent reached"
315                        );
316                    }
317                }
318            } else {
319                // Success resets the counter
320                state.consecutive_errors.store(0, Ordering::Relaxed);
321            }
322
323            Ok(response)
324        })
325    }
326}
327
328#[cfg(test)]
329mod tests {
330    use super::*;
331    use crate::config::OutlierDetectionConfig;
332    use crate::test_util::{MockService, call_service};
333    use tower::Service;
334    use tower_mcp::protocol::RequestId;
335    use tower_mcp::router::Extensions;
336    use tower_mcp_types::protocol::McpRequest;
337
338    fn make_config(consecutive: u32, ejection_secs: u64, max_pct: u32) -> OutlierDetectionConfig {
339        OutlierDetectionConfig {
340            consecutive_errors: consecutive,
341            interval_seconds: 10,
342            base_ejection_seconds: ejection_secs,
343            max_ejection_percent: max_pct,
344        }
345    }
346
347    fn make_error_request() -> RouterRequest {
348        RouterRequest {
349            id: RequestId::Number(1),
350            inner: McpRequest::CallTool(tower_mcp_types::protocol::CallToolParams {
351                name: "test/fail".to_string(),
352                arguments: serde_json::json!({}),
353                meta: None,
354                task: None,
355            }),
356            extensions: Extensions::new(),
357        }
358    }
359
360    /// A mock service that returns server errors.
361    #[derive(Clone)]
362    struct ErrorService;
363
364    impl Service<RouterRequest> for ErrorService {
365        type Response = RouterResponse;
366        type Error = Infallible;
367        type Future = Pin<Box<dyn Future<Output = Result<RouterResponse, Infallible>> + Send>>;
368
369        fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
370            Poll::Ready(Ok(()))
371        }
372
373        fn call(&mut self, req: RouterRequest) -> Self::Future {
374            let id = req.id.clone();
375            Box::pin(async move {
376                Ok(RouterResponse {
377                    id,
378                    inner: Err(JsonRpcError {
379                        code: -32603,
380                        message: "internal error".to_string(),
381                        data: None,
382                    }),
383                })
384            })
385        }
386    }
387
388    #[tokio::test]
389    async fn test_passes_through_on_success() {
390        let mock = MockService::with_tools(&["test/hello"]);
391        let detector = OutlierDetector::new(50);
392        let config = make_config(5, 30, 50);
393        let mut svc = OutlierDetectionService::new(mock, "test".to_string(), config, detector);
394
395        let resp = call_service(&mut svc, McpRequest::ListTools(Default::default())).await;
396        assert!(resp.inner.is_ok());
397    }
398
399    #[tokio::test]
400    async fn test_tracks_consecutive_errors() {
401        let detector = OutlierDetector::new(50);
402        let config = make_config(3, 30, 50);
403        let mut svc =
404            OutlierDetectionService::new(ErrorService, "flaky".to_string(), config, detector);
405
406        // 2 errors -- not yet ejected
407        for _ in 0..2 {
408            let _ = svc.call(make_error_request()).await;
409        }
410        assert!(!svc.state.ejected.load(Ordering::Relaxed));
411
412        // 3rd error triggers ejection
413        let _ = svc.call(make_error_request()).await;
414        assert!(svc.state.ejected.load(Ordering::Relaxed));
415    }
416
417    #[tokio::test]
418    async fn test_success_resets_counter() {
419        let mock = MockService::with_tools(&["test/hello"]);
420        let detector = OutlierDetector::new(50);
421        let config = make_config(3, 30, 50);
422
423        // We need a service that can return errors then success.
424        // Use ErrorService first, then switch to mock.
425        let mut error_svc = OutlierDetectionService::new(
426            ErrorService,
427            "test".to_string(),
428            config.clone(),
429            detector.clone(),
430        );
431
432        // 2 errors
433        let _ = error_svc.call(make_error_request()).await;
434        let _ = error_svc.call(make_error_request()).await;
435        assert_eq!(
436            error_svc.state.consecutive_errors.load(Ordering::Relaxed),
437            2
438        );
439
440        // Simulate success by directly resetting (the real service would do this)
441        error_svc
442            .state
443            .consecutive_errors
444            .store(0, Ordering::Relaxed);
445        assert_eq!(
446            error_svc.state.consecutive_errors.load(Ordering::Relaxed),
447            0
448        );
449
450        // Now test with mock service that returns success
451        let mut success_svc =
452            OutlierDetectionService::new(mock, "test2".to_string(), config, detector);
453        // Send a success request
454        let resp = call_service(&mut success_svc, McpRequest::ListTools(Default::default())).await;
455        assert!(resp.inner.is_ok());
456        assert_eq!(
457            success_svc.state.consecutive_errors.load(Ordering::Relaxed),
458            0
459        );
460    }
461
462    #[tokio::test]
463    async fn test_ejected_backend_returns_error() {
464        let detector = OutlierDetector::new(50);
465        let config = make_config(2, 3600, 50); // long ejection so it doesn't expire
466        let mut svc =
467            OutlierDetectionService::new(ErrorService, "bad".to_string(), config, detector);
468
469        // Trigger ejection
470        let _ = svc.call(make_error_request()).await;
471        let _ = svc.call(make_error_request()).await;
472        assert!(svc.state.ejected.load(Ordering::Relaxed));
473
474        // Next request should be rejected without hitting the backend
475        let resp = svc.call(make_error_request()).await.unwrap();
476        match &resp.inner {
477            Err(err) => {
478                assert!(err.message.contains("ejected"));
479            }
480            Ok(_) => panic!("expected error for ejected backend"),
481        }
482    }
483
484    #[tokio::test]
485    async fn test_uneject_after_timeout() {
486        let detector = OutlierDetector::new(50);
487        let config = make_config(1, 0, 50); // 0-second ejection = immediate uneject
488        let mut svc =
489            OutlierDetectionService::new(ErrorService, "recover".to_string(), config, detector);
490
491        // Trigger ejection
492        let _ = svc.call(make_error_request()).await;
493        assert!(svc.state.ejected.load(Ordering::Relaxed));
494
495        // With 0-second ejection, next call should uneject
496        // The maybe_uneject runs before checking ejection status
497        let _ = svc.call(make_error_request()).await;
498        // After uneject, the error from this call will increment counter again
499        // but with threshold=1, it will re-eject. That's fine -- the point is
500        // the uneject happened.
501    }
502
503    #[test]
504    fn test_max_ejection_percent_blocks() {
505        let detector = OutlierDetector::new(50); // 50%
506
507        // Register 2 backends
508        detector.register_backend();
509        detector.register_backend();
510
511        // First ejection should work (1/2 = 50%)
512        assert!(detector.try_eject());
513
514        // Second ejection should be blocked (2/2 = 100% > 50%)
515        assert!(!detector.try_eject());
516    }
517
518    #[test]
519    fn test_max_ejection_percent_zero_blocks_all() {
520        let detector = OutlierDetector::new(0);
521        detector.register_backend();
522        assert!(!detector.try_eject());
523    }
524
525    #[test]
526    fn test_max_ejection_percent_100_allows_all() {
527        let detector = OutlierDetector::new(100);
528        detector.register_backend();
529        detector.register_backend();
530        assert!(detector.try_eject());
531        assert!(detector.try_eject());
532    }
533
534    #[test]
535    fn test_uneject_decrements_count() {
536        let detector = OutlierDetector::new(100);
537        detector.register_backend();
538        assert!(detector.try_eject());
539        assert_eq!(detector.ejected_count(), 1);
540        detector.record_uneject();
541        assert_eq!(detector.ejected_count(), 0);
542    }
543
544    #[test]
545    fn test_is_server_error() {
546        let err_resp = RouterResponse {
547            id: RequestId::Number(1),
548            inner: Err(JsonRpcError {
549                code: -32603,
550                message: "internal".to_string(),
551                data: None,
552            }),
553        };
554        assert!(is_server_error(&err_resp));
555
556        let err_resp2 = RouterResponse {
557            id: RequestId::Number(1),
558            inner: Err(JsonRpcError {
559                code: -32000,
560                message: "server error".to_string(),
561                data: None,
562            }),
563        };
564        assert!(is_server_error(&err_resp2));
565
566        // Client error -- not a server error
567        let client_err = RouterResponse {
568            id: RequestId::Number(1),
569            inner: Err(JsonRpcError {
570                code: -32601,
571                message: "method not found".to_string(),
572                data: None,
573            }),
574        };
575        assert!(!is_server_error(&client_err));
576
577        // Success -- not an error
578        let ok_resp = RouterResponse {
579            id: RequestId::Number(1),
580            inner: Ok(tower_mcp_types::protocol::McpResponse::ListTools(
581                tower_mcp_types::protocol::ListToolsResult {
582                    tools: vec![],
583                    next_cursor: None,
584                    meta: None,
585                },
586            )),
587        };
588        assert!(!is_server_error(&ok_resp));
589    }
590}