Skip to main content

mcp_proxy/
mirror.rs

1//! Traffic mirroring / shadowing middleware.
2//!
3//! Sends a copy of traffic to a secondary backend (fire-and-forget, response
4//! discarded). Useful for testing new backend versions, benchmarking, or
5//! audit recording.
6//!
7//! # Configuration
8//!
9//! ```toml
10//! [[backends]]
11//! name = "api"
12//! transport = "http"
13//! url = "http://api.internal:8080"
14//!
15//! [[backends]]
16//! name = "api-v2"
17//! transport = "http"
18//! url = "http://api-v2.internal:8080"
19//! mirror_of = "api"        # mirror traffic from "api" backend
20//! mirror_percent = 10      # mirror 10% of requests
21//! ```
22//!
23//! # How it works
24//!
25//! 1. Request arrives targeting `api/search`
26//! 2. Primary response is returned from the `api` backend as normal
27//! 3. A copy of the request is rewritten to `api-v2/search` and sent
28//!    fire-and-forget to the `api-v2` backend
29//! 4. The mirror response is discarded; errors are logged but don't
30//!    affect the primary response
31
32use std::collections::HashMap;
33use std::convert::Infallible;
34use std::future::Future;
35use std::pin::Pin;
36use std::sync::Arc;
37use std::sync::atomic::{AtomicU64, Ordering};
38use std::task::{Context, Poll};
39
40use tower::Service;
41use tower_mcp::router::{Extensions, RouterRequest, RouterResponse};
42use tower_mcp_types::protocol::{CallToolParams, GetPromptParams, McpRequest, ReadResourceParams};
43
44/// Mapping from a source backend namespace to its mirror configuration.
45#[derive(Debug, Clone)]
46struct MirrorMapping {
47    /// Source namespace prefix (e.g. "api/").
48    source_prefix: String,
49    /// Mirror namespace prefix (e.g. "api-v2/").
50    mirror_prefix: String,
51    /// Percentage of requests to mirror (1-100).
52    percent: u32,
53    /// Atomic counter for deterministic percentage-based sampling.
54    counter: Arc<AtomicU64>,
55}
56
57/// Traffic mirroring middleware.
58///
59/// Wraps the proxy service and sends copies of matching requests to
60/// mirror backends. The primary response is always returned; mirror
61/// responses are discarded.
62#[derive(Clone)]
63pub struct MirrorService<S> {
64    inner: S,
65    mappings: Arc<Vec<MirrorMapping>>,
66}
67
68impl<S> MirrorService<S> {
69    /// Create a new mirror service.
70    ///
71    /// `mirrors` maps source backend names to `(mirror_name, percent)`.
72    /// The `separator` is used to construct namespace prefixes.
73    pub fn new(inner: S, mirrors: HashMap<String, (String, u32)>, separator: &str) -> Self {
74        let mappings = mirrors
75            .into_iter()
76            .map(|(source, (mirror, percent))| MirrorMapping {
77                source_prefix: format!("{source}{separator}"),
78                mirror_prefix: format!("{mirror}{separator}"),
79                percent: percent.clamp(1, 100),
80                counter: Arc::new(AtomicU64::new(0)),
81            })
82            .collect();
83
84        Self {
85            inner,
86            mappings: Arc::new(mappings),
87        }
88    }
89}
90
91/// Check if a request name starts with a namespace prefix and return the
92/// matching mirror mapping.
93fn find_mirror<'a>(name: &str, mappings: &'a [MirrorMapping]) -> Option<&'a MirrorMapping> {
94    mappings.iter().find(|m| name.starts_with(&m.source_prefix))
95}
96
97/// Rewrite a namespaced name from source to mirror prefix.
98fn rewrite_name(name: &str, source_prefix: &str, mirror_prefix: &str) -> String {
99    let suffix = &name[source_prefix.len()..];
100    format!("{mirror_prefix}{suffix}")
101}
102
103/// Clone a request with its name rewritten to the mirror namespace.
104fn clone_for_mirror(
105    req: &RouterRequest,
106    source_prefix: &str,
107    mirror_prefix: &str,
108) -> Option<RouterRequest> {
109    let new_inner = match &req.inner {
110        McpRequest::CallTool(params) if params.name.starts_with(source_prefix) => {
111            McpRequest::CallTool(CallToolParams {
112                name: rewrite_name(&params.name, source_prefix, mirror_prefix),
113                arguments: params.arguments.clone(),
114                meta: params.meta.clone(),
115                task: params.task.clone(),
116            })
117        }
118        McpRequest::ReadResource(params) if params.uri.starts_with(source_prefix) => {
119            McpRequest::ReadResource(ReadResourceParams {
120                uri: rewrite_name(&params.uri, source_prefix, mirror_prefix),
121                meta: params.meta.clone(),
122            })
123        }
124        McpRequest::GetPrompt(params) if params.name.starts_with(source_prefix) => {
125            McpRequest::GetPrompt(GetPromptParams {
126                name: rewrite_name(&params.name, source_prefix, mirror_prefix),
127                arguments: params.arguments.clone(),
128                meta: params.meta.clone(),
129            })
130        }
131        // List requests and other types aren't mirrored
132        _ => return None,
133    };
134
135    Some(RouterRequest {
136        id: req.id.clone(),
137        inner: new_inner,
138        extensions: Extensions::new(),
139    })
140}
141
142/// Check if the sampling counter says this request should be mirrored.
143fn should_mirror(mapping: &MirrorMapping) -> bool {
144    if mapping.percent >= 100 {
145        return true;
146    }
147    let count = mapping.counter.fetch_add(1, Ordering::Relaxed);
148    (count % 100) < mapping.percent as u64
149}
150
151/// Extract the request name for namespace matching.
152fn request_name(req: &McpRequest) -> Option<&str> {
153    match req {
154        McpRequest::CallTool(params) => Some(&params.name),
155        McpRequest::ReadResource(params) => Some(&params.uri),
156        McpRequest::GetPrompt(params) => Some(&params.name),
157        _ => None,
158    }
159}
160
161impl<S> Service<RouterRequest> for MirrorService<S>
162where
163    S: Service<RouterRequest, Response = RouterResponse, Error = Infallible>
164        + Clone
165        + Send
166        + 'static,
167    S::Future: Send,
168{
169    type Response = RouterResponse;
170    type Error = Infallible;
171    type Future = Pin<Box<dyn Future<Output = Result<RouterResponse, Infallible>> + Send>>;
172
173    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
174        self.inner.poll_ready(cx)
175    }
176
177    fn call(&mut self, req: RouterRequest) -> Self::Future {
178        // Check if this request should be mirrored
179        let mirror_req = request_name(&req.inner)
180            .and_then(|name| find_mirror(name, &self.mappings))
181            .filter(|mapping| should_mirror(mapping))
182            .and_then(|mapping| {
183                clone_for_mirror(&req, &mapping.source_prefix, &mapping.mirror_prefix)
184            });
185
186        // Send the primary request
187        let primary_fut = self.inner.call(req);
188
189        // If mirroring, clone the service and spawn a fire-and-forget task
190        let mut mirror_svc = if mirror_req.is_some() {
191            Some(self.inner.clone())
192        } else {
193            None
194        };
195
196        Box::pin(async move {
197            // Spawn mirror request as a fire-and-forget task
198            if let Some(mirror) = mirror_req
199                && let Some(ref mut svc) = mirror_svc
200            {
201                let mut svc = svc.clone();
202                tokio::spawn(async move {
203                    match svc.call(mirror).await {
204                        Ok(resp) => {
205                            if resp.inner.is_err() {
206                                tracing::debug!("Mirror request returned error (discarded)");
207                            }
208                        }
209                        Err(e) => match e {},
210                    }
211                });
212            }
213
214            primary_fut.await
215        })
216    }
217}
218
219#[cfg(test)]
220mod tests {
221    use super::*;
222    use crate::test_util::{MockService, call_service};
223    use tower_mcp::protocol::RequestId;
224    use tower_mcp::router::Extensions;
225    use tower_mcp_types::protocol::McpRequest;
226
227    fn make_mirrors(source: &str, mirror: &str, percent: u32) -> HashMap<String, (String, u32)> {
228        let mut m = HashMap::new();
229        m.insert(source.to_string(), (mirror.to_string(), percent));
230        m
231    }
232
233    #[test]
234    fn test_rewrite_name() {
235        assert_eq!(
236            rewrite_name("api/search", "api/", "api-v2/"),
237            "api-v2/search"
238        );
239        assert_eq!(
240            rewrite_name("api/nested/tool", "api/", "mirror/"),
241            "mirror/nested/tool"
242        );
243    }
244
245    #[test]
246    fn test_find_mirror_match() {
247        let mappings = vec![MirrorMapping {
248            source_prefix: "api/".to_string(),
249            mirror_prefix: "api-v2/".to_string(),
250            percent: 100,
251            counter: Arc::new(AtomicU64::new(0)),
252        }];
253        assert!(find_mirror("api/search", &mappings).is_some());
254        assert!(find_mirror("other/search", &mappings).is_none());
255    }
256
257    #[test]
258    fn test_should_mirror_100_percent() {
259        let mapping = MirrorMapping {
260            source_prefix: "api/".to_string(),
261            mirror_prefix: "api-v2/".to_string(),
262            percent: 100,
263            counter: Arc::new(AtomicU64::new(0)),
264        };
265        // All requests should be mirrored
266        for _ in 0..10 {
267            assert!(should_mirror(&mapping));
268        }
269    }
270
271    #[test]
272    fn test_should_mirror_percentage() {
273        let mapping = MirrorMapping {
274            source_prefix: "api/".to_string(),
275            mirror_prefix: "api-v2/".to_string(),
276            percent: 10,
277            counter: Arc::new(AtomicU64::new(0)),
278        };
279        // Over 100 requests, exactly 10 should be mirrored
280        let mirrored: u32 = (0..100).filter(|_| should_mirror(&mapping)).count() as u32;
281        assert_eq!(mirrored, 10);
282    }
283
284    #[test]
285    fn test_clone_for_mirror_call_tool() {
286        let req = RouterRequest {
287            id: RequestId::Number(1),
288            inner: McpRequest::CallTool(CallToolParams {
289                name: "api/search".to_string(),
290                arguments: serde_json::json!({"q": "test"}),
291                meta: None,
292                task: None,
293            }),
294            extensions: Extensions::new(),
295        };
296
297        let mirrored = clone_for_mirror(&req, "api/", "api-v2/").unwrap();
298        match &mirrored.inner {
299            McpRequest::CallTool(params) => {
300                assert_eq!(params.name, "api-v2/search");
301                assert_eq!(params.arguments, serde_json::json!({"q": "test"}));
302            }
303            _ => panic!("expected CallTool"),
304        }
305    }
306
307    #[test]
308    fn test_clone_for_mirror_read_resource() {
309        let req = RouterRequest {
310            id: RequestId::Number(1),
311            inner: McpRequest::ReadResource(ReadResourceParams {
312                uri: "api/docs/readme".to_string(),
313                meta: None,
314            }),
315            extensions: Extensions::new(),
316        };
317
318        let mirrored = clone_for_mirror(&req, "api/", "mirror/").unwrap();
319        match &mirrored.inner {
320            McpRequest::ReadResource(params) => {
321                assert_eq!(params.uri, "mirror/docs/readme");
322            }
323            _ => panic!("expected ReadResource"),
324        }
325    }
326
327    #[test]
328    fn test_clone_for_mirror_list_tools_returns_none() {
329        let req = RouterRequest {
330            id: RequestId::Number(1),
331            inner: McpRequest::ListTools(Default::default()),
332            extensions: Extensions::new(),
333        };
334        assert!(clone_for_mirror(&req, "api/", "mirror/").is_none());
335    }
336
337    #[tokio::test]
338    async fn test_mirror_service_passes_through() {
339        let mock = MockService::with_tools(&["api/search", "api-v2/search"]);
340        let mirrors = make_mirrors("api", "api-v2", 100);
341        let mut svc = MirrorService::new(mock, mirrors, "/");
342
343        let resp = call_service(
344            &mut svc,
345            McpRequest::CallTool(CallToolParams {
346                name: "api/search".to_string(),
347                arguments: serde_json::json!({}),
348                meta: None,
349                task: None,
350            }),
351        )
352        .await;
353
354        // Primary response should be returned
355        assert!(resp.inner.is_ok());
356    }
357
358    #[tokio::test]
359    async fn test_mirror_service_non_mirrored_passes_through() {
360        let mock = MockService::with_tools(&["other/tool"]);
361        let mirrors = make_mirrors("api", "api-v2", 100);
362        let mut svc = MirrorService::new(mock, mirrors, "/");
363
364        let resp = call_service(
365            &mut svc,
366            McpRequest::CallTool(CallToolParams {
367                name: "other/tool".to_string(),
368                arguments: serde_json::json!({}),
369                meta: None,
370                task: None,
371            }),
372        )
373        .await;
374
375        assert!(resp.inner.is_ok());
376    }
377
378    #[tokio::test]
379    async fn test_mirror_service_list_tools_not_mirrored() {
380        let mock = MockService::with_tools(&["api/search"]);
381        let mirrors = make_mirrors("api", "api-v2", 100);
382        let mut svc = MirrorService::new(mock, mirrors, "/");
383
384        let resp = call_service(&mut svc, McpRequest::ListTools(Default::default())).await;
385        assert!(resp.inner.is_ok());
386    }
387}