1use std::collections::HashMap;
36use std::convert::Infallible;
37use std::future::Future;
38use std::pin::Pin;
39use std::sync::Arc;
40use std::sync::atomic::{AtomicU64, Ordering};
41use std::task::{Context, Poll};
42
43use tower::{Layer, Service};
44use tower_mcp::router::{Extensions, RouterRequest, RouterResponse};
45use tower_mcp_types::protocol::{CallToolParams, GetPromptParams, McpRequest, ReadResourceParams};
46
47#[derive(Clone)]
49pub struct CanaryLayer {
50 canaries: HashMap<String, (String, u32, u32)>,
51 separator: String,
52}
53
54impl CanaryLayer {
55 pub fn new(
59 canaries: HashMap<String, (String, u32, u32)>,
60 separator: impl Into<String>,
61 ) -> Self {
62 Self {
63 canaries,
64 separator: separator.into(),
65 }
66 }
67}
68
69impl<S> Layer<S> for CanaryLayer {
70 type Service = CanaryService<S>;
71
72 fn layer(&self, inner: S) -> Self::Service {
73 CanaryService::new(inner, self.canaries.clone(), &self.separator)
74 }
75}
76
77#[derive(Debug, Clone)]
79struct CanaryMapping {
80 primary_prefix: String,
82 canary_prefix: String,
84 primary_weight: u32,
86 total_weight: u32,
88 counter: Arc<AtomicU64>,
90}
91
92#[derive(Clone)]
97pub struct CanaryService<S> {
98 inner: S,
99 mappings: Arc<Vec<CanaryMapping>>,
100}
101
102impl<S> CanaryService<S> {
103 pub fn new(inner: S, canaries: HashMap<String, (String, u32, u32)>, separator: &str) -> Self {
108 let mappings = canaries
109 .into_iter()
110 .map(
111 |(primary, (canary, primary_weight, canary_weight))| CanaryMapping {
112 primary_prefix: format!("{primary}{separator}"),
113 canary_prefix: format!("{canary}{separator}"),
114 primary_weight,
115 total_weight: primary_weight + canary_weight,
116 counter: Arc::new(AtomicU64::new(0)),
117 },
118 )
119 .collect();
120
121 Self {
122 inner,
123 mappings: Arc::new(mappings),
124 }
125 }
126}
127
128fn find_canary<'a>(name: &str, mappings: &'a [CanaryMapping]) -> Option<&'a CanaryMapping> {
130 mappings
131 .iter()
132 .find(|m| name.starts_with(&m.primary_prefix))
133}
134
135fn should_route_to_canary(mapping: &CanaryMapping) -> bool {
137 let count = mapping.counter.fetch_add(1, Ordering::Relaxed);
138 let position = count % mapping.total_weight as u64;
139 position >= mapping.primary_weight as u64
141}
142
143fn rewrite_to_canary(req: RouterRequest, mapping: &CanaryMapping) -> RouterRequest {
145 let new_inner = match req.inner {
146 McpRequest::CallTool(params) if params.name.starts_with(&mapping.primary_prefix) => {
147 let suffix = ¶ms.name[mapping.primary_prefix.len()..];
148 McpRequest::CallTool(CallToolParams {
149 name: format!("{}{suffix}", mapping.canary_prefix),
150 arguments: params.arguments,
151 meta: params.meta,
152 task: params.task,
153 })
154 }
155 McpRequest::ReadResource(params) if params.uri.starts_with(&mapping.primary_prefix) => {
156 let suffix = ¶ms.uri[mapping.primary_prefix.len()..];
157 McpRequest::ReadResource(ReadResourceParams {
158 uri: format!("{}{suffix}", mapping.canary_prefix),
159 meta: params.meta,
160 })
161 }
162 McpRequest::GetPrompt(params) if params.name.starts_with(&mapping.primary_prefix) => {
163 let suffix = ¶ms.name[mapping.primary_prefix.len()..];
164 McpRequest::GetPrompt(GetPromptParams {
165 name: format!("{}{suffix}", mapping.canary_prefix),
166 arguments: params.arguments,
167 meta: params.meta,
168 })
169 }
170 other => other,
171 };
172
173 RouterRequest {
174 id: req.id,
175 inner: new_inner,
176 extensions: Extensions::new(),
177 }
178}
179
180fn request_name(req: &McpRequest) -> Option<&str> {
182 match req {
183 McpRequest::CallTool(params) => Some(¶ms.name),
184 McpRequest::ReadResource(params) => Some(¶ms.uri),
185 McpRequest::GetPrompt(params) => Some(¶ms.name),
186 _ => None,
187 }
188}
189
190impl<S> Service<RouterRequest> for CanaryService<S>
191where
192 S: Service<RouterRequest, Response = RouterResponse, Error = Infallible>
193 + Clone
194 + Send
195 + 'static,
196 S::Future: Send,
197{
198 type Response = RouterResponse;
199 type Error = Infallible;
200 type Future = Pin<Box<dyn Future<Output = Result<RouterResponse, Infallible>> + Send>>;
201
202 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
203 self.inner.poll_ready(cx)
204 }
205
206 fn call(&mut self, req: RouterRequest) -> Self::Future {
207 let should_canary = request_name(&req.inner)
209 .and_then(|name| find_canary(name, &self.mappings))
210 .filter(|mapping| should_route_to_canary(mapping))
211 .cloned();
212
213 let req = if let Some(ref mapping) = should_canary {
214 tracing::debug!(
215 primary = %mapping.primary_prefix,
216 canary = %mapping.canary_prefix,
217 "Routing request to canary backend"
218 );
219 rewrite_to_canary(req, mapping)
220 } else {
221 req
222 };
223
224 let fut = self.inner.call(req);
225 Box::pin(fut)
226 }
227}
228
229#[cfg(test)]
230mod tests {
231 use super::*;
232 use crate::test_util::{MockService, call_service};
233 use tower_mcp::protocol::RequestId;
234
235 fn make_canaries(
236 primary: &str,
237 canary: &str,
238 primary_weight: u32,
239 canary_weight: u32,
240 ) -> HashMap<String, (String, u32, u32)> {
241 let mut m = HashMap::new();
242 m.insert(
243 primary.to_string(),
244 (canary.to_string(), primary_weight, canary_weight),
245 );
246 m
247 }
248
249 #[test]
250 fn test_find_canary_match() {
251 let mappings = vec![CanaryMapping {
252 primary_prefix: "api/".to_string(),
253 canary_prefix: "api-canary/".to_string(),
254 primary_weight: 90,
255 total_weight: 100,
256 counter: Arc::new(AtomicU64::new(0)),
257 }];
258 assert!(find_canary("api/search", &mappings).is_some());
259 assert!(find_canary("other/search", &mappings).is_none());
260 }
261
262 #[test]
263 fn test_should_route_to_canary_weights() {
264 let mapping = CanaryMapping {
265 primary_prefix: "api/".to_string(),
266 canary_prefix: "api-canary/".to_string(),
267 primary_weight: 90,
268 total_weight: 100,
269 counter: Arc::new(AtomicU64::new(0)),
270 };
271
272 let canary_count: u32 = (0..100)
274 .filter(|_| should_route_to_canary(&mapping))
275 .count() as u32;
276 assert_eq!(canary_count, 10);
277 }
278
279 #[test]
280 fn test_should_route_to_canary_50_50() {
281 let mapping = CanaryMapping {
282 primary_prefix: "api/".to_string(),
283 canary_prefix: "api-canary/".to_string(),
284 primary_weight: 50,
285 total_weight: 100,
286 counter: Arc::new(AtomicU64::new(0)),
287 };
288
289 let canary_count: u32 = (0..100)
290 .filter(|_| should_route_to_canary(&mapping))
291 .count() as u32;
292 assert_eq!(canary_count, 50);
293 }
294
295 #[test]
296 fn test_rewrite_to_canary_call_tool() {
297 let mapping = CanaryMapping {
298 primary_prefix: "api/".to_string(),
299 canary_prefix: "api-canary/".to_string(),
300 primary_weight: 90,
301 total_weight: 100,
302 counter: Arc::new(AtomicU64::new(0)),
303 };
304
305 let req = RouterRequest {
306 id: RequestId::Number(1),
307 inner: McpRequest::CallTool(CallToolParams {
308 name: "api/search".to_string(),
309 arguments: serde_json::json!({"q": "test"}),
310 meta: None,
311 task: None,
312 }),
313 extensions: Extensions::new(),
314 };
315
316 let rewritten = rewrite_to_canary(req, &mapping);
317 match &rewritten.inner {
318 McpRequest::CallTool(params) => {
319 assert_eq!(params.name, "api-canary/search");
320 assert_eq!(params.arguments, serde_json::json!({"q": "test"}));
321 }
322 _ => panic!("expected CallTool"),
323 }
324 }
325
326 #[test]
327 fn test_rewrite_to_canary_read_resource() {
328 let mapping = CanaryMapping {
329 primary_prefix: "api/".to_string(),
330 canary_prefix: "api-canary/".to_string(),
331 primary_weight: 90,
332 total_weight: 100,
333 counter: Arc::new(AtomicU64::new(0)),
334 };
335
336 let req = RouterRequest {
337 id: RequestId::Number(1),
338 inner: McpRequest::ReadResource(ReadResourceParams {
339 uri: "api/docs/readme".to_string(),
340 meta: None,
341 }),
342 extensions: Extensions::new(),
343 };
344
345 let rewritten = rewrite_to_canary(req, &mapping);
346 match &rewritten.inner {
347 McpRequest::ReadResource(params) => {
348 assert_eq!(params.uri, "api-canary/docs/readme");
349 }
350 _ => panic!("expected ReadResource"),
351 }
352 }
353
354 #[test]
355 fn test_rewrite_leaves_non_matching_unchanged() {
356 let mapping = CanaryMapping {
357 primary_prefix: "api/".to_string(),
358 canary_prefix: "api-canary/".to_string(),
359 primary_weight: 90,
360 total_weight: 100,
361 counter: Arc::new(AtomicU64::new(0)),
362 };
363
364 let req = RouterRequest {
365 id: RequestId::Number(1),
366 inner: McpRequest::ListTools(Default::default()),
367 extensions: Extensions::new(),
368 };
369
370 let rewritten = rewrite_to_canary(req, &mapping);
371 assert!(matches!(rewritten.inner, McpRequest::ListTools(_)));
372 }
373
374 #[tokio::test]
375 async fn test_canary_service_routes_to_canary() {
376 let mock = MockService::with_tools(&["api/search", "api-canary/search"]);
378 let canaries = make_canaries("api", "api-canary", 0, 100);
379 let mut svc = CanaryService::new(mock, canaries, "/");
380
381 let resp = call_service(
382 &mut svc,
383 McpRequest::CallTool(CallToolParams {
384 name: "api/search".to_string(),
385 arguments: serde_json::json!({}),
386 meta: None,
387 task: None,
388 }),
389 )
390 .await;
391
392 assert!(resp.inner.is_ok());
394 }
395
396 #[tokio::test]
397 async fn test_canary_service_passes_through_primary() {
398 let mock = MockService::with_tools(&["api/search"]);
400 let canaries = make_canaries("api", "api-canary", 100, 1);
401 let mut svc = CanaryService::new(mock, canaries, "/");
402
403 let resp = call_service(
405 &mut svc,
406 McpRequest::CallTool(CallToolParams {
407 name: "api/search".to_string(),
408 arguments: serde_json::json!({}),
409 meta: None,
410 task: None,
411 }),
412 )
413 .await;
414
415 assert!(resp.inner.is_ok());
416 }
417
418 #[tokio::test]
419 async fn test_canary_service_non_matching_passes_through() {
420 let mock = MockService::with_tools(&["other/tool"]);
421 let canaries = make_canaries("api", "api-canary", 0, 100);
422 let mut svc = CanaryService::new(mock, canaries, "/");
423
424 let resp = call_service(
425 &mut svc,
426 McpRequest::CallTool(CallToolParams {
427 name: "other/tool".to_string(),
428 arguments: serde_json::json!({}),
429 meta: None,
430 task: None,
431 }),
432 )
433 .await;
434
435 assert!(resp.inner.is_ok());
436 }
437
438 #[tokio::test]
439 async fn test_canary_service_list_tools_not_affected() {
440 let mock = MockService::with_tools(&["api/search"]);
441 let canaries = make_canaries("api", "api-canary", 0, 100);
442 let mut svc = CanaryService::new(mock, canaries, "/");
443
444 let resp = call_service(&mut svc, McpRequest::ListTools(Default::default())).await;
445 assert!(resp.inner.is_ok());
446 }
447}