1use std::collections::HashSet;
33use std::convert::Infallible;
34use std::future::Future;
35use std::pin::Pin;
36use std::sync::Arc;
37use std::task::{Context, Poll};
38
39use tower::Service;
40
41use tower_mcp::router::{RouterRequest, RouterResponse};
42use tower_mcp_types::protocol::McpRequest;
43
44#[derive(Debug, Clone)]
49pub struct ClientToken {
50 pub subject: Option<String>,
52 pub scope: Option<String>,
54 pub raw_token: Option<String>,
56}
57
58#[derive(Clone)]
61pub struct TokenPassthroughService<S> {
62 inner: S,
63 forward_namespaces: Arc<HashSet<String>>,
64}
65
66impl<S> TokenPassthroughService<S> {
67 pub fn new(inner: S, forward_namespaces: HashSet<String>) -> Self {
72 Self {
73 inner,
74 forward_namespaces: Arc::new(forward_namespaces),
75 }
76 }
77}
78
79fn request_targets_namespace(req: &McpRequest, namespaces: &HashSet<String>) -> bool {
81 let name = match req {
82 McpRequest::CallTool(params) => Some(params.name.as_str()),
83 McpRequest::ReadResource(params) => Some(params.uri.as_str()),
84 McpRequest::GetPrompt(params) => Some(params.name.as_str()),
85 _ => None,
86 };
87 if let Some(name) = name {
88 namespaces.iter().any(|ns| name.starts_with(ns))
89 } else {
90 false
91 }
92}
93
94impl<S> Service<RouterRequest> for TokenPassthroughService<S>
95where
96 S: Service<RouterRequest, Response = RouterResponse, Error = Infallible>
97 + Clone
98 + Send
99 + 'static,
100 S::Future: Send,
101{
102 type Response = RouterResponse;
103 type Error = Infallible;
104 type Future = Pin<Box<dyn Future<Output = Result<RouterResponse, Infallible>> + Send>>;
105
106 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
107 self.inner.poll_ready(cx)
108 }
109
110 fn call(&mut self, mut req: RouterRequest) -> Self::Future {
111 if !self.forward_namespaces.is_empty()
113 && request_targets_namespace(&req.inner, &self.forward_namespaces)
114 {
115 let client_token = req
116 .extensions
117 .get::<tower_mcp::oauth::token::TokenClaims>()
118 .map(|claims| ClientToken {
119 subject: claims.sub.clone(),
120 scope: claims.scope.clone(),
121 raw_token: None, });
123 if let Some(token) = client_token {
124 tracing::debug!(
125 subject = ?token.subject,
126 "Injected ClientToken for forward_auth backend"
127 );
128 req.extensions.insert(token);
129 }
130 }
131
132 let fut = self.inner.call(req);
133 Box::pin(fut)
134 }
135}
136
137#[cfg(test)]
138mod tests {
139 use std::collections::HashSet;
140
141 use tower::Service;
142 use tower_mcp::protocol::{CallToolParams, McpRequest, RequestId};
143 use tower_mcp::router::{Extensions, RouterRequest};
144
145 use super::{TokenPassthroughService, request_targets_namespace};
146 use crate::test_util::{MockService, call_service};
147
148 #[test]
149 fn test_request_targets_namespace_match() {
150 let namespaces: HashSet<String> = ["github/".to_string()].into();
151 let req = McpRequest::CallTool(CallToolParams {
152 name: "github/search".to_string(),
153 arguments: serde_json::json!({}),
154 meta: None,
155 task: None,
156 });
157 assert!(request_targets_namespace(&req, &namespaces));
158 }
159
160 #[test]
161 fn test_request_targets_namespace_no_match() {
162 let namespaces: HashSet<String> = ["github/".to_string()].into();
163 let req = McpRequest::CallTool(CallToolParams {
164 name: "db/query".to_string(),
165 arguments: serde_json::json!({}),
166 meta: None,
167 task: None,
168 });
169 assert!(!request_targets_namespace(&req, &namespaces));
170 }
171
172 #[test]
173 fn test_request_targets_namespace_list_tools() {
174 let namespaces: HashSet<String> = ["github/".to_string()].into();
175 let req = McpRequest::ListTools(Default::default());
176 assert!(!request_targets_namespace(&req, &namespaces));
177 }
178
179 #[tokio::test]
180 async fn test_passthrough_injects_client_token() {
181 let mock = MockService::with_tools(&["github/search"]);
182 let namespaces: HashSet<String> = ["github/".to_string()].into();
183 let mut svc = TokenPassthroughService::new(mock, namespaces);
184
185 let mut extensions = Extensions::new();
187 extensions.insert(tower_mcp::oauth::token::TokenClaims {
188 sub: Some("user-123".to_string()),
189 scope: Some("mcp:read".to_string()),
190 iss: None,
191 aud: None,
192 exp: None,
193 client_id: None,
194 extra: Default::default(),
195 });
196
197 let req = RouterRequest {
198 id: RequestId::Number(1),
199 inner: McpRequest::CallTool(CallToolParams {
200 name: "github/search".to_string(),
201 arguments: serde_json::json!({}),
202 meta: None,
203 task: None,
204 }),
205 extensions,
206 };
207
208 let resp = svc.call(req).await.unwrap();
209 assert!(resp.inner.is_ok());
210 }
211
212 #[tokio::test]
213 async fn test_passthrough_skips_non_forward_backends() {
214 let mock = MockService::with_tools(&["db/query"]);
215 let namespaces: HashSet<String> = ["github/".to_string()].into();
216 let mut svc = TokenPassthroughService::new(mock, namespaces);
217
218 let resp = call_service(
219 &mut svc,
220 McpRequest::CallTool(CallToolParams {
221 name: "db/query".to_string(),
222 arguments: serde_json::json!({}),
223 meta: None,
224 task: None,
225 }),
226 )
227 .await;
228
229 assert!(resp.inner.is_ok());
230 }
231
232 #[tokio::test]
233 async fn test_passthrough_no_claims_passes_through() {
234 let mock = MockService::with_tools(&["github/search"]);
235 let namespaces: HashSet<String> = ["github/".to_string()].into();
236 let mut svc = TokenPassthroughService::new(mock, namespaces);
237
238 let resp = call_service(
240 &mut svc,
241 McpRequest::CallTool(CallToolParams {
242 name: "github/search".to_string(),
243 arguments: serde_json::json!({}),
244 meta: None,
245 task: None,
246 }),
247 )
248 .await;
249
250 assert!(resp.inner.is_ok());
251 }
252}