1use std::collections::HashMap;
40use std::convert::Infallible;
41use std::future::Future;
42use std::pin::Pin;
43use std::sync::Arc;
44use std::task::{Context, Poll};
45
46use tower::{Layer, Service};
47use tower_mcp::router::{Extensions, RouterRequest, RouterResponse};
48use tower_mcp_types::protocol::{CallToolParams, GetPromptParams, McpRequest, ReadResourceParams};
49
50#[derive(Debug, Clone)]
52struct FailoverMapping {
53 primary_prefix: String,
55 failover_prefixes: Vec<String>,
58}
59
60#[derive(Clone)]
62pub struct FailoverLayer {
63 failovers: HashMap<String, Vec<String>>,
64 separator: String,
65}
66
67impl FailoverLayer {
68 pub fn new(failovers: HashMap<String, Vec<String>>, separator: impl Into<String>) -> Self {
73 Self {
74 failovers,
75 separator: separator.into(),
76 }
77 }
78}
79
80impl<S> Layer<S> for FailoverLayer {
81 type Service = FailoverService<S>;
82
83 fn layer(&self, inner: S) -> Self::Service {
84 FailoverService::new(inner, self.failovers.clone(), &self.separator)
85 }
86}
87
88#[derive(Clone)]
93pub struct FailoverService<S> {
94 inner: S,
95 mappings: Arc<Vec<FailoverMapping>>,
96}
97
98impl<S> FailoverService<S> {
99 pub fn new(inner: S, failovers: HashMap<String, Vec<String>>, separator: &str) -> Self {
104 let mappings = failovers
105 .into_iter()
106 .map(|(primary, failover_names)| FailoverMapping {
107 primary_prefix: format!("{primary}{separator}"),
108 failover_prefixes: failover_names
109 .into_iter()
110 .map(|name| format!("{name}{separator}"))
111 .collect(),
112 })
113 .collect();
114
115 Self {
116 inner,
117 mappings: Arc::new(mappings),
118 }
119 }
120}
121
122fn rewrite_request(req: &McpRequest, primary_prefix: &str, failover_prefix: &str) -> McpRequest {
124 match req {
125 McpRequest::CallTool(params) => {
126 if let Some(local) = params.name.strip_prefix(primary_prefix) {
127 McpRequest::CallTool(CallToolParams {
128 name: format!("{failover_prefix}{local}"),
129 arguments: params.arguments.clone(),
130 meta: params.meta.clone(),
131 task: params.task.clone(),
132 })
133 } else {
134 req.clone()
135 }
136 }
137 McpRequest::ReadResource(params) => {
138 if let Some(local) = params.uri.strip_prefix(primary_prefix) {
139 McpRequest::ReadResource(ReadResourceParams {
140 uri: format!("{failover_prefix}{local}"),
141 meta: params.meta.clone(),
142 })
143 } else {
144 req.clone()
145 }
146 }
147 McpRequest::GetPrompt(params) => {
148 if let Some(local) = params.name.strip_prefix(primary_prefix) {
149 McpRequest::GetPrompt(GetPromptParams {
150 name: format!("{failover_prefix}{local}"),
151 arguments: params.arguments.clone(),
152 meta: params.meta.clone(),
153 })
154 } else {
155 req.clone()
156 }
157 }
158 other => other.clone(),
159 }
160}
161
162impl<S> Service<RouterRequest> for FailoverService<S>
163where
164 S: Service<RouterRequest, Response = RouterResponse, Error = Infallible>
165 + Clone
166 + Send
167 + 'static,
168 S::Future: Send,
169{
170 type Response = RouterResponse;
171 type Error = Infallible;
172 type Future = Pin<Box<dyn Future<Output = Result<RouterResponse, Infallible>> + Send>>;
173
174 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
175 self.inner.poll_ready(cx)
176 }
177
178 fn call(&mut self, req: RouterRequest) -> Self::Future {
179 let mappings = Arc::clone(&self.mappings);
180 let mut inner = self.inner.clone();
181
182 Box::pin(async move {
183 let mapping = mappings.iter().find(|m| match &req.inner {
185 McpRequest::CallTool(p) => p.name.starts_with(&m.primary_prefix),
186 McpRequest::ReadResource(p) => p.uri.starts_with(&m.primary_prefix),
187 McpRequest::GetPrompt(p) => p.name.starts_with(&m.primary_prefix),
188 _ => false,
189 });
190
191 let mapping = match mapping {
192 Some(m) => m.clone(),
193 None => {
194 return inner.call(req).await;
196 }
197 };
198
199 let primary_resp = inner.call(req.clone()).await?;
201
202 if primary_resp.inner.is_ok() {
204 return Ok(primary_resp);
205 }
206
207 let mut last_resp = primary_resp;
213
214 for failover_prefix in &mapping.failover_prefixes {
215 let failover_name = failover_prefix.trim_end_matches('/');
216 tracing::warn!(
217 primary = %mapping.primary_prefix.trim_end_matches('/'),
218 failover = %failover_name,
219 "Backend failed, attempting failover"
220 );
221
222 let failover_request =
223 rewrite_request(&req.inner, &mapping.primary_prefix, failover_prefix);
224
225 let failover_req = RouterRequest {
226 id: req.id.clone(),
227 inner: failover_request,
228 extensions: Extensions::new(),
229 };
230
231 let resp = inner.call(failover_req).await?;
232
233 if resp.inner.is_ok() {
234 return Ok(resp);
235 }
236
237 last_resp = resp;
238 }
239
240 Ok(last_resp)
242 })
243 }
244}
245
246#[cfg(test)]
247mod tests {
248 use tower_mcp::protocol::{McpRequest, McpResponse};
249
250 use super::FailoverService;
251 use crate::test_util::{MockService, call_service};
252
253 fn make_failover_svc(mock: MockService) -> FailoverService<MockService> {
254 let failovers = [("primary".to_string(), vec!["backup".to_string()])]
255 .into_iter()
256 .collect();
257 FailoverService::new(mock, failovers, "/")
258 }
259
260 #[tokio::test]
261 async fn test_failover_passes_through_when_no_mapping() {
262 let mock = MockService::with_tools(&["other/tool"]);
263 let mut svc = make_failover_svc(mock);
264
265 let resp = call_service(&mut svc, McpRequest::ListTools(Default::default())).await;
266 assert!(resp.inner.is_ok());
267 }
268
269 #[tokio::test]
270 async fn test_failover_passes_through_on_success() {
271 let mock = MockService::with_tools(&["primary/tool", "backup/tool"]);
272 let mut svc = make_failover_svc(mock);
273
274 let resp = call_service(
275 &mut svc,
276 McpRequest::CallTool(tower_mcp::protocol::CallToolParams {
277 name: "primary/tool".to_string(),
278 arguments: serde_json::json!({}),
279 meta: None,
280 task: None,
281 }),
282 )
283 .await;
284
285 assert!(resp.inner.is_ok(), "successful primary should pass through");
286 }
287
288 #[tokio::test]
289 async fn test_failover_retries_on_primary_error() {
290 use std::convert::Infallible;
293 use std::future::Future;
294 use std::pin::Pin;
295 use std::task::{Context, Poll};
296 use tower::Service;
297 use tower_mcp::protocol::CallToolResult;
298 use tower_mcp::router::{RouterRequest, RouterResponse};
299
300 #[derive(Clone)]
301 struct FailPrimaryMock;
302
303 impl Service<RouterRequest> for FailPrimaryMock {
304 type Response = RouterResponse;
305 type Error = Infallible;
306 type Future = Pin<Box<dyn Future<Output = Result<RouterResponse, Infallible>> + Send>>;
307
308 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
309 Poll::Ready(Ok(()))
310 }
311
312 fn call(&mut self, req: RouterRequest) -> Self::Future {
313 let id = req.id.clone();
314 Box::pin(async move {
315 let inner = match &req.inner {
316 McpRequest::CallTool(params) if params.name.starts_with("primary/") => {
317 Err(tower_mcp_types::JsonRpcError {
318 code: -32603,
319 message: "primary down".to_string(),
320 data: None,
321 })
322 }
323 McpRequest::CallTool(params) if params.name.starts_with("backup/") => {
324 Ok(McpResponse::CallTool(CallToolResult::text("from backup")))
325 }
326 _ => Ok(McpResponse::Pong(Default::default())),
327 };
328 Ok(RouterResponse { id, inner })
329 })
330 }
331 }
332
333 let failovers = [("primary".to_string(), vec!["backup".to_string()])]
334 .into_iter()
335 .collect();
336 let mut svc = FailoverService::new(FailPrimaryMock, failovers, "/");
337
338 let resp = call_service(
339 &mut svc,
340 McpRequest::CallTool(tower_mcp::protocol::CallToolParams {
341 name: "primary/tool".to_string(),
342 arguments: serde_json::json!({}),
343 meta: None,
344 task: None,
345 }),
346 )
347 .await;
348
349 match resp.inner.unwrap() {
350 McpResponse::CallTool(result) => {
351 assert_eq!(result.all_text(), "from backup");
352 }
353 other => panic!("expected CallTool, got: {:?}", other),
354 }
355 }
356
357 #[tokio::test]
358 async fn test_failover_chain_tries_in_order() {
359 use std::convert::Infallible;
361 use std::future::Future;
362 use std::pin::Pin;
363 use std::task::{Context, Poll};
364 use tower::Service;
365 use tower_mcp::protocol::CallToolResult;
366 use tower_mcp::router::{RouterRequest, RouterResponse};
367
368 #[derive(Clone)]
369 struct ChainMock;
370
371 impl Service<RouterRequest> for ChainMock {
372 type Response = RouterResponse;
373 type Error = Infallible;
374 type Future = Pin<Box<dyn Future<Output = Result<RouterResponse, Infallible>> + Send>>;
375
376 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
377 Poll::Ready(Ok(()))
378 }
379
380 fn call(&mut self, req: RouterRequest) -> Self::Future {
381 let id = req.id.clone();
382 Box::pin(async move {
383 let inner = match &req.inner {
384 McpRequest::CallTool(params) if params.name.starts_with("primary/") => {
385 Err(tower_mcp_types::JsonRpcError {
386 code: -32603,
387 message: "primary down".to_string(),
388 data: None,
389 })
390 }
391 McpRequest::CallTool(params) if params.name.starts_with("backup-1/") => {
392 Err(tower_mcp_types::JsonRpcError {
393 code: -32603,
394 message: "backup-1 down".to_string(),
395 data: None,
396 })
397 }
398 McpRequest::CallTool(params) if params.name.starts_with("backup-2/") => {
399 Ok(McpResponse::CallTool(CallToolResult::text("from backup-2")))
400 }
401 _ => Ok(McpResponse::Pong(Default::default())),
402 };
403 Ok(RouterResponse { id, inner })
404 })
405 }
406 }
407
408 let failovers = [(
409 "primary".to_string(),
410 vec!["backup-1".to_string(), "backup-2".to_string()],
411 )]
412 .into_iter()
413 .collect();
414 let mut svc = FailoverService::new(ChainMock, failovers, "/");
415
416 let resp = call_service(
417 &mut svc,
418 McpRequest::CallTool(tower_mcp::protocol::CallToolParams {
419 name: "primary/tool".to_string(),
420 arguments: serde_json::json!({}),
421 meta: None,
422 task: None,
423 }),
424 )
425 .await;
426
427 match resp.inner.unwrap() {
428 McpResponse::CallTool(result) => {
429 assert_eq!(result.all_text(), "from backup-2");
430 }
431 other => panic!("expected CallTool, got: {:?}", other),
432 }
433 }
434
435 #[tokio::test]
436 async fn test_failover_chain_all_fail_returns_last_error() {
437 use std::convert::Infallible;
438 use std::future::Future;
439 use std::pin::Pin;
440 use std::task::{Context, Poll};
441 use tower::Service;
442 use tower_mcp::router::{RouterRequest, RouterResponse};
443
444 #[derive(Clone)]
445 struct AllFailMock;
446
447 impl Service<RouterRequest> for AllFailMock {
448 type Response = RouterResponse;
449 type Error = Infallible;
450 type Future = Pin<Box<dyn Future<Output = Result<RouterResponse, Infallible>> + Send>>;
451
452 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
453 Poll::Ready(Ok(()))
454 }
455
456 fn call(&mut self, req: RouterRequest) -> Self::Future {
457 let id = req.id.clone();
458 Box::pin(async move {
459 let inner = match &req.inner {
460 McpRequest::CallTool(params) => Err(tower_mcp_types::JsonRpcError {
461 code: -32603,
462 message: format!("{} down", params.name),
463 data: None,
464 }),
465 _ => Ok(McpResponse::Pong(Default::default())),
466 };
467 Ok(RouterResponse { id, inner })
468 })
469 }
470 }
471
472 let failovers = [(
473 "primary".to_string(),
474 vec!["backup-1".to_string(), "backup-2".to_string()],
475 )]
476 .into_iter()
477 .collect();
478 let mut svc = FailoverService::new(AllFailMock, failovers, "/");
479
480 let resp = call_service(
481 &mut svc,
482 McpRequest::CallTool(tower_mcp::protocol::CallToolParams {
483 name: "primary/tool".to_string(),
484 arguments: serde_json::json!({}),
485 meta: None,
486 task: None,
487 }),
488 )
489 .await;
490
491 let err = resp.inner.unwrap_err();
493 assert!(
494 err.message.contains("backup-2"),
495 "expected last failover error, got: {}",
496 err.message
497 );
498 }
499}