1use std::collections::HashMap;
33use std::convert::Infallible;
34use std::future::Future;
35use std::pin::Pin;
36use std::sync::Arc;
37use std::sync::atomic::{AtomicU64, Ordering};
38use std::task::{Context, Poll};
39
40use tower::{Layer, Service};
41use tower_mcp::router::{Extensions, RouterRequest, RouterResponse};
42use tower_mcp_types::protocol::{CallToolParams, GetPromptParams, McpRequest, ReadResourceParams};
43
44#[derive(Clone)]
46pub struct MirrorLayer {
47 mirrors: HashMap<String, (String, u32)>,
48 separator: String,
49}
50
51impl MirrorLayer {
52 pub fn new(mirrors: HashMap<String, (String, u32)>, separator: impl Into<String>) -> Self {
56 Self {
57 mirrors,
58 separator: separator.into(),
59 }
60 }
61}
62
63impl<S> Layer<S> for MirrorLayer {
64 type Service = MirrorService<S>;
65
66 fn layer(&self, inner: S) -> Self::Service {
67 MirrorService::new(inner, self.mirrors.clone(), &self.separator)
68 }
69}
70
71#[derive(Debug, Clone)]
73struct MirrorMapping {
74 source_prefix: String,
76 mirror_prefix: String,
78 percent: u32,
80 counter: Arc<AtomicU64>,
82}
83
84#[derive(Clone)]
90pub struct MirrorService<S> {
91 inner: S,
92 mappings: Arc<Vec<MirrorMapping>>,
93}
94
95impl<S> MirrorService<S> {
96 pub fn new(inner: S, mirrors: HashMap<String, (String, u32)>, separator: &str) -> Self {
101 let mappings = mirrors
102 .into_iter()
103 .map(|(source, (mirror, percent))| MirrorMapping {
104 source_prefix: format!("{source}{separator}"),
105 mirror_prefix: format!("{mirror}{separator}"),
106 percent: percent.clamp(1, 100),
107 counter: Arc::new(AtomicU64::new(0)),
108 })
109 .collect();
110
111 Self {
112 inner,
113 mappings: Arc::new(mappings),
114 }
115 }
116}
117
118fn find_mirror<'a>(name: &str, mappings: &'a [MirrorMapping]) -> Option<&'a MirrorMapping> {
121 mappings.iter().find(|m| name.starts_with(&m.source_prefix))
122}
123
124fn rewrite_name(name: &str, source_prefix: &str, mirror_prefix: &str) -> String {
126 let suffix = &name[source_prefix.len()..];
127 format!("{mirror_prefix}{suffix}")
128}
129
130fn clone_for_mirror(
132 req: &RouterRequest,
133 source_prefix: &str,
134 mirror_prefix: &str,
135) -> Option<RouterRequest> {
136 let new_inner = match &req.inner {
137 McpRequest::CallTool(params) if params.name.starts_with(source_prefix) => {
138 McpRequest::CallTool(CallToolParams {
139 name: rewrite_name(¶ms.name, source_prefix, mirror_prefix),
140 arguments: params.arguments.clone(),
141 meta: params.meta.clone(),
142 task: params.task.clone(),
143 })
144 }
145 McpRequest::ReadResource(params) if params.uri.starts_with(source_prefix) => {
146 McpRequest::ReadResource(ReadResourceParams {
147 uri: rewrite_name(¶ms.uri, source_prefix, mirror_prefix),
148 meta: params.meta.clone(),
149 })
150 }
151 McpRequest::GetPrompt(params) if params.name.starts_with(source_prefix) => {
152 McpRequest::GetPrompt(GetPromptParams {
153 name: rewrite_name(¶ms.name, source_prefix, mirror_prefix),
154 arguments: params.arguments.clone(),
155 meta: params.meta.clone(),
156 })
157 }
158 _ => return None,
160 };
161
162 Some(RouterRequest {
163 id: req.id.clone(),
164 inner: new_inner,
165 extensions: Extensions::new(),
166 })
167}
168
169fn should_mirror(mapping: &MirrorMapping) -> bool {
171 if mapping.percent >= 100 {
172 return true;
173 }
174 let count = mapping.counter.fetch_add(1, Ordering::Relaxed);
175 (count % 100) < mapping.percent as u64
176}
177
178fn request_name(req: &McpRequest) -> Option<&str> {
180 match req {
181 McpRequest::CallTool(params) => Some(¶ms.name),
182 McpRequest::ReadResource(params) => Some(¶ms.uri),
183 McpRequest::GetPrompt(params) => Some(¶ms.name),
184 _ => None,
185 }
186}
187
188impl<S> Service<RouterRequest> for MirrorService<S>
189where
190 S: Service<RouterRequest, Response = RouterResponse, Error = Infallible>
191 + Clone
192 + Send
193 + 'static,
194 S::Future: Send,
195{
196 type Response = RouterResponse;
197 type Error = Infallible;
198 type Future = Pin<Box<dyn Future<Output = Result<RouterResponse, Infallible>> + Send>>;
199
200 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
201 self.inner.poll_ready(cx)
202 }
203
204 fn call(&mut self, req: RouterRequest) -> Self::Future {
205 let mirror_req = request_name(&req.inner)
207 .and_then(|name| find_mirror(name, &self.mappings))
208 .filter(|mapping| should_mirror(mapping))
209 .and_then(|mapping| {
210 clone_for_mirror(&req, &mapping.source_prefix, &mapping.mirror_prefix)
211 });
212
213 let primary_fut = self.inner.call(req);
215
216 let mut mirror_svc = if mirror_req.is_some() {
218 Some(self.inner.clone())
219 } else {
220 None
221 };
222
223 Box::pin(async move {
224 if let Some(mirror) = mirror_req
226 && let Some(ref mut svc) = mirror_svc
227 {
228 let mut svc = svc.clone();
229 tokio::spawn(async move {
230 match svc.call(mirror).await {
231 Ok(resp) => {
232 if resp.inner.is_err() {
233 tracing::debug!("Mirror request returned error (discarded)");
234 }
235 }
236 Err(e) => match e {},
237 }
238 });
239 }
240
241 primary_fut.await
242 })
243 }
244}
245
246#[cfg(test)]
247mod tests {
248 use super::*;
249 use crate::test_util::{MockService, call_service};
250 use tower_mcp::protocol::RequestId;
251 use tower_mcp::router::Extensions;
252 use tower_mcp_types::protocol::McpRequest;
253
254 fn make_mirrors(source: &str, mirror: &str, percent: u32) -> HashMap<String, (String, u32)> {
255 let mut m = HashMap::new();
256 m.insert(source.to_string(), (mirror.to_string(), percent));
257 m
258 }
259
260 #[test]
261 fn test_rewrite_name() {
262 assert_eq!(
263 rewrite_name("api/search", "api/", "api-v2/"),
264 "api-v2/search"
265 );
266 assert_eq!(
267 rewrite_name("api/nested/tool", "api/", "mirror/"),
268 "mirror/nested/tool"
269 );
270 }
271
272 #[test]
273 fn test_find_mirror_match() {
274 let mappings = vec![MirrorMapping {
275 source_prefix: "api/".to_string(),
276 mirror_prefix: "api-v2/".to_string(),
277 percent: 100,
278 counter: Arc::new(AtomicU64::new(0)),
279 }];
280 assert!(find_mirror("api/search", &mappings).is_some());
281 assert!(find_mirror("other/search", &mappings).is_none());
282 }
283
284 #[test]
285 fn test_should_mirror_100_percent() {
286 let mapping = MirrorMapping {
287 source_prefix: "api/".to_string(),
288 mirror_prefix: "api-v2/".to_string(),
289 percent: 100,
290 counter: Arc::new(AtomicU64::new(0)),
291 };
292 for _ in 0..10 {
294 assert!(should_mirror(&mapping));
295 }
296 }
297
298 #[test]
299 fn test_should_mirror_percentage() {
300 let mapping = MirrorMapping {
301 source_prefix: "api/".to_string(),
302 mirror_prefix: "api-v2/".to_string(),
303 percent: 10,
304 counter: Arc::new(AtomicU64::new(0)),
305 };
306 let mirrored: u32 = (0..100).filter(|_| should_mirror(&mapping)).count() as u32;
308 assert_eq!(mirrored, 10);
309 }
310
311 #[test]
312 fn test_clone_for_mirror_call_tool() {
313 let req = RouterRequest {
314 id: RequestId::Number(1),
315 inner: McpRequest::CallTool(CallToolParams {
316 name: "api/search".to_string(),
317 arguments: serde_json::json!({"q": "test"}),
318 meta: None,
319 task: None,
320 }),
321 extensions: Extensions::new(),
322 };
323
324 let mirrored = clone_for_mirror(&req, "api/", "api-v2/").unwrap();
325 match &mirrored.inner {
326 McpRequest::CallTool(params) => {
327 assert_eq!(params.name, "api-v2/search");
328 assert_eq!(params.arguments, serde_json::json!({"q": "test"}));
329 }
330 _ => panic!("expected CallTool"),
331 }
332 }
333
334 #[test]
335 fn test_clone_for_mirror_read_resource() {
336 let req = RouterRequest {
337 id: RequestId::Number(1),
338 inner: McpRequest::ReadResource(ReadResourceParams {
339 uri: "api/docs/readme".to_string(),
340 meta: None,
341 }),
342 extensions: Extensions::new(),
343 };
344
345 let mirrored = clone_for_mirror(&req, "api/", "mirror/").unwrap();
346 match &mirrored.inner {
347 McpRequest::ReadResource(params) => {
348 assert_eq!(params.uri, "mirror/docs/readme");
349 }
350 _ => panic!("expected ReadResource"),
351 }
352 }
353
354 #[test]
355 fn test_clone_for_mirror_list_tools_returns_none() {
356 let req = RouterRequest {
357 id: RequestId::Number(1),
358 inner: McpRequest::ListTools(Default::default()),
359 extensions: Extensions::new(),
360 };
361 assert!(clone_for_mirror(&req, "api/", "mirror/").is_none());
362 }
363
364 #[tokio::test]
365 async fn test_mirror_service_passes_through() {
366 let mock = MockService::with_tools(&["api/search", "api-v2/search"]);
367 let mirrors = make_mirrors("api", "api-v2", 100);
368 let mut svc = MirrorService::new(mock, mirrors, "/");
369
370 let resp = call_service(
371 &mut svc,
372 McpRequest::CallTool(CallToolParams {
373 name: "api/search".to_string(),
374 arguments: serde_json::json!({}),
375 meta: None,
376 task: None,
377 }),
378 )
379 .await;
380
381 assert!(resp.inner.is_ok());
383 }
384
385 #[tokio::test]
386 async fn test_mirror_service_non_mirrored_passes_through() {
387 let mock = MockService::with_tools(&["other/tool"]);
388 let mirrors = make_mirrors("api", "api-v2", 100);
389 let mut svc = MirrorService::new(mock, mirrors, "/");
390
391 let resp = call_service(
392 &mut svc,
393 McpRequest::CallTool(CallToolParams {
394 name: "other/tool".to_string(),
395 arguments: serde_json::json!({}),
396 meta: None,
397 task: None,
398 }),
399 )
400 .await;
401
402 assert!(resp.inner.is_ok());
403 }
404
405 #[tokio::test]
406 async fn test_mirror_service_list_tools_not_mirrored() {
407 let mock = MockService::with_tools(&["api/search"]);
408 let mirrors = make_mirrors("api", "api-v2", 100);
409 let mut svc = MirrorService::new(mock, mirrors, "/");
410
411 let resp = call_service(&mut svc, McpRequest::ListTools(Default::default())).await;
412 assert!(resp.inner.is_ok());
413 }
414}