1use std::convert::Infallible;
25use std::future::Future;
26use std::pin::Pin;
27use std::sync::Arc;
28use std::sync::atomic::{AtomicBool, AtomicU32, AtomicU64, Ordering};
29use std::task::{Context, Poll};
30
31use tower::Service;
32use tower_mcp::router::{RouterRequest, RouterResponse};
33use tower_mcp_types::JsonRpcError;
34
35use crate::config::OutlierDetectionConfig;
36
37#[derive(Clone)]
39pub struct OutlierDetectionLayer {
40 name: String,
41 config: OutlierDetectionConfig,
42 detector: OutlierDetector,
43}
44
45impl OutlierDetectionLayer {
46 pub fn new(name: String, config: OutlierDetectionConfig, detector: OutlierDetector) -> Self {
48 Self {
49 name,
50 config,
51 detector,
52 }
53 }
54}
55
56impl<S> tower::Layer<S> for OutlierDetectionLayer {
57 type Service = OutlierDetectionService<S>;
58
59 fn layer(&self, inner: S) -> Self::Service {
60 OutlierDetectionService::new(
61 inner,
62 self.name.clone(),
63 self.config.clone(),
64 self.detector.clone(),
65 )
66 }
67}
68
69#[derive(Clone)]
74pub struct OutlierDetector {
75 inner: Arc<OutlierDetectorInner>,
76}
77
78struct OutlierDetectorInner {
79 total_backends: AtomicU32,
81 ejected_count: AtomicU32,
83 max_ejection_percent: u32,
85}
86
87impl OutlierDetector {
88 pub fn new(max_ejection_percent: u32) -> Self {
93 Self {
94 inner: Arc::new(OutlierDetectorInner {
95 total_backends: AtomicU32::new(0),
96 ejected_count: AtomicU32::new(0),
97 max_ejection_percent,
98 }),
99 }
100 }
101
102 pub fn register_backend(&self) {
104 self.inner.total_backends.fetch_add(1, Ordering::Relaxed);
105 }
106
107 pub fn try_eject(&self) -> bool {
112 let total = self.inner.total_backends.load(Ordering::Relaxed);
113 if total == 0 {
114 return false;
115 }
116
117 let currently_ejected = self.inner.ejected_count.load(Ordering::Relaxed);
118 let max_ejectable = (total as u64 * self.inner.max_ejection_percent as u64 / 100) as u32;
119 let max_ejectable = if self.inner.max_ejection_percent > 0 {
121 max_ejectable.max(1)
122 } else {
123 0
124 };
125
126 if currently_ejected >= max_ejectable {
127 tracing::debug!(
128 currently_ejected,
129 max_ejectable,
130 total,
131 "Ejection blocked: max_ejection_percent reached"
132 );
133 return false;
134 }
135
136 self.inner.ejected_count.fetch_add(1, Ordering::Relaxed);
137 true
138 }
139
140 pub fn record_uneject(&self) {
142 self.inner.ejected_count.fetch_sub(1, Ordering::Relaxed);
143 }
144
145 pub fn ejected_count(&self) -> u32 {
147 self.inner.ejected_count.load(Ordering::Relaxed)
148 }
149
150 pub fn total_backends(&self) -> u32 {
152 self.inner.total_backends.load(Ordering::Relaxed)
153 }
154}
155
156struct BackendOutlierState {
158 consecutive_errors: AtomicU32,
160 ejected: AtomicBool,
162 ejected_at_ms: AtomicU64,
164}
165
166#[derive(Clone)]
172pub struct OutlierDetectionService<S> {
173 inner: S,
174 state: Arc<BackendOutlierState>,
175 detector: OutlierDetector,
176 config: OutlierDetectionConfig,
177 name: String,
178}
179
180impl<S> OutlierDetectionService<S> {
181 pub fn new(
183 inner: S,
184 name: String,
185 config: OutlierDetectionConfig,
186 detector: OutlierDetector,
187 ) -> Self {
188 detector.register_backend();
189 Self {
190 inner,
191 state: Arc::new(BackendOutlierState {
192 consecutive_errors: AtomicU32::new(0),
193 ejected: AtomicBool::new(false),
194 ejected_at_ms: AtomicU64::new(0),
195 }),
196 detector,
197 config,
198 name,
199 }
200 }
201
202 fn maybe_uneject(&self) -> bool {
204 if !self.state.ejected.load(Ordering::Relaxed) {
205 return false;
206 }
207
208 let ejected_at = self.state.ejected_at_ms.load(Ordering::Relaxed);
209 let now = now_ms();
210 let elapsed_secs = now.saturating_sub(ejected_at) / 1000;
211
212 if elapsed_secs >= self.config.base_ejection_seconds {
213 self.state.ejected.store(false, Ordering::Relaxed);
214 self.state.consecutive_errors.store(0, Ordering::Relaxed);
215 self.detector.record_uneject();
216 tracing::info!(
217 backend = %self.name,
218 ejected_for_secs = elapsed_secs,
219 "Backend un-ejected, allowing traffic"
220 );
221 true
222 } else {
223 false
224 }
225 }
226}
227
228fn now_ms() -> u64 {
229 std::time::SystemTime::now()
230 .duration_since(std::time::UNIX_EPOCH)
231 .unwrap_or_default()
232 .as_millis() as u64
233}
234
235fn is_server_error(response: &RouterResponse) -> bool {
237 match &response.inner {
238 Err(err) => {
239 err.code == -32603 || (-32099..=-32000).contains(&err.code)
241 }
242 Ok(_) => false,
243 }
244}
245
246impl<S> Service<RouterRequest> for OutlierDetectionService<S>
247where
248 S: Service<RouterRequest, Response = RouterResponse, Error = Infallible>
249 + Clone
250 + Send
251 + 'static,
252 S::Future: Send,
253{
254 type Response = RouterResponse;
255 type Error = Infallible;
256 type Future = Pin<Box<dyn Future<Output = Result<RouterResponse, Infallible>> + Send>>;
257
258 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
259 self.inner.poll_ready(cx)
260 }
261
262 fn call(&mut self, req: RouterRequest) -> Self::Future {
263 self.maybe_uneject();
265
266 if self.state.ejected.load(Ordering::Relaxed) {
267 let id = req.id.clone();
268 let name = self.name.clone();
269 return Box::pin(async move {
270 tracing::debug!(backend = %name, "Request rejected: backend ejected");
271 Ok(RouterResponse {
272 id,
273 inner: Err(JsonRpcError {
274 code: -32000,
275 message: format!("backend '{name}' is ejected due to consecutive errors"),
276 data: None,
277 }),
278 })
279 });
280 }
281
282 let state = Arc::clone(&self.state);
283 let detector = self.detector.clone();
284 let config = self.config.clone();
285 let name = self.name.clone();
286 let fut = self.inner.call(req);
287
288 Box::pin(async move {
289 let response = fut.await?;
290
291 if is_server_error(&response) {
292 let errors = state.consecutive_errors.fetch_add(1, Ordering::Relaxed) + 1;
293 tracing::debug!(
294 backend = %name,
295 consecutive_errors = errors,
296 threshold = config.consecutive_errors,
297 "Backend error observed"
298 );
299
300 if errors >= config.consecutive_errors && !state.ejected.load(Ordering::Relaxed) {
301 if detector.try_eject() {
302 state.ejected.store(true, Ordering::Relaxed);
303 state.ejected_at_ms.store(now_ms(), Ordering::Relaxed);
304 tracing::warn!(
305 backend = %name,
306 consecutive_errors = errors,
307 ejection_seconds = config.base_ejection_seconds,
308 "Backend ejected due to consecutive errors"
309 );
310 } else {
311 tracing::warn!(
312 backend = %name,
313 consecutive_errors = errors,
314 "Backend would be ejected but max_ejection_percent reached"
315 );
316 }
317 }
318 } else {
319 state.consecutive_errors.store(0, Ordering::Relaxed);
321 }
322
323 Ok(response)
324 })
325 }
326}
327
328#[cfg(test)]
329mod tests {
330 use super::*;
331 use crate::config::OutlierDetectionConfig;
332 use crate::test_util::{MockService, call_service};
333 use tower::Service;
334 use tower_mcp::protocol::RequestId;
335 use tower_mcp::router::Extensions;
336 use tower_mcp_types::protocol::McpRequest;
337
338 fn make_config(consecutive: u32, ejection_secs: u64, max_pct: u32) -> OutlierDetectionConfig {
339 OutlierDetectionConfig {
340 consecutive_errors: consecutive,
341 interval_seconds: 10,
342 base_ejection_seconds: ejection_secs,
343 max_ejection_percent: max_pct,
344 }
345 }
346
347 fn make_error_request() -> RouterRequest {
348 RouterRequest {
349 id: RequestId::Number(1),
350 inner: McpRequest::CallTool(tower_mcp_types::protocol::CallToolParams {
351 name: "test/fail".to_string(),
352 arguments: serde_json::json!({}),
353 meta: None,
354 task: None,
355 }),
356 extensions: Extensions::new(),
357 }
358 }
359
360 #[derive(Clone)]
362 struct ErrorService;
363
364 impl Service<RouterRequest> for ErrorService {
365 type Response = RouterResponse;
366 type Error = Infallible;
367 type Future = Pin<Box<dyn Future<Output = Result<RouterResponse, Infallible>> + Send>>;
368
369 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
370 Poll::Ready(Ok(()))
371 }
372
373 fn call(&mut self, req: RouterRequest) -> Self::Future {
374 let id = req.id.clone();
375 Box::pin(async move {
376 Ok(RouterResponse {
377 id,
378 inner: Err(JsonRpcError {
379 code: -32603,
380 message: "internal error".to_string(),
381 data: None,
382 }),
383 })
384 })
385 }
386 }
387
388 #[tokio::test]
389 async fn test_passes_through_on_success() {
390 let mock = MockService::with_tools(&["test/hello"]);
391 let detector = OutlierDetector::new(50);
392 let config = make_config(5, 30, 50);
393 let mut svc = OutlierDetectionService::new(mock, "test".to_string(), config, detector);
394
395 let resp = call_service(&mut svc, McpRequest::ListTools(Default::default())).await;
396 assert!(resp.inner.is_ok());
397 }
398
399 #[tokio::test]
400 async fn test_tracks_consecutive_errors() {
401 let detector = OutlierDetector::new(50);
402 let config = make_config(3, 30, 50);
403 let mut svc =
404 OutlierDetectionService::new(ErrorService, "flaky".to_string(), config, detector);
405
406 for _ in 0..2 {
408 let _ = svc.call(make_error_request()).await;
409 }
410 assert!(!svc.state.ejected.load(Ordering::Relaxed));
411
412 let _ = svc.call(make_error_request()).await;
414 assert!(svc.state.ejected.load(Ordering::Relaxed));
415 }
416
417 #[tokio::test]
418 async fn test_success_resets_counter() {
419 let mock = MockService::with_tools(&["test/hello"]);
420 let detector = OutlierDetector::new(50);
421 let config = make_config(3, 30, 50);
422
423 let mut error_svc = OutlierDetectionService::new(
426 ErrorService,
427 "test".to_string(),
428 config.clone(),
429 detector.clone(),
430 );
431
432 let _ = error_svc.call(make_error_request()).await;
434 let _ = error_svc.call(make_error_request()).await;
435 assert_eq!(
436 error_svc.state.consecutive_errors.load(Ordering::Relaxed),
437 2
438 );
439
440 error_svc
442 .state
443 .consecutive_errors
444 .store(0, Ordering::Relaxed);
445 assert_eq!(
446 error_svc.state.consecutive_errors.load(Ordering::Relaxed),
447 0
448 );
449
450 let mut success_svc =
452 OutlierDetectionService::new(mock, "test2".to_string(), config, detector);
453 let resp = call_service(&mut success_svc, McpRequest::ListTools(Default::default())).await;
455 assert!(resp.inner.is_ok());
456 assert_eq!(
457 success_svc.state.consecutive_errors.load(Ordering::Relaxed),
458 0
459 );
460 }
461
462 #[tokio::test]
463 async fn test_ejected_backend_returns_error() {
464 let detector = OutlierDetector::new(50);
465 let config = make_config(2, 3600, 50); let mut svc =
467 OutlierDetectionService::new(ErrorService, "bad".to_string(), config, detector);
468
469 let _ = svc.call(make_error_request()).await;
471 let _ = svc.call(make_error_request()).await;
472 assert!(svc.state.ejected.load(Ordering::Relaxed));
473
474 let resp = svc.call(make_error_request()).await.unwrap();
476 match &resp.inner {
477 Err(err) => {
478 assert!(err.message.contains("ejected"));
479 }
480 Ok(_) => panic!("expected error for ejected backend"),
481 }
482 }
483
484 #[tokio::test]
485 async fn test_uneject_after_timeout() {
486 let detector = OutlierDetector::new(50);
487 let config = make_config(1, 0, 50); let mut svc =
489 OutlierDetectionService::new(ErrorService, "recover".to_string(), config, detector);
490
491 let _ = svc.call(make_error_request()).await;
493 assert!(svc.state.ejected.load(Ordering::Relaxed));
494
495 let _ = svc.call(make_error_request()).await;
498 }
502
503 #[test]
504 fn test_max_ejection_percent_blocks() {
505 let detector = OutlierDetector::new(50); detector.register_backend();
509 detector.register_backend();
510
511 assert!(detector.try_eject());
513
514 assert!(!detector.try_eject());
516 }
517
518 #[test]
519 fn test_max_ejection_percent_zero_blocks_all() {
520 let detector = OutlierDetector::new(0);
521 detector.register_backend();
522 assert!(!detector.try_eject());
523 }
524
525 #[test]
526 fn test_max_ejection_percent_100_allows_all() {
527 let detector = OutlierDetector::new(100);
528 detector.register_backend();
529 detector.register_backend();
530 assert!(detector.try_eject());
531 assert!(detector.try_eject());
532 }
533
534 #[test]
535 fn test_uneject_decrements_count() {
536 let detector = OutlierDetector::new(100);
537 detector.register_backend();
538 assert!(detector.try_eject());
539 assert_eq!(detector.ejected_count(), 1);
540 detector.record_uneject();
541 assert_eq!(detector.ejected_count(), 0);
542 }
543
544 #[test]
545 fn test_is_server_error() {
546 let err_resp = RouterResponse {
547 id: RequestId::Number(1),
548 inner: Err(JsonRpcError {
549 code: -32603,
550 message: "internal".to_string(),
551 data: None,
552 }),
553 };
554 assert!(is_server_error(&err_resp));
555
556 let err_resp2 = RouterResponse {
557 id: RequestId::Number(1),
558 inner: Err(JsonRpcError {
559 code: -32000,
560 message: "server error".to_string(),
561 data: None,
562 }),
563 };
564 assert!(is_server_error(&err_resp2));
565
566 let client_err = RouterResponse {
568 id: RequestId::Number(1),
569 inner: Err(JsonRpcError {
570 code: -32601,
571 message: "method not found".to_string(),
572 data: None,
573 }),
574 };
575 assert!(!is_server_error(&client_err));
576
577 let ok_resp = RouterResponse {
579 id: RequestId::Number(1),
580 inner: Ok(tower_mcp_types::protocol::McpResponse::ListTools(
581 tower_mcp_types::protocol::ListToolsResult {
582 tools: vec![],
583 next_cursor: None,
584 meta: None,
585 },
586 )),
587 };
588 assert!(!is_server_error(&ok_resp));
589 }
590}