1use std::collections::{HashMap, HashSet};
20use std::convert::Infallible;
21use std::future::Future;
22use std::pin::Pin;
23use std::sync::Arc;
24use std::task::{Context, Poll};
25
26use tower::Service;
27use tower_mcp::oauth::token::TokenClaims;
28use tower_mcp::protocol::{McpRequest, McpResponse};
29use tower_mcp::{RouterRequest, RouterResponse};
30use tower_mcp_types::JsonRpcError;
31
32use crate::config::BearerTokenConfig;
33
34const BEARER_SCOPE_KEY: &str = "__bearer_scope";
36
37#[derive(Clone)]
47pub struct ScopedBearerAuthLayer {
48 inner: Arc<ScopedBearerAuthState>,
49}
50
51struct ScopedBearerAuthState {
52 valid_tokens: HashSet<String>,
54 scopes: HashMap<String, serde_json::Value>,
56}
57
58impl ScopedBearerAuthLayer {
59 pub fn new(simple_tokens: &[String], scoped_tokens: &[BearerTokenConfig]) -> Self {
61 let mut valid_tokens = HashSet::new();
62 let mut scopes = HashMap::new();
63
64 for t in simple_tokens {
65 valid_tokens.insert(t.clone());
66 }
67
68 for st in scoped_tokens {
69 valid_tokens.insert(st.token.clone());
70 let scope = serde_json::json!({
72 "allow": st.allow_tools,
73 "deny": st.deny_tools,
74 });
75 scopes.insert(st.token.clone(), scope);
76 }
77
78 Self {
79 inner: Arc::new(ScopedBearerAuthState {
80 valid_tokens,
81 scopes,
82 }),
83 }
84 }
85}
86
87impl<S> tower::Layer<S> for ScopedBearerAuthLayer {
88 type Service = ScopedBearerAuthService<S>;
89
90 fn layer(&self, inner: S) -> Self::Service {
91 ScopedBearerAuthService {
92 inner,
93 state: Arc::clone(&self.inner),
94 }
95 }
96}
97
98#[derive(Clone)]
100pub struct ScopedBearerAuthService<S> {
101 inner: S,
102 state: Arc<ScopedBearerAuthState>,
103}
104
105impl<S> Service<axum::http::Request<axum::body::Body>> for ScopedBearerAuthService<S>
106where
107 S: Service<axum::http::Request<axum::body::Body>, Response = axum::response::Response>
108 + Clone
109 + Send
110 + 'static,
111 S::Future: Send,
112 S::Error: Into<tower_mcp::BoxError> + Send,
113{
114 type Response = axum::response::Response;
115 type Error = S::Error;
116 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
117
118 fn poll_ready(
119 &mut self,
120 cx: &mut std::task::Context<'_>,
121 ) -> std::task::Poll<Result<(), Self::Error>> {
122 self.inner.poll_ready(cx)
123 }
124
125 fn call(&mut self, req: axum::http::Request<axum::body::Body>) -> Self::Future {
126 let token = req
127 .headers()
128 .get("Authorization")
129 .and_then(|v| v.to_str().ok())
130 .and_then(|s| s.strip_prefix("Bearer "))
131 .map(|s| s.trim().to_owned());
132
133 let state = Arc::clone(&self.state);
134 let inner = self.inner.clone();
135
136 Box::pin(async move {
137 let Some(token) = token else {
138 return Ok(unauthorized_response("Missing bearer token"));
139 };
140
141 if !state.valid_tokens.contains(&token) {
142 return Ok(unauthorized_response("Invalid bearer token"));
143 }
144
145 let mut req = req;
146
147 if let Some(scope) = state.scopes.get(&token) {
149 let mut extra = HashMap::new();
150 extra.insert(BEARER_SCOPE_KEY.to_string(), scope.clone());
151 let claims = TokenClaims {
152 sub: None,
153 iss: None,
154 aud: None,
155 exp: None,
156 scope: None,
157 client_id: None,
158 extra,
159 };
160 req.extensions_mut().insert(claims);
161 }
162
163 tower::ServiceExt::oneshot(inner, req).await
164 })
165 }
166}
167
168fn unauthorized_response(message: &str) -> axum::response::Response {
170 use axum::http::StatusCode;
171 use axum::response::IntoResponse;
172
173 let body = serde_json::json!({
174 "jsonrpc": "2.0",
175 "error": {
176 "code": -32001,
177 "message": message
178 },
179 "id": null
180 });
181
182 (StatusCode::UNAUTHORIZED, axum::Json(body)).into_response()
183}
184
185#[derive(Debug, Clone)]
191struct ResolvedScope {
192 allow: HashSet<String>,
193 deny: HashSet<String>,
194}
195
196impl ResolvedScope {
197 fn from_claims(claims: &TokenClaims) -> Option<Self> {
199 let scope_val = claims.extra.get(BEARER_SCOPE_KEY)?;
200
201 let allow: HashSet<String> = scope_val
202 .get("allow")
203 .and_then(|v| v.as_array())
204 .map(|arr| {
205 arr.iter()
206 .filter_map(|v| v.as_str().map(String::from))
207 .collect()
208 })
209 .unwrap_or_default();
210
211 let deny: HashSet<String> = scope_val
212 .get("deny")
213 .and_then(|v| v.as_array())
214 .map(|arr| {
215 arr.iter()
216 .filter_map(|v| v.as_str().map(String::from))
217 .collect()
218 })
219 .unwrap_or_default();
220
221 if allow.is_empty() && deny.is_empty() {
223 return None;
224 }
225
226 Some(Self { allow, deny })
227 }
228
229 fn is_tool_allowed(&self, tool_name: &str) -> bool {
231 if !self.allow.is_empty() && !self.allow.contains(tool_name) {
232 return false;
233 }
234 if self.deny.contains(tool_name) {
235 return false;
236 }
237 true
238 }
239}
240
241#[derive(Clone)]
246pub struct BearerScopingService<S> {
247 inner: S,
248}
249
250impl<S> BearerScopingService<S> {
251 pub fn new(inner: S) -> Self {
253 Self { inner }
254 }
255}
256
257impl<S> Service<RouterRequest> for BearerScopingService<S>
258where
259 S: Service<RouterRequest, Response = RouterResponse, Error = Infallible>
260 + Clone
261 + Send
262 + 'static,
263 S::Future: Send,
264{
265 type Response = RouterResponse;
266 type Error = Infallible;
267 type Future = Pin<Box<dyn Future<Output = Result<RouterResponse, Infallible>> + Send>>;
268
269 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
270 self.inner.poll_ready(cx)
271 }
272
273 fn call(&mut self, req: RouterRequest) -> Self::Future {
274 let request_id = req.id.clone();
275
276 let scope = req
278 .extensions
279 .get::<TokenClaims>()
280 .and_then(ResolvedScope::from_claims);
281
282 let Some(scope) = scope else {
284 let fut = self.inner.call(req);
285 return Box::pin(fut);
286 };
287
288 if let McpRequest::CallTool(ref params) = req.inner
290 && !scope.is_tool_allowed(¶ms.name)
291 {
292 let tool_name = params.name.clone();
293 return Box::pin(async move {
294 Ok(RouterResponse {
295 id: request_id,
296 inner: Err(JsonRpcError::invalid_params(format!(
297 "Token is not authorized to call tool: {tool_name}"
298 ))),
299 })
300 });
301 }
302
303 let fut = self.inner.call(req);
304
305 Box::pin(async move {
306 let mut resp = fut.await?;
307
308 if let Ok(McpResponse::ListTools(ref mut result)) = resp.inner {
310 result
311 .tools
312 .retain(|tool| scope.is_tool_allowed(&tool.name));
313 }
314
315 Ok(resp)
316 })
317 }
318}
319
320#[cfg(test)]
321mod tests {
322 use std::collections::HashMap;
323
324 use tower::Service;
325 use tower_mcp::oauth::token::TokenClaims;
326 use tower_mcp::protocol::{
327 CallToolParams, ListToolsParams, McpRequest, McpResponse, RequestId,
328 };
329 use tower_mcp::router::Extensions;
330
331 use super::{BEARER_SCOPE_KEY, BearerScopingService};
332 use crate::test_util::{MockService, call_service};
333
334 fn request_with_bearer_scope(
335 allow: &[&str],
336 deny: &[&str],
337 inner: McpRequest,
338 ) -> tower_mcp::RouterRequest {
339 let mut extra = HashMap::new();
340 extra.insert(
341 BEARER_SCOPE_KEY.to_string(),
342 serde_json::json!({
343 "allow": allow,
344 "deny": deny,
345 }),
346 );
347 let mut extensions = Extensions::new();
348 extensions.insert(TokenClaims {
349 sub: None,
350 iss: None,
351 aud: None,
352 exp: None,
353 scope: None,
354 client_id: None,
355 extra,
356 });
357 tower_mcp::RouterRequest {
358 id: RequestId::Number(1),
359 inner,
360 extensions,
361 }
362 }
363
364 #[tokio::test]
365 async fn no_scope_passes_through() {
366 let mock = MockService::with_tools(&["fs/read", "fs/write", "db/query"]);
367 let mut svc = BearerScopingService::new(mock);
368
369 let resp = call_service(&mut svc, McpRequest::ListTools(ListToolsParams::default())).await;
370 let tools = match resp.inner.unwrap() {
371 McpResponse::ListTools(r) => r.tools,
372 other => panic!("expected ListTools, got: {other:?}"),
373 };
374 assert_eq!(tools.len(), 3);
375 }
376
377 #[tokio::test]
378 async fn allow_list_filters_tools() {
379 let mock = MockService::with_tools(&["fs/read", "fs/write", "db/query"]);
380 let mut svc = BearerScopingService::new(mock);
381
382 let req = request_with_bearer_scope(
383 &["fs/read"],
384 &[],
385 McpRequest::ListTools(ListToolsParams::default()),
386 );
387 let resp = svc.call(req).await.unwrap();
388 let tools = match resp.inner.unwrap() {
389 McpResponse::ListTools(r) => r.tools,
390 other => panic!("expected ListTools, got: {other:?}"),
391 };
392 assert_eq!(tools.len(), 1);
393 assert_eq!(tools[0].name, "fs/read");
394 }
395
396 #[tokio::test]
397 async fn deny_list_filters_tools() {
398 let mock = MockService::with_tools(&["fs/read", "fs/write", "db/query"]);
399 let mut svc = BearerScopingService::new(mock);
400
401 let req = request_with_bearer_scope(
402 &[],
403 &["fs/write"],
404 McpRequest::ListTools(ListToolsParams::default()),
405 );
406 let resp = svc.call(req).await.unwrap();
407 let tools = match resp.inner.unwrap() {
408 McpResponse::ListTools(r) => r.tools,
409 other => panic!("expected ListTools, got: {other:?}"),
410 };
411 assert_eq!(tools.len(), 2);
412 assert!(tools.iter().all(|t| t.name != "fs/write"));
413 }
414
415 #[tokio::test]
416 async fn allow_list_blocks_call() {
417 let mock = MockService::with_tools(&["fs/read", "fs/write"]);
418 let mut svc = BearerScopingService::new(mock);
419
420 let req = request_with_bearer_scope(
421 &["fs/read"],
422 &[],
423 McpRequest::CallTool(CallToolParams {
424 name: "fs/write".to_string(),
425 arguments: serde_json::json!({}),
426 meta: None,
427 task: None,
428 }),
429 );
430 let resp = svc.call(req).await.unwrap();
431 assert!(resp.inner.is_err(), "should block disallowed tool call");
432 let err = resp.inner.unwrap_err();
433 assert!(err.message.contains("fs/write"));
434 }
435
436 #[tokio::test]
437 async fn allow_list_permits_call() {
438 let mock = MockService::with_tools(&["fs/read", "fs/write"]);
439 let mut svc = BearerScopingService::new(mock);
440
441 let req = request_with_bearer_scope(
442 &["fs/read"],
443 &[],
444 McpRequest::CallTool(CallToolParams {
445 name: "fs/read".to_string(),
446 arguments: serde_json::json!({}),
447 meta: None,
448 task: None,
449 }),
450 );
451 let resp = svc.call(req).await.unwrap();
452 assert!(resp.inner.is_ok(), "should allow permitted tool call");
453 }
454
455 #[tokio::test]
456 async fn deny_list_blocks_call() {
457 let mock = MockService::with_tools(&["fs/read", "fs/write"]);
458 let mut svc = BearerScopingService::new(mock);
459
460 let req = request_with_bearer_scope(
461 &[],
462 &["fs/write"],
463 McpRequest::CallTool(CallToolParams {
464 name: "fs/write".to_string(),
465 arguments: serde_json::json!({}),
466 meta: None,
467 task: None,
468 }),
469 );
470 let resp = svc.call(req).await.unwrap();
471 assert!(resp.inner.is_err(), "should block denied tool call");
472 }
473}