Skip to main content

mcp_proxy/
mirror.rs

1//! Traffic mirroring / shadowing middleware.
2//!
3//! Sends a copy of traffic to a secondary backend (fire-and-forget, response
4//! discarded). Useful for testing new backend versions, benchmarking, or
5//! audit recording.
6//!
7//! # Configuration
8//!
9//! ```toml
10//! [[backends]]
11//! name = "api"
12//! transport = "http"
13//! url = "http://api.internal:8080"
14//!
15//! [[backends]]
16//! name = "api-v2"
17//! transport = "http"
18//! url = "http://api-v2.internal:8080"
19//! mirror_of = "api"        # mirror traffic from "api" backend
20//! mirror_percent = 10      # mirror 10% of requests
21//! ```
22//!
23//! # How it works
24//!
25//! 1. Request arrives targeting `api/search`
26//! 2. Primary response is returned from the `api` backend as normal
27//! 3. A copy of the request is rewritten to `api-v2/search` and sent
28//!    fire-and-forget to the `api-v2` backend
29//! 4. The mirror response is discarded; errors are logged but don't
30//!    affect the primary response
31
32use std::collections::HashMap;
33use std::convert::Infallible;
34use std::future::Future;
35use std::pin::Pin;
36use std::sync::Arc;
37use std::sync::atomic::{AtomicU64, Ordering};
38use std::task::{Context, Poll};
39
40use tower::{Layer, Service};
41use tower_mcp::router::{Extensions, RouterRequest, RouterResponse};
42use tower_mcp_types::protocol::{CallToolParams, GetPromptParams, McpRequest, ReadResourceParams};
43
44/// Tower layer that produces a [`MirrorService`].
45#[derive(Clone)]
46pub struct MirrorLayer {
47    mirrors: HashMap<String, (String, u32)>,
48    separator: String,
49}
50
51impl MirrorLayer {
52    /// Create a new mirror layer.
53    ///
54    /// `mirrors` maps source backend names to `(mirror_name, percent)`.
55    pub fn new(mirrors: HashMap<String, (String, u32)>, separator: impl Into<String>) -> Self {
56        Self {
57            mirrors,
58            separator: separator.into(),
59        }
60    }
61}
62
63impl<S> Layer<S> for MirrorLayer {
64    type Service = MirrorService<S>;
65
66    fn layer(&self, inner: S) -> Self::Service {
67        MirrorService::new(inner, self.mirrors.clone(), &self.separator)
68    }
69}
70
71/// Mapping from a source backend namespace to its mirror configuration.
72#[derive(Debug, Clone)]
73struct MirrorMapping {
74    /// Source namespace prefix (e.g. "api/").
75    source_prefix: String,
76    /// Mirror namespace prefix (e.g. "api-v2/").
77    mirror_prefix: String,
78    /// Percentage of requests to mirror (1-100).
79    percent: u32,
80    /// Atomic counter for deterministic percentage-based sampling.
81    counter: Arc<AtomicU64>,
82}
83
84/// Traffic mirroring middleware.
85///
86/// Wraps the proxy service and sends copies of matching requests to
87/// mirror backends. The primary response is always returned; mirror
88/// responses are discarded.
89#[derive(Clone)]
90pub struct MirrorService<S> {
91    inner: S,
92    mappings: Arc<Vec<MirrorMapping>>,
93}
94
95impl<S> MirrorService<S> {
96    /// Create a new mirror service.
97    ///
98    /// `mirrors` maps source backend names to `(mirror_name, percent)`.
99    /// The `separator` is used to construct namespace prefixes.
100    pub fn new(inner: S, mirrors: HashMap<String, (String, u32)>, separator: &str) -> Self {
101        let mappings = mirrors
102            .into_iter()
103            .map(|(source, (mirror, percent))| MirrorMapping {
104                source_prefix: format!("{source}{separator}"),
105                mirror_prefix: format!("{mirror}{separator}"),
106                percent: percent.clamp(1, 100),
107                counter: Arc::new(AtomicU64::new(0)),
108            })
109            .collect();
110
111        Self {
112            inner,
113            mappings: Arc::new(mappings),
114        }
115    }
116}
117
118/// Check if a request name starts with a namespace prefix and return the
119/// matching mirror mapping.
120fn find_mirror<'a>(name: &str, mappings: &'a [MirrorMapping]) -> Option<&'a MirrorMapping> {
121    mappings.iter().find(|m| name.starts_with(&m.source_prefix))
122}
123
124/// Rewrite a namespaced name from source to mirror prefix.
125fn rewrite_name(name: &str, source_prefix: &str, mirror_prefix: &str) -> String {
126    let suffix = &name[source_prefix.len()..];
127    format!("{mirror_prefix}{suffix}")
128}
129
130/// Clone a request with its name rewritten to the mirror namespace.
131fn clone_for_mirror(
132    req: &RouterRequest,
133    source_prefix: &str,
134    mirror_prefix: &str,
135) -> Option<RouterRequest> {
136    let new_inner = match &req.inner {
137        McpRequest::CallTool(params) if params.name.starts_with(source_prefix) => {
138            McpRequest::CallTool(CallToolParams {
139                name: rewrite_name(&params.name, source_prefix, mirror_prefix),
140                arguments: params.arguments.clone(),
141                meta: params.meta.clone(),
142                task: params.task.clone(),
143            })
144        }
145        McpRequest::ReadResource(params) if params.uri.starts_with(source_prefix) => {
146            McpRequest::ReadResource(ReadResourceParams {
147                uri: rewrite_name(&params.uri, source_prefix, mirror_prefix),
148                meta: params.meta.clone(),
149            })
150        }
151        McpRequest::GetPrompt(params) if params.name.starts_with(source_prefix) => {
152            McpRequest::GetPrompt(GetPromptParams {
153                name: rewrite_name(&params.name, source_prefix, mirror_prefix),
154                arguments: params.arguments.clone(),
155                meta: params.meta.clone(),
156            })
157        }
158        // List requests and other types aren't mirrored
159        _ => return None,
160    };
161
162    Some(RouterRequest {
163        id: req.id.clone(),
164        inner: new_inner,
165        extensions: Extensions::new(),
166    })
167}
168
169/// Check if the sampling counter says this request should be mirrored.
170fn should_mirror(mapping: &MirrorMapping) -> bool {
171    if mapping.percent >= 100 {
172        return true;
173    }
174    let count = mapping.counter.fetch_add(1, Ordering::Relaxed);
175    (count % 100) < mapping.percent as u64
176}
177
178/// Extract the request name for namespace matching.
179fn request_name(req: &McpRequest) -> Option<&str> {
180    match req {
181        McpRequest::CallTool(params) => Some(&params.name),
182        McpRequest::ReadResource(params) => Some(&params.uri),
183        McpRequest::GetPrompt(params) => Some(&params.name),
184        _ => None,
185    }
186}
187
188impl<S> Service<RouterRequest> for MirrorService<S>
189where
190    S: Service<RouterRequest, Response = RouterResponse, Error = Infallible>
191        + Clone
192        + Send
193        + 'static,
194    S::Future: Send,
195{
196    type Response = RouterResponse;
197    type Error = Infallible;
198    type Future = Pin<Box<dyn Future<Output = Result<RouterResponse, Infallible>> + Send>>;
199
200    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
201        self.inner.poll_ready(cx)
202    }
203
204    fn call(&mut self, req: RouterRequest) -> Self::Future {
205        // Check if this request should be mirrored
206        let mirror_req = request_name(&req.inner)
207            .and_then(|name| find_mirror(name, &self.mappings))
208            .filter(|mapping| should_mirror(mapping))
209            .and_then(|mapping| {
210                clone_for_mirror(&req, &mapping.source_prefix, &mapping.mirror_prefix)
211            });
212
213        // Send the primary request
214        let primary_fut = self.inner.call(req);
215
216        // If mirroring, clone the service and spawn a fire-and-forget task
217        let mut mirror_svc = if mirror_req.is_some() {
218            Some(self.inner.clone())
219        } else {
220            None
221        };
222
223        Box::pin(async move {
224            // Spawn mirror request as a fire-and-forget task
225            if let Some(mirror) = mirror_req
226                && let Some(ref mut svc) = mirror_svc
227            {
228                let mut svc = svc.clone();
229                tokio::spawn(async move {
230                    match svc.call(mirror).await {
231                        Ok(resp) => {
232                            if resp.inner.is_err() {
233                                tracing::debug!("Mirror request returned error (discarded)");
234                            }
235                        }
236                        Err(e) => match e {},
237                    }
238                });
239            }
240
241            primary_fut.await
242        })
243    }
244}
245
246#[cfg(test)]
247mod tests {
248    use super::*;
249    use crate::test_util::{MockService, call_service};
250    use tower_mcp::protocol::RequestId;
251    use tower_mcp::router::Extensions;
252    use tower_mcp_types::protocol::McpRequest;
253
254    fn make_mirrors(source: &str, mirror: &str, percent: u32) -> HashMap<String, (String, u32)> {
255        let mut m = HashMap::new();
256        m.insert(source.to_string(), (mirror.to_string(), percent));
257        m
258    }
259
260    #[test]
261    fn test_rewrite_name() {
262        assert_eq!(
263            rewrite_name("api/search", "api/", "api-v2/"),
264            "api-v2/search"
265        );
266        assert_eq!(
267            rewrite_name("api/nested/tool", "api/", "mirror/"),
268            "mirror/nested/tool"
269        );
270    }
271
272    #[test]
273    fn test_find_mirror_match() {
274        let mappings = vec![MirrorMapping {
275            source_prefix: "api/".to_string(),
276            mirror_prefix: "api-v2/".to_string(),
277            percent: 100,
278            counter: Arc::new(AtomicU64::new(0)),
279        }];
280        assert!(find_mirror("api/search", &mappings).is_some());
281        assert!(find_mirror("other/search", &mappings).is_none());
282    }
283
284    #[test]
285    fn test_should_mirror_100_percent() {
286        let mapping = MirrorMapping {
287            source_prefix: "api/".to_string(),
288            mirror_prefix: "api-v2/".to_string(),
289            percent: 100,
290            counter: Arc::new(AtomicU64::new(0)),
291        };
292        // All requests should be mirrored
293        for _ in 0..10 {
294            assert!(should_mirror(&mapping));
295        }
296    }
297
298    #[test]
299    fn test_should_mirror_percentage() {
300        let mapping = MirrorMapping {
301            source_prefix: "api/".to_string(),
302            mirror_prefix: "api-v2/".to_string(),
303            percent: 10,
304            counter: Arc::new(AtomicU64::new(0)),
305        };
306        // Over 100 requests, exactly 10 should be mirrored
307        let mirrored: u32 = (0..100).filter(|_| should_mirror(&mapping)).count() as u32;
308        assert_eq!(mirrored, 10);
309    }
310
311    #[test]
312    fn test_clone_for_mirror_call_tool() {
313        let req = RouterRequest {
314            id: RequestId::Number(1),
315            inner: McpRequest::CallTool(CallToolParams {
316                name: "api/search".to_string(),
317                arguments: serde_json::json!({"q": "test"}),
318                meta: None,
319                task: None,
320            }),
321            extensions: Extensions::new(),
322        };
323
324        let mirrored = clone_for_mirror(&req, "api/", "api-v2/").unwrap();
325        match &mirrored.inner {
326            McpRequest::CallTool(params) => {
327                assert_eq!(params.name, "api-v2/search");
328                assert_eq!(params.arguments, serde_json::json!({"q": "test"}));
329            }
330            _ => panic!("expected CallTool"),
331        }
332    }
333
334    #[test]
335    fn test_clone_for_mirror_read_resource() {
336        let req = RouterRequest {
337            id: RequestId::Number(1),
338            inner: McpRequest::ReadResource(ReadResourceParams {
339                uri: "api/docs/readme".to_string(),
340                meta: None,
341            }),
342            extensions: Extensions::new(),
343        };
344
345        let mirrored = clone_for_mirror(&req, "api/", "mirror/").unwrap();
346        match &mirrored.inner {
347            McpRequest::ReadResource(params) => {
348                assert_eq!(params.uri, "mirror/docs/readme");
349            }
350            _ => panic!("expected ReadResource"),
351        }
352    }
353
354    #[test]
355    fn test_clone_for_mirror_list_tools_returns_none() {
356        let req = RouterRequest {
357            id: RequestId::Number(1),
358            inner: McpRequest::ListTools(Default::default()),
359            extensions: Extensions::new(),
360        };
361        assert!(clone_for_mirror(&req, "api/", "mirror/").is_none());
362    }
363
364    #[tokio::test]
365    async fn test_mirror_service_passes_through() {
366        let mock = MockService::with_tools(&["api/search", "api-v2/search"]);
367        let mirrors = make_mirrors("api", "api-v2", 100);
368        let mut svc = MirrorService::new(mock, mirrors, "/");
369
370        let resp = call_service(
371            &mut svc,
372            McpRequest::CallTool(CallToolParams {
373                name: "api/search".to_string(),
374                arguments: serde_json::json!({}),
375                meta: None,
376                task: None,
377            }),
378        )
379        .await;
380
381        // Primary response should be returned
382        assert!(resp.inner.is_ok());
383    }
384
385    #[tokio::test]
386    async fn test_mirror_service_non_mirrored_passes_through() {
387        let mock = MockService::with_tools(&["other/tool"]);
388        let mirrors = make_mirrors("api", "api-v2", 100);
389        let mut svc = MirrorService::new(mock, mirrors, "/");
390
391        let resp = call_service(
392            &mut svc,
393            McpRequest::CallTool(CallToolParams {
394                name: "other/tool".to_string(),
395                arguments: serde_json::json!({}),
396                meta: None,
397                task: None,
398            }),
399        )
400        .await;
401
402        assert!(resp.inner.is_ok());
403    }
404
405    #[tokio::test]
406    async fn test_mirror_service_list_tools_not_mirrored() {
407        let mock = MockService::with_tools(&["api/search"]);
408        let mirrors = make_mirrors("api", "api-v2", 100);
409        let mut svc = MirrorService::new(mock, mirrors, "/");
410
411        let resp = call_service(&mut svc, McpRequest::ListTools(Default::default())).await;
412        assert!(resp.inner.is_ok());
413    }
414}