Skip to main content

mcp_proxy/
failover.rs

1//! Backend failover middleware.
2//!
3//! Routes requests to a primary backend, automatically falling over to
4//! secondary backends when the primary returns an error. Multiple failover
5//! backends can be configured per primary, ordered by [`priority`](crate::config::BackendConfig::priority)
6//! (lower values tried first).
7//!
8//! # Configuration
9//!
10//! ```toml
11//! [[backends]]
12//! name = "api"
13//! transport = "http"
14//! url = "http://primary:8080"
15//!
16//! [[backends]]
17//! name = "api-backup"
18//! transport = "http"
19//! url = "http://secondary:8080"
20//! failover_for = "api"
21//! priority = 0            # tried first (default)
22//!
23//! [[backends]]
24//! name = "api-backup-2"
25//! transport = "http"
26//! url = "http://tertiary:8080"
27//! failover_for = "api"
28//! priority = 10           # tried second
29//! ```
30//!
31//! # How it works
32//!
33//! 1. Request arrives targeting `api/search`
34//! 2. Request is forwarded to the `api` backend
35//! 3. If `api` returns an error, the request is retried against `api-backup/search`
36//! 4. If `api-backup` also fails, the request is retried against `api-backup-2/search`
37//! 5. Failover backend tools are hidden from `ListTools` (like canary backends)
38
39use std::collections::HashMap;
40use std::convert::Infallible;
41use std::future::Future;
42use std::pin::Pin;
43use std::sync::Arc;
44use std::task::{Context, Poll};
45
46use tower::{Layer, Service};
47use tower_mcp::router::{Extensions, RouterRequest, RouterResponse};
48use tower_mcp_types::protocol::{CallToolParams, GetPromptParams, McpRequest, ReadResourceParams};
49
50/// Resolved failover mapping for a single primary backend.
51#[derive(Debug, Clone)]
52struct FailoverMapping {
53    /// Primary namespace prefix (e.g. "api/").
54    primary_prefix: String,
55    /// Ordered list of failover namespace prefixes (e.g. ["api-backup/", "api-backup-2/"]).
56    /// Tried in order until one succeeds.
57    failover_prefixes: Vec<String>,
58}
59
60/// Tower layer that produces a [`FailoverService`].
61#[derive(Clone)]
62pub struct FailoverLayer {
63    failovers: HashMap<String, Vec<String>>,
64    separator: String,
65}
66
67impl FailoverLayer {
68    /// Create a new failover layer.
69    ///
70    /// `failovers` maps primary backend names to an ordered list of failover
71    /// backend names (sorted by priority, lowest first).
72    pub fn new(failovers: HashMap<String, Vec<String>>, separator: impl Into<String>) -> Self {
73        Self {
74            failovers,
75            separator: separator.into(),
76        }
77    }
78}
79
80impl<S> Layer<S> for FailoverLayer {
81    type Service = FailoverService<S>;
82
83    fn layer(&self, inner: S) -> Self::Service {
84        FailoverService::new(inner, self.failovers.clone(), &self.separator)
85    }
86}
87
88/// Tower service that fails over to secondary backends on primary error.
89///
90/// When a primary backend returns an error, failover backends are tried
91/// in priority order until one succeeds or all have been exhausted.
92#[derive(Clone)]
93pub struct FailoverService<S> {
94    inner: S,
95    mappings: Arc<Vec<FailoverMapping>>,
96}
97
98impl<S> FailoverService<S> {
99    /// Create a new failover service.
100    ///
101    /// `failovers` maps primary backend names to an ordered list of failover
102    /// backend names (sorted by priority, lowest first).
103    pub fn new(inner: S, failovers: HashMap<String, Vec<String>>, separator: &str) -> Self {
104        let mappings = failovers
105            .into_iter()
106            .map(|(primary, failover_names)| FailoverMapping {
107                primary_prefix: format!("{primary}{separator}"),
108                failover_prefixes: failover_names
109                    .into_iter()
110                    .map(|name| format!("{name}{separator}"))
111                    .collect(),
112            })
113            .collect();
114
115        Self {
116            inner,
117            mappings: Arc::new(mappings),
118        }
119    }
120}
121
122/// Rewrite a request's namespace from primary to failover.
123fn rewrite_request(req: &McpRequest, primary_prefix: &str, failover_prefix: &str) -> McpRequest {
124    match req {
125        McpRequest::CallTool(params) => {
126            if let Some(local) = params.name.strip_prefix(primary_prefix) {
127                McpRequest::CallTool(CallToolParams {
128                    name: format!("{failover_prefix}{local}"),
129                    arguments: params.arguments.clone(),
130                    meta: params.meta.clone(),
131                    task: params.task.clone(),
132                })
133            } else {
134                req.clone()
135            }
136        }
137        McpRequest::ReadResource(params) => {
138            if let Some(local) = params.uri.strip_prefix(primary_prefix) {
139                McpRequest::ReadResource(ReadResourceParams {
140                    uri: format!("{failover_prefix}{local}"),
141                    meta: params.meta.clone(),
142                })
143            } else {
144                req.clone()
145            }
146        }
147        McpRequest::GetPrompt(params) => {
148            if let Some(local) = params.name.strip_prefix(primary_prefix) {
149                McpRequest::GetPrompt(GetPromptParams {
150                    name: format!("{failover_prefix}{local}"),
151                    arguments: params.arguments.clone(),
152                    meta: params.meta.clone(),
153                })
154            } else {
155                req.clone()
156            }
157        }
158        other => other.clone(),
159    }
160}
161
162impl<S> Service<RouterRequest> for FailoverService<S>
163where
164    S: Service<RouterRequest, Response = RouterResponse, Error = Infallible>
165        + Clone
166        + Send
167        + 'static,
168    S::Future: Send,
169{
170    type Response = RouterResponse;
171    type Error = Infallible;
172    type Future = Pin<Box<dyn Future<Output = Result<RouterResponse, Infallible>> + Send>>;
173
174    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
175        self.inner.poll_ready(cx)
176    }
177
178    fn call(&mut self, req: RouterRequest) -> Self::Future {
179        let mappings = Arc::clone(&self.mappings);
180        let mut inner = self.inner.clone();
181
182        Box::pin(async move {
183            // Find if this request targets a primary that has failovers
184            let mapping = mappings.iter().find(|m| match &req.inner {
185                McpRequest::CallTool(p) => p.name.starts_with(&m.primary_prefix),
186                McpRequest::ReadResource(p) => p.uri.starts_with(&m.primary_prefix),
187                McpRequest::GetPrompt(p) => p.name.starts_with(&m.primary_prefix),
188                _ => false,
189            });
190
191            let mapping = match mapping {
192                Some(m) => m.clone(),
193                None => {
194                    // No failover configured for this request, pass through
195                    return inner.call(req).await;
196                }
197            };
198
199            // Try primary
200            let primary_resp = inner.call(req.clone()).await?;
201
202            // If primary succeeded, return it
203            if primary_resp.inner.is_ok() {
204                return Ok(primary_resp);
205            }
206
207            // Primary failed -- attempt failovers in priority order
208            // TODO: When outlier detection is integrated, check if the primary
209            // backend is ejected and skip directly to failover without waiting
210            // for an error response. This requires sharing ejection state
211            // between the OutlierDetectionService and FailoverService layers.
212            let mut last_resp = primary_resp;
213
214            for failover_prefix in &mapping.failover_prefixes {
215                let failover_name = failover_prefix.trim_end_matches('/');
216                tracing::warn!(
217                    primary = %mapping.primary_prefix.trim_end_matches('/'),
218                    failover = %failover_name,
219                    "Backend failed, attempting failover"
220                );
221
222                let failover_request =
223                    rewrite_request(&req.inner, &mapping.primary_prefix, failover_prefix);
224
225                let failover_req = RouterRequest {
226                    id: req.id.clone(),
227                    inner: failover_request,
228                    extensions: Extensions::new(),
229                };
230
231                let resp = inner.call(failover_req).await?;
232
233                if resp.inner.is_ok() {
234                    return Ok(resp);
235                }
236
237                last_resp = resp;
238            }
239
240            // All failovers exhausted, return the last error
241            Ok(last_resp)
242        })
243    }
244}
245
246#[cfg(test)]
247mod tests {
248    use tower_mcp::protocol::{McpRequest, McpResponse};
249
250    use super::FailoverService;
251    use crate::test_util::{MockService, call_service};
252
253    fn make_failover_svc(mock: MockService) -> FailoverService<MockService> {
254        let failovers = [("primary".to_string(), vec!["backup".to_string()])]
255            .into_iter()
256            .collect();
257        FailoverService::new(mock, failovers, "/")
258    }
259
260    #[tokio::test]
261    async fn test_failover_passes_through_when_no_mapping() {
262        let mock = MockService::with_tools(&["other/tool"]);
263        let mut svc = make_failover_svc(mock);
264
265        let resp = call_service(&mut svc, McpRequest::ListTools(Default::default())).await;
266        assert!(resp.inner.is_ok());
267    }
268
269    #[tokio::test]
270    async fn test_failover_passes_through_on_success() {
271        let mock = MockService::with_tools(&["primary/tool", "backup/tool"]);
272        let mut svc = make_failover_svc(mock);
273
274        let resp = call_service(
275            &mut svc,
276            McpRequest::CallTool(tower_mcp::protocol::CallToolParams {
277                name: "primary/tool".to_string(),
278                arguments: serde_json::json!({}),
279                meta: None,
280                task: None,
281            }),
282        )
283        .await;
284
285        assert!(resp.inner.is_ok(), "successful primary should pass through");
286    }
287
288    #[tokio::test]
289    async fn test_failover_retries_on_primary_error() {
290        // Create a mock that returns errors for "primary/" calls
291        // but succeeds for "backup/" calls
292        use std::convert::Infallible;
293        use std::future::Future;
294        use std::pin::Pin;
295        use std::task::{Context, Poll};
296        use tower::Service;
297        use tower_mcp::protocol::CallToolResult;
298        use tower_mcp::router::{RouterRequest, RouterResponse};
299
300        #[derive(Clone)]
301        struct FailPrimaryMock;
302
303        impl Service<RouterRequest> for FailPrimaryMock {
304            type Response = RouterResponse;
305            type Error = Infallible;
306            type Future = Pin<Box<dyn Future<Output = Result<RouterResponse, Infallible>> + Send>>;
307
308            fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
309                Poll::Ready(Ok(()))
310            }
311
312            fn call(&mut self, req: RouterRequest) -> Self::Future {
313                let id = req.id.clone();
314                Box::pin(async move {
315                    let inner = match &req.inner {
316                        McpRequest::CallTool(params) if params.name.starts_with("primary/") => {
317                            Err(tower_mcp_types::JsonRpcError {
318                                code: -32603,
319                                message: "primary down".to_string(),
320                                data: None,
321                            })
322                        }
323                        McpRequest::CallTool(params) if params.name.starts_with("backup/") => {
324                            Ok(McpResponse::CallTool(CallToolResult::text("from backup")))
325                        }
326                        _ => Ok(McpResponse::Pong(Default::default())),
327                    };
328                    Ok(RouterResponse { id, inner })
329                })
330            }
331        }
332
333        let failovers = [("primary".to_string(), vec!["backup".to_string()])]
334            .into_iter()
335            .collect();
336        let mut svc = FailoverService::new(FailPrimaryMock, failovers, "/");
337
338        let resp = call_service(
339            &mut svc,
340            McpRequest::CallTool(tower_mcp::protocol::CallToolParams {
341                name: "primary/tool".to_string(),
342                arguments: serde_json::json!({}),
343                meta: None,
344                task: None,
345            }),
346        )
347        .await;
348
349        match resp.inner.unwrap() {
350            McpResponse::CallTool(result) => {
351                assert_eq!(result.all_text(), "from backup");
352            }
353            other => panic!("expected CallTool, got: {:?}", other),
354        }
355    }
356
357    #[tokio::test]
358    async fn test_failover_chain_tries_in_order() {
359        // Mock that fails for primary and backup-1, succeeds for backup-2
360        use std::convert::Infallible;
361        use std::future::Future;
362        use std::pin::Pin;
363        use std::task::{Context, Poll};
364        use tower::Service;
365        use tower_mcp::protocol::CallToolResult;
366        use tower_mcp::router::{RouterRequest, RouterResponse};
367
368        #[derive(Clone)]
369        struct ChainMock;
370
371        impl Service<RouterRequest> for ChainMock {
372            type Response = RouterResponse;
373            type Error = Infallible;
374            type Future = Pin<Box<dyn Future<Output = Result<RouterResponse, Infallible>> + Send>>;
375
376            fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
377                Poll::Ready(Ok(()))
378            }
379
380            fn call(&mut self, req: RouterRequest) -> Self::Future {
381                let id = req.id.clone();
382                Box::pin(async move {
383                    let inner = match &req.inner {
384                        McpRequest::CallTool(params) if params.name.starts_with("primary/") => {
385                            Err(tower_mcp_types::JsonRpcError {
386                                code: -32603,
387                                message: "primary down".to_string(),
388                                data: None,
389                            })
390                        }
391                        McpRequest::CallTool(params) if params.name.starts_with("backup-1/") => {
392                            Err(tower_mcp_types::JsonRpcError {
393                                code: -32603,
394                                message: "backup-1 down".to_string(),
395                                data: None,
396                            })
397                        }
398                        McpRequest::CallTool(params) if params.name.starts_with("backup-2/") => {
399                            Ok(McpResponse::CallTool(CallToolResult::text("from backup-2")))
400                        }
401                        _ => Ok(McpResponse::Pong(Default::default())),
402                    };
403                    Ok(RouterResponse { id, inner })
404                })
405            }
406        }
407
408        let failovers = [(
409            "primary".to_string(),
410            vec!["backup-1".to_string(), "backup-2".to_string()],
411        )]
412        .into_iter()
413        .collect();
414        let mut svc = FailoverService::new(ChainMock, failovers, "/");
415
416        let resp = call_service(
417            &mut svc,
418            McpRequest::CallTool(tower_mcp::protocol::CallToolParams {
419                name: "primary/tool".to_string(),
420                arguments: serde_json::json!({}),
421                meta: None,
422                task: None,
423            }),
424        )
425        .await;
426
427        match resp.inner.unwrap() {
428            McpResponse::CallTool(result) => {
429                assert_eq!(result.all_text(), "from backup-2");
430            }
431            other => panic!("expected CallTool, got: {:?}", other),
432        }
433    }
434
435    #[tokio::test]
436    async fn test_failover_chain_all_fail_returns_last_error() {
437        use std::convert::Infallible;
438        use std::future::Future;
439        use std::pin::Pin;
440        use std::task::{Context, Poll};
441        use tower::Service;
442        use tower_mcp::router::{RouterRequest, RouterResponse};
443
444        #[derive(Clone)]
445        struct AllFailMock;
446
447        impl Service<RouterRequest> for AllFailMock {
448            type Response = RouterResponse;
449            type Error = Infallible;
450            type Future = Pin<Box<dyn Future<Output = Result<RouterResponse, Infallible>> + Send>>;
451
452            fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
453                Poll::Ready(Ok(()))
454            }
455
456            fn call(&mut self, req: RouterRequest) -> Self::Future {
457                let id = req.id.clone();
458                Box::pin(async move {
459                    let inner = match &req.inner {
460                        McpRequest::CallTool(params) => Err(tower_mcp_types::JsonRpcError {
461                            code: -32603,
462                            message: format!("{} down", params.name),
463                            data: None,
464                        }),
465                        _ => Ok(McpResponse::Pong(Default::default())),
466                    };
467                    Ok(RouterResponse { id, inner })
468                })
469            }
470        }
471
472        let failovers = [(
473            "primary".to_string(),
474            vec!["backup-1".to_string(), "backup-2".to_string()],
475        )]
476        .into_iter()
477        .collect();
478        let mut svc = FailoverService::new(AllFailMock, failovers, "/");
479
480        let resp = call_service(
481            &mut svc,
482            McpRequest::CallTool(tower_mcp::protocol::CallToolParams {
483                name: "primary/tool".to_string(),
484                arguments: serde_json::json!({}),
485                meta: None,
486                task: None,
487            }),
488        )
489        .await;
490
491        // Should get the last failover's error
492        let err = resp.inner.unwrap_err();
493        assert!(
494            err.message.contains("backup-2"),
495            "expected last failover error, got: {}",
496            err.message
497        );
498    }
499}