Skip to main content

mcp_proxy/
bearer_scope.rs

1//! Per-token tool scoping for bearer token authentication.
2//!
3//! When scoped bearer tokens are configured, this module provides:
4//! - An Axum middleware that identifies which scoped token was used and
5//!   injects scope info via [`TokenClaims`] into request extensions.
6//! - An MCP middleware that reads scope info from extensions and enforces
7//!   tool allow/deny lists per token.
8//!
9//! # Architecture
10//!
11//! tower-mcp's HTTP transport only bridges [`TokenClaims`] from Axum
12//! extensions to MCP extensions. To pass bearer scope info across this
13//! boundary, the Axum middleware inserts synthetic `TokenClaims` with
14//! scope details in the `extra` map (key: `__bearer_scope`).
15//!
16//! The MCP-level [`BearerScopingService`] reads this marker and applies
17//! the matching token's allow/deny rules.
18
19use std::collections::{HashMap, HashSet};
20use std::convert::Infallible;
21use std::future::Future;
22use std::pin::Pin;
23use std::sync::Arc;
24use std::task::{Context, Poll};
25
26use tower::Service;
27use tower_mcp::oauth::token::TokenClaims;
28use tower_mcp::protocol::{McpRequest, McpResponse};
29use tower_mcp::{RouterRequest, RouterResponse};
30use tower_mcp_types::JsonRpcError;
31
32use crate::config::BearerTokenConfig;
33
34/// Key used in `TokenClaims.extra` to store bearer scope info.
35const BEARER_SCOPE_KEY: &str = "__bearer_scope";
36
37// ---------------------------------------------------------------------------
38// Axum middleware: inject TokenClaims with bearer scope info
39// ---------------------------------------------------------------------------
40
41/// Axum middleware layer that validates bearer tokens and injects scope info.
42///
43/// For scoped tokens, inserts synthetic [`TokenClaims`] into request
44/// extensions so tower-mcp's HTTP transport propagates them to MCP
45/// extensions. Unscoped tokens pass through without `TokenClaims`.
46#[derive(Clone)]
47pub struct ScopedBearerAuthLayer {
48    inner: Arc<ScopedBearerAuthState>,
49}
50
51struct ScopedBearerAuthState {
52    /// All valid tokens (for validation)
53    valid_tokens: HashSet<String>,
54    /// Token -> scope JSON (only for scoped tokens)
55    scopes: HashMap<String, serde_json::Value>,
56}
57
58impl ScopedBearerAuthLayer {
59    /// Build from combined simple + scoped token lists.
60    pub fn new(simple_tokens: &[String], scoped_tokens: &[BearerTokenConfig]) -> Self {
61        let mut valid_tokens = HashSet::new();
62        let mut scopes = HashMap::new();
63
64        for t in simple_tokens {
65            valid_tokens.insert(t.clone());
66        }
67
68        for st in scoped_tokens {
69            valid_tokens.insert(st.token.clone());
70            // Build scope JSON for this token
71            let scope = serde_json::json!({
72                "allow": st.allow_tools,
73                "deny": st.deny_tools,
74            });
75            scopes.insert(st.token.clone(), scope);
76        }
77
78        Self {
79            inner: Arc::new(ScopedBearerAuthState {
80                valid_tokens,
81                scopes,
82            }),
83        }
84    }
85}
86
87impl<S> tower::Layer<S> for ScopedBearerAuthLayer {
88    type Service = ScopedBearerAuthService<S>;
89
90    fn layer(&self, inner: S) -> Self::Service {
91        ScopedBearerAuthService {
92            inner,
93            state: Arc::clone(&self.inner),
94        }
95    }
96}
97
98/// Axum service that validates bearer tokens and injects scope info.
99#[derive(Clone)]
100pub struct ScopedBearerAuthService<S> {
101    inner: S,
102    state: Arc<ScopedBearerAuthState>,
103}
104
105impl<S> Service<axum::http::Request<axum::body::Body>> for ScopedBearerAuthService<S>
106where
107    S: Service<axum::http::Request<axum::body::Body>, Response = axum::response::Response>
108        + Clone
109        + Send
110        + 'static,
111    S::Future: Send,
112    S::Error: Into<tower_mcp::BoxError> + Send,
113{
114    type Response = axum::response::Response;
115    type Error = S::Error;
116    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
117
118    fn poll_ready(
119        &mut self,
120        cx: &mut std::task::Context<'_>,
121    ) -> std::task::Poll<Result<(), Self::Error>> {
122        self.inner.poll_ready(cx)
123    }
124
125    fn call(&mut self, req: axum::http::Request<axum::body::Body>) -> Self::Future {
126        let token = req
127            .headers()
128            .get("Authorization")
129            .and_then(|v| v.to_str().ok())
130            .and_then(|s| s.strip_prefix("Bearer "))
131            .map(|s| s.trim().to_owned());
132
133        let state = Arc::clone(&self.state);
134        let inner = self.inner.clone();
135
136        Box::pin(async move {
137            let Some(token) = token else {
138                return Ok(unauthorized_response("Missing bearer token"));
139            };
140
141            if !state.valid_tokens.contains(&token) {
142                return Ok(unauthorized_response("Invalid bearer token"));
143            }
144
145            let mut req = req;
146
147            // If this is a scoped token, inject TokenClaims with scope info
148            if let Some(scope) = state.scopes.get(&token) {
149                let mut extra = HashMap::new();
150                extra.insert(BEARER_SCOPE_KEY.to_string(), scope.clone());
151                let claims = TokenClaims {
152                    sub: None,
153                    iss: None,
154                    aud: None,
155                    exp: None,
156                    scope: None,
157                    client_id: None,
158                    extra,
159                };
160                req.extensions_mut().insert(claims);
161            }
162
163            tower::ServiceExt::oneshot(inner, req).await
164        })
165    }
166}
167
168/// Construct an HTTP 401 Unauthorized response.
169fn unauthorized_response(message: &str) -> axum::response::Response {
170    use axum::http::StatusCode;
171    use axum::response::IntoResponse;
172
173    let body = serde_json::json!({
174        "jsonrpc": "2.0",
175        "error": {
176            "code": -32001,
177            "message": message
178        },
179        "id": null
180    });
181
182    (StatusCode::UNAUTHORIZED, axum::Json(body)).into_response()
183}
184
185// ---------------------------------------------------------------------------
186// MCP middleware: enforce per-token tool scoping
187// ---------------------------------------------------------------------------
188
189/// Resolved bearer token scope (allow/deny tool sets).
190#[derive(Debug, Clone)]
191struct ResolvedScope {
192    allow: HashSet<String>,
193    deny: HashSet<String>,
194}
195
196impl ResolvedScope {
197    /// Parse scope from the `TokenClaims.extra` map.
198    fn from_claims(claims: &TokenClaims) -> Option<Self> {
199        let scope_val = claims.extra.get(BEARER_SCOPE_KEY)?;
200
201        let allow: HashSet<String> = scope_val
202            .get("allow")
203            .and_then(|v| v.as_array())
204            .map(|arr| {
205                arr.iter()
206                    .filter_map(|v| v.as_str().map(String::from))
207                    .collect()
208            })
209            .unwrap_or_default();
210
211        let deny: HashSet<String> = scope_val
212            .get("deny")
213            .and_then(|v| v.as_array())
214            .map(|arr| {
215                arr.iter()
216                    .filter_map(|v| v.as_str().map(String::from))
217                    .collect()
218            })
219            .unwrap_or_default();
220
221        // If both are empty, this is an unscoped token
222        if allow.is_empty() && deny.is_empty() {
223            return None;
224        }
225
226        Some(Self { allow, deny })
227    }
228
229    /// Check if a tool is allowed under this scope.
230    fn is_tool_allowed(&self, tool_name: &str) -> bool {
231        if !self.allow.is_empty() && !self.allow.contains(tool_name) {
232            return false;
233        }
234        if self.deny.contains(tool_name) {
235            return false;
236        }
237        true
238    }
239}
240
241/// MCP middleware that enforces per-bearer-token tool access control.
242///
243/// Reads scope info from `TokenClaims.extra` (injected by [`ScopedBearerAuthLayer`])
244/// and applies allow/deny lists to tool calls and list responses.
245#[derive(Clone)]
246pub struct BearerScopingService<S> {
247    inner: S,
248}
249
250impl<S> BearerScopingService<S> {
251    /// Wrap an inner MCP service with bearer scoping enforcement.
252    pub fn new(inner: S) -> Self {
253        Self { inner }
254    }
255}
256
257impl<S> Service<RouterRequest> for BearerScopingService<S>
258where
259    S: Service<RouterRequest, Response = RouterResponse, Error = Infallible>
260        + Clone
261        + Send
262        + 'static,
263    S::Future: Send,
264{
265    type Response = RouterResponse;
266    type Error = Infallible;
267    type Future = Pin<Box<dyn Future<Output = Result<RouterResponse, Infallible>> + Send>>;
268
269    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
270        self.inner.poll_ready(cx)
271    }
272
273    fn call(&mut self, req: RouterRequest) -> Self::Future {
274        let request_id = req.id.clone();
275
276        // Try to extract bearer scope from extensions
277        let scope = req
278            .extensions
279            .get::<TokenClaims>()
280            .and_then(ResolvedScope::from_claims);
281
282        // No scope = unscoped token or no auth; pass through
283        let Some(scope) = scope else {
284            let fut = self.inner.call(req);
285            return Box::pin(fut);
286        };
287
288        // Check tool calls against scope
289        if let McpRequest::CallTool(ref params) = req.inner
290            && !scope.is_tool_allowed(&params.name)
291        {
292            let tool_name = params.name.clone();
293            return Box::pin(async move {
294                Ok(RouterResponse {
295                    id: request_id,
296                    inner: Err(JsonRpcError::invalid_params(format!(
297                        "Token is not authorized to call tool: {tool_name}"
298                    ))),
299                })
300            });
301        }
302
303        let fut = self.inner.call(req);
304
305        Box::pin(async move {
306            let mut resp = fut.await?;
307
308            // Filter list_tools response
309            if let Ok(McpResponse::ListTools(ref mut result)) = resp.inner {
310                result
311                    .tools
312                    .retain(|tool| scope.is_tool_allowed(&tool.name));
313            }
314
315            Ok(resp)
316        })
317    }
318}
319
320#[cfg(test)]
321mod tests {
322    use std::collections::HashMap;
323
324    use tower::Service;
325    use tower_mcp::oauth::token::TokenClaims;
326    use tower_mcp::protocol::{
327        CallToolParams, ListToolsParams, McpRequest, McpResponse, RequestId,
328    };
329    use tower_mcp::router::Extensions;
330
331    use super::{BEARER_SCOPE_KEY, BearerScopingService};
332    use crate::test_util::{MockService, call_service};
333
334    fn request_with_bearer_scope(
335        allow: &[&str],
336        deny: &[&str],
337        inner: McpRequest,
338    ) -> tower_mcp::RouterRequest {
339        let mut extra = HashMap::new();
340        extra.insert(
341            BEARER_SCOPE_KEY.to_string(),
342            serde_json::json!({
343                "allow": allow,
344                "deny": deny,
345            }),
346        );
347        let mut extensions = Extensions::new();
348        extensions.insert(TokenClaims {
349            sub: None,
350            iss: None,
351            aud: None,
352            exp: None,
353            scope: None,
354            client_id: None,
355            extra,
356        });
357        tower_mcp::RouterRequest {
358            id: RequestId::Number(1),
359            inner,
360            extensions,
361        }
362    }
363
364    #[tokio::test]
365    async fn no_scope_passes_through() {
366        let mock = MockService::with_tools(&["fs/read", "fs/write", "db/query"]);
367        let mut svc = BearerScopingService::new(mock);
368
369        let resp = call_service(&mut svc, McpRequest::ListTools(ListToolsParams::default())).await;
370        let tools = match resp.inner.unwrap() {
371            McpResponse::ListTools(r) => r.tools,
372            other => panic!("expected ListTools, got: {other:?}"),
373        };
374        assert_eq!(tools.len(), 3);
375    }
376
377    #[tokio::test]
378    async fn allow_list_filters_tools() {
379        let mock = MockService::with_tools(&["fs/read", "fs/write", "db/query"]);
380        let mut svc = BearerScopingService::new(mock);
381
382        let req = request_with_bearer_scope(
383            &["fs/read"],
384            &[],
385            McpRequest::ListTools(ListToolsParams::default()),
386        );
387        let resp = svc.call(req).await.unwrap();
388        let tools = match resp.inner.unwrap() {
389            McpResponse::ListTools(r) => r.tools,
390            other => panic!("expected ListTools, got: {other:?}"),
391        };
392        assert_eq!(tools.len(), 1);
393        assert_eq!(tools[0].name, "fs/read");
394    }
395
396    #[tokio::test]
397    async fn deny_list_filters_tools() {
398        let mock = MockService::with_tools(&["fs/read", "fs/write", "db/query"]);
399        let mut svc = BearerScopingService::new(mock);
400
401        let req = request_with_bearer_scope(
402            &[],
403            &["fs/write"],
404            McpRequest::ListTools(ListToolsParams::default()),
405        );
406        let resp = svc.call(req).await.unwrap();
407        let tools = match resp.inner.unwrap() {
408            McpResponse::ListTools(r) => r.tools,
409            other => panic!("expected ListTools, got: {other:?}"),
410        };
411        assert_eq!(tools.len(), 2);
412        assert!(tools.iter().all(|t| t.name != "fs/write"));
413    }
414
415    #[tokio::test]
416    async fn allow_list_blocks_call() {
417        let mock = MockService::with_tools(&["fs/read", "fs/write"]);
418        let mut svc = BearerScopingService::new(mock);
419
420        let req = request_with_bearer_scope(
421            &["fs/read"],
422            &[],
423            McpRequest::CallTool(CallToolParams {
424                name: "fs/write".to_string(),
425                arguments: serde_json::json!({}),
426                meta: None,
427                task: None,
428            }),
429        );
430        let resp = svc.call(req).await.unwrap();
431        assert!(resp.inner.is_err(), "should block disallowed tool call");
432        let err = resp.inner.unwrap_err();
433        assert!(err.message.contains("fs/write"));
434    }
435
436    #[tokio::test]
437    async fn allow_list_permits_call() {
438        let mock = MockService::with_tools(&["fs/read", "fs/write"]);
439        let mut svc = BearerScopingService::new(mock);
440
441        let req = request_with_bearer_scope(
442            &["fs/read"],
443            &[],
444            McpRequest::CallTool(CallToolParams {
445                name: "fs/read".to_string(),
446                arguments: serde_json::json!({}),
447                meta: None,
448                task: None,
449            }),
450        );
451        let resp = svc.call(req).await.unwrap();
452        assert!(resp.inner.is_ok(), "should allow permitted tool call");
453    }
454
455    #[tokio::test]
456    async fn deny_list_blocks_call() {
457        let mock = MockService::with_tools(&["fs/read", "fs/write"]);
458        let mut svc = BearerScopingService::new(mock);
459
460        let req = request_with_bearer_scope(
461            &[],
462            &["fs/write"],
463            McpRequest::CallTool(CallToolParams {
464                name: "fs/write".to_string(),
465                arguments: serde_json::json!({}),
466                meta: None,
467                task: None,
468            }),
469        );
470        let resp = svc.call(req).await.unwrap();
471        assert!(resp.inner.is_err(), "should block denied tool call");
472    }
473}