use std::collections::HashMap;
use std::convert::Infallible;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use std::task::{Context, Poll};
use tower::{Layer, Service};
use tower_mcp::router::{Extensions, RouterRequest, RouterResponse};
use tower_mcp_types::protocol::{CallToolParams, GetPromptParams, McpRequest, ReadResourceParams};
#[derive(Clone)]
pub struct CanaryLayer {
canaries: HashMap<String, (String, u32, u32)>,
separator: String,
}
impl CanaryLayer {
pub fn new(
canaries: HashMap<String, (String, u32, u32)>,
separator: impl Into<String>,
) -> Self {
Self {
canaries,
separator: separator.into(),
}
}
}
impl<S> Layer<S> for CanaryLayer {
type Service = CanaryService<S>;
fn layer(&self, inner: S) -> Self::Service {
CanaryService::new(inner, self.canaries.clone(), &self.separator)
}
}
#[derive(Debug, Clone)]
struct CanaryMapping {
primary_prefix: String,
canary_prefix: String,
primary_weight: u32,
total_weight: u32,
counter: Arc<AtomicU64>,
}
#[derive(Clone)]
pub struct CanaryService<S> {
inner: S,
mappings: Arc<Vec<CanaryMapping>>,
}
impl<S> CanaryService<S> {
pub fn new(inner: S, canaries: HashMap<String, (String, u32, u32)>, separator: &str) -> Self {
let mappings = canaries
.into_iter()
.map(
|(primary, (canary, primary_weight, canary_weight))| CanaryMapping {
primary_prefix: format!("{primary}{separator}"),
canary_prefix: format!("{canary}{separator}"),
primary_weight,
total_weight: primary_weight + canary_weight,
counter: Arc::new(AtomicU64::new(0)),
},
)
.collect();
Self {
inner,
mappings: Arc::new(mappings),
}
}
}
fn find_canary<'a>(name: &str, mappings: &'a [CanaryMapping]) -> Option<&'a CanaryMapping> {
mappings
.iter()
.find(|m| name.starts_with(&m.primary_prefix))
}
fn should_route_to_canary(mapping: &CanaryMapping) -> bool {
let count = mapping.counter.fetch_add(1, Ordering::Relaxed);
let position = count % mapping.total_weight as u64;
position >= mapping.primary_weight as u64
}
fn rewrite_to_canary(req: RouterRequest, mapping: &CanaryMapping) -> RouterRequest {
let new_inner = match req.inner {
McpRequest::CallTool(params) if params.name.starts_with(&mapping.primary_prefix) => {
let suffix = ¶ms.name[mapping.primary_prefix.len()..];
McpRequest::CallTool(CallToolParams {
name: format!("{}{suffix}", mapping.canary_prefix),
arguments: params.arguments,
meta: params.meta,
task: params.task,
})
}
McpRequest::ReadResource(params) if params.uri.starts_with(&mapping.primary_prefix) => {
let suffix = ¶ms.uri[mapping.primary_prefix.len()..];
McpRequest::ReadResource(ReadResourceParams {
uri: format!("{}{suffix}", mapping.canary_prefix),
meta: params.meta,
})
}
McpRequest::GetPrompt(params) if params.name.starts_with(&mapping.primary_prefix) => {
let suffix = ¶ms.name[mapping.primary_prefix.len()..];
McpRequest::GetPrompt(GetPromptParams {
name: format!("{}{suffix}", mapping.canary_prefix),
arguments: params.arguments,
meta: params.meta,
})
}
other => other,
};
RouterRequest {
id: req.id,
inner: new_inner,
extensions: Extensions::new(),
}
}
fn request_name(req: &McpRequest) -> Option<&str> {
match req {
McpRequest::CallTool(params) => Some(¶ms.name),
McpRequest::ReadResource(params) => Some(¶ms.uri),
McpRequest::GetPrompt(params) => Some(¶ms.name),
_ => None,
}
}
impl<S> Service<RouterRequest> for CanaryService<S>
where
S: Service<RouterRequest, Response = RouterResponse, Error = Infallible>
+ Clone
+ Send
+ 'static,
S::Future: Send,
{
type Response = RouterResponse;
type Error = Infallible;
type Future = Pin<Box<dyn Future<Output = Result<RouterResponse, Infallible>> + Send>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, req: RouterRequest) -> Self::Future {
let should_canary = request_name(&req.inner)
.and_then(|name| find_canary(name, &self.mappings))
.filter(|mapping| should_route_to_canary(mapping))
.cloned();
let req = if let Some(ref mapping) = should_canary {
tracing::debug!(
primary = %mapping.primary_prefix,
canary = %mapping.canary_prefix,
"Routing request to canary backend"
);
rewrite_to_canary(req, mapping)
} else {
req
};
let fut = self.inner.call(req);
Box::pin(fut)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::test_util::{MockService, call_service};
use tower_mcp::protocol::RequestId;
fn make_canaries(
primary: &str,
canary: &str,
primary_weight: u32,
canary_weight: u32,
) -> HashMap<String, (String, u32, u32)> {
let mut m = HashMap::new();
m.insert(
primary.to_string(),
(canary.to_string(), primary_weight, canary_weight),
);
m
}
#[test]
fn test_find_canary_match() {
let mappings = vec![CanaryMapping {
primary_prefix: "api/".to_string(),
canary_prefix: "api-canary/".to_string(),
primary_weight: 90,
total_weight: 100,
counter: Arc::new(AtomicU64::new(0)),
}];
assert!(find_canary("api/search", &mappings).is_some());
assert!(find_canary("other/search", &mappings).is_none());
}
#[test]
fn test_should_route_to_canary_weights() {
let mapping = CanaryMapping {
primary_prefix: "api/".to_string(),
canary_prefix: "api-canary/".to_string(),
primary_weight: 90,
total_weight: 100,
counter: Arc::new(AtomicU64::new(0)),
};
let canary_count: u32 = (0..100)
.filter(|_| should_route_to_canary(&mapping))
.count() as u32;
assert_eq!(canary_count, 10);
}
#[test]
fn test_should_route_to_canary_50_50() {
let mapping = CanaryMapping {
primary_prefix: "api/".to_string(),
canary_prefix: "api-canary/".to_string(),
primary_weight: 50,
total_weight: 100,
counter: Arc::new(AtomicU64::new(0)),
};
let canary_count: u32 = (0..100)
.filter(|_| should_route_to_canary(&mapping))
.count() as u32;
assert_eq!(canary_count, 50);
}
#[test]
fn test_rewrite_to_canary_call_tool() {
let mapping = CanaryMapping {
primary_prefix: "api/".to_string(),
canary_prefix: "api-canary/".to_string(),
primary_weight: 90,
total_weight: 100,
counter: Arc::new(AtomicU64::new(0)),
};
let req = RouterRequest {
id: RequestId::Number(1),
inner: McpRequest::CallTool(CallToolParams {
name: "api/search".to_string(),
arguments: serde_json::json!({"q": "test"}),
meta: None,
task: None,
}),
extensions: Extensions::new(),
};
let rewritten = rewrite_to_canary(req, &mapping);
match &rewritten.inner {
McpRequest::CallTool(params) => {
assert_eq!(params.name, "api-canary/search");
assert_eq!(params.arguments, serde_json::json!({"q": "test"}));
}
_ => panic!("expected CallTool"),
}
}
#[test]
fn test_rewrite_to_canary_read_resource() {
let mapping = CanaryMapping {
primary_prefix: "api/".to_string(),
canary_prefix: "api-canary/".to_string(),
primary_weight: 90,
total_weight: 100,
counter: Arc::new(AtomicU64::new(0)),
};
let req = RouterRequest {
id: RequestId::Number(1),
inner: McpRequest::ReadResource(ReadResourceParams {
uri: "api/docs/readme".to_string(),
meta: None,
}),
extensions: Extensions::new(),
};
let rewritten = rewrite_to_canary(req, &mapping);
match &rewritten.inner {
McpRequest::ReadResource(params) => {
assert_eq!(params.uri, "api-canary/docs/readme");
}
_ => panic!("expected ReadResource"),
}
}
#[test]
fn test_rewrite_leaves_non_matching_unchanged() {
let mapping = CanaryMapping {
primary_prefix: "api/".to_string(),
canary_prefix: "api-canary/".to_string(),
primary_weight: 90,
total_weight: 100,
counter: Arc::new(AtomicU64::new(0)),
};
let req = RouterRequest {
id: RequestId::Number(1),
inner: McpRequest::ListTools(Default::default()),
extensions: Extensions::new(),
};
let rewritten = rewrite_to_canary(req, &mapping);
assert!(matches!(rewritten.inner, McpRequest::ListTools(_)));
}
#[tokio::test]
async fn test_canary_service_routes_to_canary() {
let mock = MockService::with_tools(&["api/search", "api-canary/search"]);
let canaries = make_canaries("api", "api-canary", 0, 100);
let mut svc = CanaryService::new(mock, canaries, "/");
let resp = call_service(
&mut svc,
McpRequest::CallTool(CallToolParams {
name: "api/search".to_string(),
arguments: serde_json::json!({}),
meta: None,
task: None,
}),
)
.await;
assert!(resp.inner.is_ok());
}
#[tokio::test]
async fn test_canary_service_passes_through_primary() {
let mock = MockService::with_tools(&["api/search"]);
let canaries = make_canaries("api", "api-canary", 100, 1);
let mut svc = CanaryService::new(mock, canaries, "/");
let resp = call_service(
&mut svc,
McpRequest::CallTool(CallToolParams {
name: "api/search".to_string(),
arguments: serde_json::json!({}),
meta: None,
task: None,
}),
)
.await;
assert!(resp.inner.is_ok());
}
#[tokio::test]
async fn test_canary_service_non_matching_passes_through() {
let mock = MockService::with_tools(&["other/tool"]);
let canaries = make_canaries("api", "api-canary", 0, 100);
let mut svc = CanaryService::new(mock, canaries, "/");
let resp = call_service(
&mut svc,
McpRequest::CallTool(CallToolParams {
name: "other/tool".to_string(),
arguments: serde_json::json!({}),
meta: None,
task: None,
}),
)
.await;
assert!(resp.inner.is_ok());
}
#[tokio::test]
async fn test_canary_service_list_tools_not_affected() {
let mock = MockService::with_tools(&["api/search"]);
let canaries = make_canaries("api", "api-canary", 0, 100);
let mut svc = CanaryService::new(mock, canaries, "/");
let resp = call_service(&mut svc, McpRequest::ListTools(Default::default())).await;
assert!(resp.inner.is_ok());
}
}