1use std::collections::HashMap;
33use std::convert::Infallible;
34use std::future::Future;
35use std::pin::Pin;
36use std::sync::Arc;
37use std::sync::atomic::{AtomicU64, Ordering};
38use std::task::{Context, Poll};
39
40use tower::Service;
41use tower_mcp::router::{Extensions, RouterRequest, RouterResponse};
42use tower_mcp_types::protocol::{CallToolParams, GetPromptParams, McpRequest, ReadResourceParams};
43
44#[derive(Debug, Clone)]
46struct MirrorMapping {
47 source_prefix: String,
49 mirror_prefix: String,
51 percent: u32,
53 counter: Arc<AtomicU64>,
55}
56
57#[derive(Clone)]
63pub struct MirrorService<S> {
64 inner: S,
65 mappings: Arc<Vec<MirrorMapping>>,
66}
67
68impl<S> MirrorService<S> {
69 pub fn new(inner: S, mirrors: HashMap<String, (String, u32)>, separator: &str) -> Self {
74 let mappings = mirrors
75 .into_iter()
76 .map(|(source, (mirror, percent))| MirrorMapping {
77 source_prefix: format!("{source}{separator}"),
78 mirror_prefix: format!("{mirror}{separator}"),
79 percent: percent.clamp(1, 100),
80 counter: Arc::new(AtomicU64::new(0)),
81 })
82 .collect();
83
84 Self {
85 inner,
86 mappings: Arc::new(mappings),
87 }
88 }
89}
90
91fn find_mirror<'a>(name: &str, mappings: &'a [MirrorMapping]) -> Option<&'a MirrorMapping> {
94 mappings.iter().find(|m| name.starts_with(&m.source_prefix))
95}
96
97fn rewrite_name(name: &str, source_prefix: &str, mirror_prefix: &str) -> String {
99 let suffix = &name[source_prefix.len()..];
100 format!("{mirror_prefix}{suffix}")
101}
102
103fn clone_for_mirror(
105 req: &RouterRequest,
106 source_prefix: &str,
107 mirror_prefix: &str,
108) -> Option<RouterRequest> {
109 let new_inner = match &req.inner {
110 McpRequest::CallTool(params) if params.name.starts_with(source_prefix) => {
111 McpRequest::CallTool(CallToolParams {
112 name: rewrite_name(¶ms.name, source_prefix, mirror_prefix),
113 arguments: params.arguments.clone(),
114 meta: params.meta.clone(),
115 task: params.task.clone(),
116 })
117 }
118 McpRequest::ReadResource(params) if params.uri.starts_with(source_prefix) => {
119 McpRequest::ReadResource(ReadResourceParams {
120 uri: rewrite_name(¶ms.uri, source_prefix, mirror_prefix),
121 meta: params.meta.clone(),
122 })
123 }
124 McpRequest::GetPrompt(params) if params.name.starts_with(source_prefix) => {
125 McpRequest::GetPrompt(GetPromptParams {
126 name: rewrite_name(¶ms.name, source_prefix, mirror_prefix),
127 arguments: params.arguments.clone(),
128 meta: params.meta.clone(),
129 })
130 }
131 _ => return None,
133 };
134
135 Some(RouterRequest {
136 id: req.id.clone(),
137 inner: new_inner,
138 extensions: Extensions::new(),
139 })
140}
141
142fn should_mirror(mapping: &MirrorMapping) -> bool {
144 if mapping.percent >= 100 {
145 return true;
146 }
147 let count = mapping.counter.fetch_add(1, Ordering::Relaxed);
148 (count % 100) < mapping.percent as u64
149}
150
151fn request_name(req: &McpRequest) -> Option<&str> {
153 match req {
154 McpRequest::CallTool(params) => Some(¶ms.name),
155 McpRequest::ReadResource(params) => Some(¶ms.uri),
156 McpRequest::GetPrompt(params) => Some(¶ms.name),
157 _ => None,
158 }
159}
160
161impl<S> Service<RouterRequest> for MirrorService<S>
162where
163 S: Service<RouterRequest, Response = RouterResponse, Error = Infallible>
164 + Clone
165 + Send
166 + 'static,
167 S::Future: Send,
168{
169 type Response = RouterResponse;
170 type Error = Infallible;
171 type Future = Pin<Box<dyn Future<Output = Result<RouterResponse, Infallible>> + Send>>;
172
173 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
174 self.inner.poll_ready(cx)
175 }
176
177 fn call(&mut self, req: RouterRequest) -> Self::Future {
178 let mirror_req = request_name(&req.inner)
180 .and_then(|name| find_mirror(name, &self.mappings))
181 .filter(|mapping| should_mirror(mapping))
182 .and_then(|mapping| {
183 clone_for_mirror(&req, &mapping.source_prefix, &mapping.mirror_prefix)
184 });
185
186 let primary_fut = self.inner.call(req);
188
189 let mut mirror_svc = if mirror_req.is_some() {
191 Some(self.inner.clone())
192 } else {
193 None
194 };
195
196 Box::pin(async move {
197 if let Some(mirror) = mirror_req
199 && let Some(ref mut svc) = mirror_svc
200 {
201 let mut svc = svc.clone();
202 tokio::spawn(async move {
203 match svc.call(mirror).await {
204 Ok(resp) => {
205 if resp.inner.is_err() {
206 tracing::debug!("Mirror request returned error (discarded)");
207 }
208 }
209 Err(e) => match e {},
210 }
211 });
212 }
213
214 primary_fut.await
215 })
216 }
217}
218
219#[cfg(test)]
220mod tests {
221 use super::*;
222 use crate::test_util::{MockService, call_service};
223 use tower_mcp::protocol::RequestId;
224 use tower_mcp::router::Extensions;
225 use tower_mcp_types::protocol::McpRequest;
226
227 fn make_mirrors(source: &str, mirror: &str, percent: u32) -> HashMap<String, (String, u32)> {
228 let mut m = HashMap::new();
229 m.insert(source.to_string(), (mirror.to_string(), percent));
230 m
231 }
232
233 #[test]
234 fn test_rewrite_name() {
235 assert_eq!(
236 rewrite_name("api/search", "api/", "api-v2/"),
237 "api-v2/search"
238 );
239 assert_eq!(
240 rewrite_name("api/nested/tool", "api/", "mirror/"),
241 "mirror/nested/tool"
242 );
243 }
244
245 #[test]
246 fn test_find_mirror_match() {
247 let mappings = vec![MirrorMapping {
248 source_prefix: "api/".to_string(),
249 mirror_prefix: "api-v2/".to_string(),
250 percent: 100,
251 counter: Arc::new(AtomicU64::new(0)),
252 }];
253 assert!(find_mirror("api/search", &mappings).is_some());
254 assert!(find_mirror("other/search", &mappings).is_none());
255 }
256
257 #[test]
258 fn test_should_mirror_100_percent() {
259 let mapping = MirrorMapping {
260 source_prefix: "api/".to_string(),
261 mirror_prefix: "api-v2/".to_string(),
262 percent: 100,
263 counter: Arc::new(AtomicU64::new(0)),
264 };
265 for _ in 0..10 {
267 assert!(should_mirror(&mapping));
268 }
269 }
270
271 #[test]
272 fn test_should_mirror_percentage() {
273 let mapping = MirrorMapping {
274 source_prefix: "api/".to_string(),
275 mirror_prefix: "api-v2/".to_string(),
276 percent: 10,
277 counter: Arc::new(AtomicU64::new(0)),
278 };
279 let mirrored: u32 = (0..100).filter(|_| should_mirror(&mapping)).count() as u32;
281 assert_eq!(mirrored, 10);
282 }
283
284 #[test]
285 fn test_clone_for_mirror_call_tool() {
286 let req = RouterRequest {
287 id: RequestId::Number(1),
288 inner: McpRequest::CallTool(CallToolParams {
289 name: "api/search".to_string(),
290 arguments: serde_json::json!({"q": "test"}),
291 meta: None,
292 task: None,
293 }),
294 extensions: Extensions::new(),
295 };
296
297 let mirrored = clone_for_mirror(&req, "api/", "api-v2/").unwrap();
298 match &mirrored.inner {
299 McpRequest::CallTool(params) => {
300 assert_eq!(params.name, "api-v2/search");
301 assert_eq!(params.arguments, serde_json::json!({"q": "test"}));
302 }
303 _ => panic!("expected CallTool"),
304 }
305 }
306
307 #[test]
308 fn test_clone_for_mirror_read_resource() {
309 let req = RouterRequest {
310 id: RequestId::Number(1),
311 inner: McpRequest::ReadResource(ReadResourceParams {
312 uri: "api/docs/readme".to_string(),
313 meta: None,
314 }),
315 extensions: Extensions::new(),
316 };
317
318 let mirrored = clone_for_mirror(&req, "api/", "mirror/").unwrap();
319 match &mirrored.inner {
320 McpRequest::ReadResource(params) => {
321 assert_eq!(params.uri, "mirror/docs/readme");
322 }
323 _ => panic!("expected ReadResource"),
324 }
325 }
326
327 #[test]
328 fn test_clone_for_mirror_list_tools_returns_none() {
329 let req = RouterRequest {
330 id: RequestId::Number(1),
331 inner: McpRequest::ListTools(Default::default()),
332 extensions: Extensions::new(),
333 };
334 assert!(clone_for_mirror(&req, "api/", "mirror/").is_none());
335 }
336
337 #[tokio::test]
338 async fn test_mirror_service_passes_through() {
339 let mock = MockService::with_tools(&["api/search", "api-v2/search"]);
340 let mirrors = make_mirrors("api", "api-v2", 100);
341 let mut svc = MirrorService::new(mock, mirrors, "/");
342
343 let resp = call_service(
344 &mut svc,
345 McpRequest::CallTool(CallToolParams {
346 name: "api/search".to_string(),
347 arguments: serde_json::json!({}),
348 meta: None,
349 task: None,
350 }),
351 )
352 .await;
353
354 assert!(resp.inner.is_ok());
356 }
357
358 #[tokio::test]
359 async fn test_mirror_service_non_mirrored_passes_through() {
360 let mock = MockService::with_tools(&["other/tool"]);
361 let mirrors = make_mirrors("api", "api-v2", 100);
362 let mut svc = MirrorService::new(mock, mirrors, "/");
363
364 let resp = call_service(
365 &mut svc,
366 McpRequest::CallTool(CallToolParams {
367 name: "other/tool".to_string(),
368 arguments: serde_json::json!({}),
369 meta: None,
370 task: None,
371 }),
372 )
373 .await;
374
375 assert!(resp.inner.is_ok());
376 }
377
378 #[tokio::test]
379 async fn test_mirror_service_list_tools_not_mirrored() {
380 let mock = MockService::with_tools(&["api/search"]);
381 let mirrors = make_mirrors("api", "api-v2", 100);
382 let mut svc = MirrorService::new(mock, mirrors, "/");
383
384 let resp = call_service(&mut svc, McpRequest::ListTools(Default::default())).await;
385 assert!(resp.inner.is_ok());
386 }
387}