Skip to main content

mcp_proxy/
canary.rs

1//! Canary / weighted routing middleware.
2//!
3//! Routes a percentage of requests to a canary backend instead of the primary.
4//! The canary backend is registered as a separate backend with its own namespace,
5//! but its tools are hidden from `ListTools` (via capability filtering). When a
6//! request targets the primary namespace, this middleware probabilistically
7//! rewrites it to target the canary namespace instead.
8//!
9//! # Configuration
10//!
11//! ```toml
12//! [[backends]]
13//! name = "api"
14//! transport = "http"
15//! url = "http://api-v1.internal:8080"
16//! weight = 90
17//!
18//! [[backends]]
19//! name = "api-canary"
20//! transport = "http"
21//! url = "http://api-v2.internal:8080"
22//! weight = 10
23//! canary_of = "api"  # share namespace with api
24//! ```
25//!
26//! # How it works
27//!
28//! 1. Both `api` and `api-canary` are registered as separate backends
29//! 2. `api-canary`'s tools are auto-hidden via capability filtering
30//! 3. When `CallTool("api/search")` arrives, this middleware rolls a weighted
31//!    random selection: 90% chance it passes through to `api`, 10% chance it
32//!    rewrites to `CallTool("api-canary/search")`
33//! 4. `ListTools` always returns only the primary's tools
34
35use std::collections::HashMap;
36use std::convert::Infallible;
37use std::future::Future;
38use std::pin::Pin;
39use std::sync::Arc;
40use std::sync::atomic::{AtomicU64, Ordering};
41use std::task::{Context, Poll};
42
43use tower::{Layer, Service};
44use tower_mcp::router::{Extensions, RouterRequest, RouterResponse};
45use tower_mcp_types::protocol::{CallToolParams, GetPromptParams, McpRequest, ReadResourceParams};
46
47/// Tower layer that produces a [`CanaryService`].
48#[derive(Clone)]
49pub struct CanaryLayer {
50    canaries: HashMap<String, (String, u32, u32)>,
51    separator: String,
52}
53
54impl CanaryLayer {
55    /// Create a new canary routing layer.
56    ///
57    /// `canaries` maps primary backend names to `(canary_name, primary_weight, canary_weight)`.
58    pub fn new(
59        canaries: HashMap<String, (String, u32, u32)>,
60        separator: impl Into<String>,
61    ) -> Self {
62        Self {
63            canaries,
64            separator: separator.into(),
65        }
66    }
67}
68
69impl<S> Layer<S> for CanaryLayer {
70    type Service = CanaryService<S>;
71
72    fn layer(&self, inner: S) -> Self::Service {
73        CanaryService::new(inner, self.canaries.clone(), &self.separator)
74    }
75}
76
77/// Mapping from a primary backend namespace to its canary configuration.
78#[derive(Debug, Clone)]
79struct CanaryMapping {
80    /// Primary namespace prefix (e.g. "api/").
81    primary_prefix: String,
82    /// Canary namespace prefix (e.g. "api-canary/").
83    canary_prefix: String,
84    /// Weight of the primary (e.g. 90).
85    primary_weight: u32,
86    /// Total weight (primary + canary, e.g. 100).
87    total_weight: u32,
88    /// Atomic counter for deterministic weight-based routing.
89    counter: Arc<AtomicU64>,
90}
91
92/// Canary routing middleware.
93///
94/// Wraps the proxy service and probabilistically rewrites requests from
95/// the primary namespace to the canary namespace based on configured weights.
96#[derive(Clone)]
97pub struct CanaryService<S> {
98    inner: S,
99    mappings: Arc<Vec<CanaryMapping>>,
100}
101
102impl<S> CanaryService<S> {
103    /// Create a new canary service.
104    ///
105    /// `canaries` maps primary backend names to `(canary_name, primary_weight, canary_weight)`.
106    /// The `separator` is used to construct namespace prefixes.
107    pub fn new(inner: S, canaries: HashMap<String, (String, u32, u32)>, separator: &str) -> Self {
108        let mappings = canaries
109            .into_iter()
110            .map(
111                |(primary, (canary, primary_weight, canary_weight))| CanaryMapping {
112                    primary_prefix: format!("{primary}{separator}"),
113                    canary_prefix: format!("{canary}{separator}"),
114                    primary_weight,
115                    total_weight: primary_weight + canary_weight,
116                    counter: Arc::new(AtomicU64::new(0)),
117                },
118            )
119            .collect();
120
121        Self {
122            inner,
123            mappings: Arc::new(mappings),
124        }
125    }
126}
127
128/// Check if a request targets a primary namespace and return the mapping.
129fn find_canary<'a>(name: &str, mappings: &'a [CanaryMapping]) -> Option<&'a CanaryMapping> {
130    mappings
131        .iter()
132        .find(|m| name.starts_with(&m.primary_prefix))
133}
134
135/// Deterministic check: should this request go to the canary?
136fn should_route_to_canary(mapping: &CanaryMapping) -> bool {
137    let count = mapping.counter.fetch_add(1, Ordering::Relaxed);
138    let position = count % mapping.total_weight as u64;
139    // Primary gets the first primary_weight slots, canary gets the rest
140    position >= mapping.primary_weight as u64
141}
142
143/// Rewrite a request to target the canary namespace.
144fn rewrite_to_canary(req: RouterRequest, mapping: &CanaryMapping) -> RouterRequest {
145    let new_inner = match req.inner {
146        McpRequest::CallTool(params) if params.name.starts_with(&mapping.primary_prefix) => {
147            let suffix = &params.name[mapping.primary_prefix.len()..];
148            McpRequest::CallTool(CallToolParams {
149                name: format!("{}{suffix}", mapping.canary_prefix),
150                arguments: params.arguments,
151                meta: params.meta,
152                task: params.task,
153            })
154        }
155        McpRequest::ReadResource(params) if params.uri.starts_with(&mapping.primary_prefix) => {
156            let suffix = &params.uri[mapping.primary_prefix.len()..];
157            McpRequest::ReadResource(ReadResourceParams {
158                uri: format!("{}{suffix}", mapping.canary_prefix),
159                meta: params.meta,
160            })
161        }
162        McpRequest::GetPrompt(params) if params.name.starts_with(&mapping.primary_prefix) => {
163            let suffix = &params.name[mapping.primary_prefix.len()..];
164            McpRequest::GetPrompt(GetPromptParams {
165                name: format!("{}{suffix}", mapping.canary_prefix),
166                arguments: params.arguments,
167                meta: params.meta,
168            })
169        }
170        other => other,
171    };
172
173    RouterRequest {
174        id: req.id,
175        inner: new_inner,
176        extensions: Extensions::new(),
177    }
178}
179
180/// Extract the request name for namespace matching.
181fn request_name(req: &McpRequest) -> Option<&str> {
182    match req {
183        McpRequest::CallTool(params) => Some(&params.name),
184        McpRequest::ReadResource(params) => Some(&params.uri),
185        McpRequest::GetPrompt(params) => Some(&params.name),
186        _ => None,
187    }
188}
189
190impl<S> Service<RouterRequest> for CanaryService<S>
191where
192    S: Service<RouterRequest, Response = RouterResponse, Error = Infallible>
193        + Clone
194        + Send
195        + 'static,
196    S::Future: Send,
197{
198    type Response = RouterResponse;
199    type Error = Infallible;
200    type Future = Pin<Box<dyn Future<Output = Result<RouterResponse, Infallible>> + Send>>;
201
202    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
203        self.inner.poll_ready(cx)
204    }
205
206    fn call(&mut self, req: RouterRequest) -> Self::Future {
207        // Check if this request should be routed to a canary
208        let should_canary = request_name(&req.inner)
209            .and_then(|name| find_canary(name, &self.mappings))
210            .filter(|mapping| should_route_to_canary(mapping))
211            .cloned();
212
213        let req = if let Some(ref mapping) = should_canary {
214            tracing::debug!(
215                primary = %mapping.primary_prefix,
216                canary = %mapping.canary_prefix,
217                "Routing request to canary backend"
218            );
219            rewrite_to_canary(req, mapping)
220        } else {
221            req
222        };
223
224        let fut = self.inner.call(req);
225        Box::pin(fut)
226    }
227}
228
229#[cfg(test)]
230mod tests {
231    use super::*;
232    use crate::test_util::{MockService, call_service};
233    use tower_mcp::protocol::RequestId;
234
235    fn make_canaries(
236        primary: &str,
237        canary: &str,
238        primary_weight: u32,
239        canary_weight: u32,
240    ) -> HashMap<String, (String, u32, u32)> {
241        let mut m = HashMap::new();
242        m.insert(
243            primary.to_string(),
244            (canary.to_string(), primary_weight, canary_weight),
245        );
246        m
247    }
248
249    #[test]
250    fn test_find_canary_match() {
251        let mappings = vec![CanaryMapping {
252            primary_prefix: "api/".to_string(),
253            canary_prefix: "api-canary/".to_string(),
254            primary_weight: 90,
255            total_weight: 100,
256            counter: Arc::new(AtomicU64::new(0)),
257        }];
258        assert!(find_canary("api/search", &mappings).is_some());
259        assert!(find_canary("other/search", &mappings).is_none());
260    }
261
262    #[test]
263    fn test_should_route_to_canary_weights() {
264        let mapping = CanaryMapping {
265            primary_prefix: "api/".to_string(),
266            canary_prefix: "api-canary/".to_string(),
267            primary_weight: 90,
268            total_weight: 100,
269            counter: Arc::new(AtomicU64::new(0)),
270        };
271
272        // Over 100 requests, exactly 10 should go to canary
273        let canary_count: u32 = (0..100)
274            .filter(|_| should_route_to_canary(&mapping))
275            .count() as u32;
276        assert_eq!(canary_count, 10);
277    }
278
279    #[test]
280    fn test_should_route_to_canary_50_50() {
281        let mapping = CanaryMapping {
282            primary_prefix: "api/".to_string(),
283            canary_prefix: "api-canary/".to_string(),
284            primary_weight: 50,
285            total_weight: 100,
286            counter: Arc::new(AtomicU64::new(0)),
287        };
288
289        let canary_count: u32 = (0..100)
290            .filter(|_| should_route_to_canary(&mapping))
291            .count() as u32;
292        assert_eq!(canary_count, 50);
293    }
294
295    #[test]
296    fn test_rewrite_to_canary_call_tool() {
297        let mapping = CanaryMapping {
298            primary_prefix: "api/".to_string(),
299            canary_prefix: "api-canary/".to_string(),
300            primary_weight: 90,
301            total_weight: 100,
302            counter: Arc::new(AtomicU64::new(0)),
303        };
304
305        let req = RouterRequest {
306            id: RequestId::Number(1),
307            inner: McpRequest::CallTool(CallToolParams {
308                name: "api/search".to_string(),
309                arguments: serde_json::json!({"q": "test"}),
310                meta: None,
311                task: None,
312            }),
313            extensions: Extensions::new(),
314        };
315
316        let rewritten = rewrite_to_canary(req, &mapping);
317        match &rewritten.inner {
318            McpRequest::CallTool(params) => {
319                assert_eq!(params.name, "api-canary/search");
320                assert_eq!(params.arguments, serde_json::json!({"q": "test"}));
321            }
322            _ => panic!("expected CallTool"),
323        }
324    }
325
326    #[test]
327    fn test_rewrite_to_canary_read_resource() {
328        let mapping = CanaryMapping {
329            primary_prefix: "api/".to_string(),
330            canary_prefix: "api-canary/".to_string(),
331            primary_weight: 90,
332            total_weight: 100,
333            counter: Arc::new(AtomicU64::new(0)),
334        };
335
336        let req = RouterRequest {
337            id: RequestId::Number(1),
338            inner: McpRequest::ReadResource(ReadResourceParams {
339                uri: "api/docs/readme".to_string(),
340                meta: None,
341            }),
342            extensions: Extensions::new(),
343        };
344
345        let rewritten = rewrite_to_canary(req, &mapping);
346        match &rewritten.inner {
347            McpRequest::ReadResource(params) => {
348                assert_eq!(params.uri, "api-canary/docs/readme");
349            }
350            _ => panic!("expected ReadResource"),
351        }
352    }
353
354    #[test]
355    fn test_rewrite_leaves_non_matching_unchanged() {
356        let mapping = CanaryMapping {
357            primary_prefix: "api/".to_string(),
358            canary_prefix: "api-canary/".to_string(),
359            primary_weight: 90,
360            total_weight: 100,
361            counter: Arc::new(AtomicU64::new(0)),
362        };
363
364        let req = RouterRequest {
365            id: RequestId::Number(1),
366            inner: McpRequest::ListTools(Default::default()),
367            extensions: Extensions::new(),
368        };
369
370        let rewritten = rewrite_to_canary(req, &mapping);
371        assert!(matches!(rewritten.inner, McpRequest::ListTools(_)));
372    }
373
374    #[tokio::test]
375    async fn test_canary_service_routes_to_canary() {
376        // Weight 0 primary / 100 canary = always canary
377        let mock = MockService::with_tools(&["api/search", "api-canary/search"]);
378        let canaries = make_canaries("api", "api-canary", 0, 100);
379        let mut svc = CanaryService::new(mock, canaries, "/");
380
381        let resp = call_service(
382            &mut svc,
383            McpRequest::CallTool(CallToolParams {
384                name: "api/search".to_string(),
385                arguments: serde_json::json!({}),
386                meta: None,
387                task: None,
388            }),
389        )
390        .await;
391
392        // Should succeed (rewritten to api-canary/search)
393        assert!(resp.inner.is_ok());
394    }
395
396    #[tokio::test]
397    async fn test_canary_service_passes_through_primary() {
398        // Weight 100 primary / 0 would panic, so use 100/1 (99% primary)
399        let mock = MockService::with_tools(&["api/search"]);
400        let canaries = make_canaries("api", "api-canary", 100, 1);
401        let mut svc = CanaryService::new(mock, canaries, "/");
402
403        // First request goes to primary (position 0 < 100)
404        let resp = call_service(
405            &mut svc,
406            McpRequest::CallTool(CallToolParams {
407                name: "api/search".to_string(),
408                arguments: serde_json::json!({}),
409                meta: None,
410                task: None,
411            }),
412        )
413        .await;
414
415        assert!(resp.inner.is_ok());
416    }
417
418    #[tokio::test]
419    async fn test_canary_service_non_matching_passes_through() {
420        let mock = MockService::with_tools(&["other/tool"]);
421        let canaries = make_canaries("api", "api-canary", 0, 100);
422        let mut svc = CanaryService::new(mock, canaries, "/");
423
424        let resp = call_service(
425            &mut svc,
426            McpRequest::CallTool(CallToolParams {
427                name: "other/tool".to_string(),
428                arguments: serde_json::json!({}),
429                meta: None,
430                task: None,
431            }),
432        )
433        .await;
434
435        assert!(resp.inner.is_ok());
436    }
437
438    #[tokio::test]
439    async fn test_canary_service_list_tools_not_affected() {
440        let mock = MockService::with_tools(&["api/search"]);
441        let canaries = make_canaries("api", "api-canary", 0, 100);
442        let mut svc = CanaryService::new(mock, canaries, "/");
443
444        let resp = call_service(&mut svc, McpRequest::ListTools(Default::default())).await;
445        assert!(resp.inner.is_ok());
446    }
447}