use std::collections::HashSet;
use std::convert::Infallible;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use tower::Service;
use tower_mcp::router::{RouterRequest, RouterResponse};
use tower_mcp_types::protocol::McpRequest;
#[derive(Debug, Clone)]
pub struct ClientToken {
pub subject: Option<String>,
pub scope: Option<String>,
pub raw_token: Option<String>,
}
#[derive(Clone)]
pub struct TokenPassthroughService<S> {
inner: S,
forward_namespaces: Arc<HashSet<String>>,
}
impl<S> TokenPassthroughService<S> {
pub fn new(inner: S, forward_namespaces: HashSet<String>) -> Self {
Self {
inner,
forward_namespaces: Arc::new(forward_namespaces),
}
}
}
fn request_targets_namespace(req: &McpRequest, namespaces: &HashSet<String>) -> bool {
let name = match req {
McpRequest::CallTool(params) => Some(params.name.as_str()),
McpRequest::ReadResource(params) => Some(params.uri.as_str()),
McpRequest::GetPrompt(params) => Some(params.name.as_str()),
_ => None,
};
if let Some(name) = name {
namespaces.iter().any(|ns| name.starts_with(ns))
} else {
false
}
}
impl<S> Service<RouterRequest> for TokenPassthroughService<S>
where
S: Service<RouterRequest, Response = RouterResponse, Error = Infallible>
+ Clone
+ Send
+ 'static,
S::Future: Send,
{
type Response = RouterResponse;
type Error = Infallible;
type Future = Pin<Box<dyn Future<Output = Result<RouterResponse, Infallible>> + Send>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, mut req: RouterRequest) -> Self::Future {
if !self.forward_namespaces.is_empty()
&& request_targets_namespace(&req.inner, &self.forward_namespaces)
{
let client_token = req
.extensions
.get::<tower_mcp::oauth::token::TokenClaims>()
.map(|claims| ClientToken {
subject: claims.sub.clone(),
scope: claims.scope.clone(),
raw_token: None, });
if let Some(token) = client_token {
tracing::debug!(
subject = ?token.subject,
"Injected ClientToken for forward_auth backend"
);
req.extensions.insert(token);
}
}
let fut = self.inner.call(req);
Box::pin(fut)
}
}
#[cfg(test)]
mod tests {
use std::collections::HashSet;
use tower::Service;
use tower_mcp::protocol::{CallToolParams, McpRequest, RequestId};
use tower_mcp::router::{Extensions, RouterRequest};
use super::{TokenPassthroughService, request_targets_namespace};
use crate::test_util::{MockService, call_service};
#[test]
fn test_request_targets_namespace_match() {
let namespaces: HashSet<String> = ["github/".to_string()].into();
let req = McpRequest::CallTool(CallToolParams {
name: "github/search".to_string(),
arguments: serde_json::json!({}),
meta: None,
task: None,
});
assert!(request_targets_namespace(&req, &namespaces));
}
#[test]
fn test_request_targets_namespace_no_match() {
let namespaces: HashSet<String> = ["github/".to_string()].into();
let req = McpRequest::CallTool(CallToolParams {
name: "db/query".to_string(),
arguments: serde_json::json!({}),
meta: None,
task: None,
});
assert!(!request_targets_namespace(&req, &namespaces));
}
#[test]
fn test_request_targets_namespace_list_tools() {
let namespaces: HashSet<String> = ["github/".to_string()].into();
let req = McpRequest::ListTools(Default::default());
assert!(!request_targets_namespace(&req, &namespaces));
}
#[tokio::test]
async fn test_passthrough_injects_client_token() {
let mock = MockService::with_tools(&["github/search"]);
let namespaces: HashSet<String> = ["github/".to_string()].into();
let mut svc = TokenPassthroughService::new(mock, namespaces);
let mut extensions = Extensions::new();
extensions.insert(tower_mcp::oauth::token::TokenClaims {
sub: Some("user-123".to_string()),
scope: Some("mcp:read".to_string()),
iss: None,
aud: None,
exp: None,
client_id: None,
extra: Default::default(),
});
let req = RouterRequest {
id: RequestId::Number(1),
inner: McpRequest::CallTool(CallToolParams {
name: "github/search".to_string(),
arguments: serde_json::json!({}),
meta: None,
task: None,
}),
extensions,
};
let resp = svc.call(req).await.unwrap();
assert!(resp.inner.is_ok());
}
#[tokio::test]
async fn test_passthrough_skips_non_forward_backends() {
let mock = MockService::with_tools(&["db/query"]);
let namespaces: HashSet<String> = ["github/".to_string()].into();
let mut svc = TokenPassthroughService::new(mock, namespaces);
let resp = call_service(
&mut svc,
McpRequest::CallTool(CallToolParams {
name: "db/query".to_string(),
arguments: serde_json::json!({}),
meta: None,
task: None,
}),
)
.await;
assert!(resp.inner.is_ok());
}
#[tokio::test]
async fn test_passthrough_no_claims_passes_through() {
let mock = MockService::with_tools(&["github/search"]);
let namespaces: HashSet<String> = ["github/".to_string()].into();
let mut svc = TokenPassthroughService::new(mock, namespaces);
let resp = call_service(
&mut svc,
McpRequest::CallTool(CallToolParams {
name: "github/search".to_string(),
arguments: serde_json::json!({}),
meta: None,
task: None,
}),
)
.await;
assert!(resp.inner.is_ok());
}
}