mcp-proxy 0.3.1

Standalone MCP proxy -- config-driven reverse proxy with auth, rate limiting, and observability
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
//! Per-token tool scoping for bearer token authentication.
//!
//! When scoped bearer tokens are configured, this module provides:
//! - An Axum middleware that identifies which scoped token was used and
//!   injects scope info via [`TokenClaims`] into request extensions.
//! - An MCP middleware that reads scope info from extensions and enforces
//!   tool allow/deny lists per token.
//!
//! # Architecture
//!
//! tower-mcp's HTTP transport only bridges [`TokenClaims`] from Axum
//! extensions to MCP extensions. To pass bearer scope info across this
//! boundary, the Axum middleware inserts synthetic `TokenClaims` with
//! scope details in the `extra` map (key: `__bearer_scope`).
//!
//! The MCP-level [`BearerScopingService`] reads this marker and applies
//! the matching token's allow/deny rules.

use std::collections::{HashMap, HashSet};
use std::convert::Infallible;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};

use tower::Service;
use tower_mcp::oauth::token::TokenClaims;
use tower_mcp::protocol::{McpRequest, McpResponse};
use tower_mcp::{RouterRequest, RouterResponse};
use tower_mcp_types::JsonRpcError;

use crate::config::BearerTokenConfig;

/// Key used in `TokenClaims.extra` to store bearer scope info.
const BEARER_SCOPE_KEY: &str = "__bearer_scope";

// ---------------------------------------------------------------------------
// Axum middleware: inject TokenClaims with bearer scope info
// ---------------------------------------------------------------------------

/// Axum middleware layer that validates bearer tokens and injects scope info.
///
/// For scoped tokens, inserts synthetic [`TokenClaims`] into request
/// extensions so tower-mcp's HTTP transport propagates them to MCP
/// extensions. Unscoped tokens pass through without `TokenClaims`.
#[derive(Clone)]
pub struct ScopedBearerAuthLayer {
    inner: Arc<ScopedBearerAuthState>,
}

struct ScopedBearerAuthState {
    /// All valid tokens (for validation)
    valid_tokens: HashSet<String>,
    /// Token -> scope JSON (only for scoped tokens)
    scopes: HashMap<String, serde_json::Value>,
}

impl ScopedBearerAuthLayer {
    /// Build from combined simple + scoped token lists.
    pub fn new(simple_tokens: &[String], scoped_tokens: &[BearerTokenConfig]) -> Self {
        let mut valid_tokens = HashSet::new();
        let mut scopes = HashMap::new();

        for t in simple_tokens {
            valid_tokens.insert(t.clone());
        }

        for st in scoped_tokens {
            valid_tokens.insert(st.token.clone());
            // Build scope JSON for this token
            let scope = serde_json::json!({
                "allow": st.allow_tools,
                "deny": st.deny_tools,
            });
            scopes.insert(st.token.clone(), scope);
        }

        Self {
            inner: Arc::new(ScopedBearerAuthState {
                valid_tokens,
                scopes,
            }),
        }
    }
}

impl<S> tower::Layer<S> for ScopedBearerAuthLayer {
    type Service = ScopedBearerAuthService<S>;

    fn layer(&self, inner: S) -> Self::Service {
        ScopedBearerAuthService {
            inner,
            state: Arc::clone(&self.inner),
        }
    }
}

/// Axum service that validates bearer tokens and injects scope info.
#[derive(Clone)]
pub struct ScopedBearerAuthService<S> {
    inner: S,
    state: Arc<ScopedBearerAuthState>,
}

impl<S> Service<axum::http::Request<axum::body::Body>> for ScopedBearerAuthService<S>
where
    S: Service<axum::http::Request<axum::body::Body>, Response = axum::response::Response>
        + Clone
        + Send
        + 'static,
    S::Future: Send,
    S::Error: Into<tower_mcp::BoxError> + Send,
{
    type Response = axum::response::Response;
    type Error = S::Error;
    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;

    fn poll_ready(
        &mut self,
        cx: &mut std::task::Context<'_>,
    ) -> std::task::Poll<Result<(), Self::Error>> {
        self.inner.poll_ready(cx)
    }

    fn call(&mut self, req: axum::http::Request<axum::body::Body>) -> Self::Future {
        let token = req
            .headers()
            .get("Authorization")
            .and_then(|v| v.to_str().ok())
            .and_then(|s| s.strip_prefix("Bearer "))
            .map(|s| s.trim().to_owned());

        let state = Arc::clone(&self.state);
        let inner = self.inner.clone();

        Box::pin(async move {
            let Some(token) = token else {
                return Ok(unauthorized_response("Missing bearer token"));
            };

            if !state.valid_tokens.contains(&token) {
                return Ok(unauthorized_response("Invalid bearer token"));
            }

            let mut req = req;

            // If this is a scoped token, inject TokenClaims with scope info
            if let Some(scope) = state.scopes.get(&token) {
                let mut extra = HashMap::new();
                extra.insert(BEARER_SCOPE_KEY.to_string(), scope.clone());
                let claims = TokenClaims {
                    sub: None,
                    iss: None,
                    aud: None,
                    exp: None,
                    scope: None,
                    client_id: None,
                    extra,
                };
                req.extensions_mut().insert(claims);
            }

            tower::ServiceExt::oneshot(inner, req).await
        })
    }
}

/// Construct an HTTP 401 Unauthorized response.
fn unauthorized_response(message: &str) -> axum::response::Response {
    use axum::http::StatusCode;
    use axum::response::IntoResponse;

    let body = serde_json::json!({
        "jsonrpc": "2.0",
        "error": {
            "code": -32001,
            "message": message
        },
        "id": null
    });

    (StatusCode::UNAUTHORIZED, axum::Json(body)).into_response()
}

// ---------------------------------------------------------------------------
// MCP middleware: enforce per-token tool scoping
// ---------------------------------------------------------------------------

/// Resolved bearer token scope (allow/deny tool sets).
#[derive(Debug, Clone)]
struct ResolvedScope {
    allow: HashSet<String>,
    deny: HashSet<String>,
}

impl ResolvedScope {
    /// Parse scope from the `TokenClaims.extra` map.
    fn from_claims(claims: &TokenClaims) -> Option<Self> {
        let scope_val = claims.extra.get(BEARER_SCOPE_KEY)?;

        let allow: HashSet<String> = scope_val
            .get("allow")
            .and_then(|v| v.as_array())
            .map(|arr| {
                arr.iter()
                    .filter_map(|v| v.as_str().map(String::from))
                    .collect()
            })
            .unwrap_or_default();

        let deny: HashSet<String> = scope_val
            .get("deny")
            .and_then(|v| v.as_array())
            .map(|arr| {
                arr.iter()
                    .filter_map(|v| v.as_str().map(String::from))
                    .collect()
            })
            .unwrap_or_default();

        // If both are empty, this is an unscoped token
        if allow.is_empty() && deny.is_empty() {
            return None;
        }

        Some(Self { allow, deny })
    }

    /// Check if a tool is allowed under this scope.
    fn is_tool_allowed(&self, tool_name: &str) -> bool {
        if !self.allow.is_empty() && !self.allow.contains(tool_name) {
            return false;
        }
        if self.deny.contains(tool_name) {
            return false;
        }
        true
    }
}

/// MCP middleware that enforces per-bearer-token tool access control.
///
/// Reads scope info from `TokenClaims.extra` (injected by [`ScopedBearerAuthLayer`])
/// and applies allow/deny lists to tool calls and list responses.
#[derive(Clone)]
pub struct BearerScopingService<S> {
    inner: S,
}

impl<S> BearerScopingService<S> {
    /// Wrap an inner MCP service with bearer scoping enforcement.
    pub fn new(inner: S) -> Self {
        Self { inner }
    }
}

impl<S> Service<RouterRequest> for BearerScopingService<S>
where
    S: Service<RouterRequest, Response = RouterResponse, Error = Infallible>
        + Clone
        + Send
        + 'static,
    S::Future: Send,
{
    type Response = RouterResponse;
    type Error = Infallible;
    type Future = Pin<Box<dyn Future<Output = Result<RouterResponse, Infallible>> + Send>>;

    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
        self.inner.poll_ready(cx)
    }

    fn call(&mut self, req: RouterRequest) -> Self::Future {
        let request_id = req.id.clone();

        // Try to extract bearer scope from extensions
        let scope = req
            .extensions
            .get::<TokenClaims>()
            .and_then(ResolvedScope::from_claims);

        // No scope = unscoped token or no auth; pass through
        let Some(scope) = scope else {
            let fut = self.inner.call(req);
            return Box::pin(fut);
        };

        // Check tool calls against scope
        if let McpRequest::CallTool(ref params) = req.inner
            && !scope.is_tool_allowed(&params.name)
        {
            let tool_name = params.name.clone();
            return Box::pin(async move {
                Ok(RouterResponse {
                    id: request_id,
                    inner: Err(JsonRpcError::invalid_params(format!(
                        "Token is not authorized to call tool: {tool_name}"
                    ))),
                })
            });
        }

        let fut = self.inner.call(req);

        Box::pin(async move {
            let mut resp = fut.await?;

            // Filter list_tools response
            if let Ok(McpResponse::ListTools(ref mut result)) = resp.inner {
                result
                    .tools
                    .retain(|tool| scope.is_tool_allowed(&tool.name));
            }

            Ok(resp)
        })
    }
}

#[cfg(test)]
mod tests {
    use std::collections::HashMap;

    use tower::Service;
    use tower_mcp::oauth::token::TokenClaims;
    use tower_mcp::protocol::{
        CallToolParams, ListToolsParams, McpRequest, McpResponse, RequestId,
    };
    use tower_mcp::router::Extensions;

    use super::{BEARER_SCOPE_KEY, BearerScopingService};
    use crate::test_util::{MockService, call_service};

    fn request_with_bearer_scope(
        allow: &[&str],
        deny: &[&str],
        inner: McpRequest,
    ) -> tower_mcp::RouterRequest {
        let mut extra = HashMap::new();
        extra.insert(
            BEARER_SCOPE_KEY.to_string(),
            serde_json::json!({
                "allow": allow,
                "deny": deny,
            }),
        );
        let mut extensions = Extensions::new();
        extensions.insert(TokenClaims {
            sub: None,
            iss: None,
            aud: None,
            exp: None,
            scope: None,
            client_id: None,
            extra,
        });
        tower_mcp::RouterRequest {
            id: RequestId::Number(1),
            inner,
            extensions,
        }
    }

    #[tokio::test]
    async fn no_scope_passes_through() {
        let mock = MockService::with_tools(&["fs/read", "fs/write", "db/query"]);
        let mut svc = BearerScopingService::new(mock);

        let resp = call_service(&mut svc, McpRequest::ListTools(ListToolsParams::default())).await;
        let tools = match resp.inner.unwrap() {
            McpResponse::ListTools(r) => r.tools,
            other => panic!("expected ListTools, got: {other:?}"),
        };
        assert_eq!(tools.len(), 3);
    }

    #[tokio::test]
    async fn allow_list_filters_tools() {
        let mock = MockService::with_tools(&["fs/read", "fs/write", "db/query"]);
        let mut svc = BearerScopingService::new(mock);

        let req = request_with_bearer_scope(
            &["fs/read"],
            &[],
            McpRequest::ListTools(ListToolsParams::default()),
        );
        let resp = svc.call(req).await.unwrap();
        let tools = match resp.inner.unwrap() {
            McpResponse::ListTools(r) => r.tools,
            other => panic!("expected ListTools, got: {other:?}"),
        };
        assert_eq!(tools.len(), 1);
        assert_eq!(tools[0].name, "fs/read");
    }

    #[tokio::test]
    async fn deny_list_filters_tools() {
        let mock = MockService::with_tools(&["fs/read", "fs/write", "db/query"]);
        let mut svc = BearerScopingService::new(mock);

        let req = request_with_bearer_scope(
            &[],
            &["fs/write"],
            McpRequest::ListTools(ListToolsParams::default()),
        );
        let resp = svc.call(req).await.unwrap();
        let tools = match resp.inner.unwrap() {
            McpResponse::ListTools(r) => r.tools,
            other => panic!("expected ListTools, got: {other:?}"),
        };
        assert_eq!(tools.len(), 2);
        assert!(tools.iter().all(|t| t.name != "fs/write"));
    }

    #[tokio::test]
    async fn allow_list_blocks_call() {
        let mock = MockService::with_tools(&["fs/read", "fs/write"]);
        let mut svc = BearerScopingService::new(mock);

        let req = request_with_bearer_scope(
            &["fs/read"],
            &[],
            McpRequest::CallTool(CallToolParams {
                name: "fs/write".to_string(),
                arguments: serde_json::json!({}),
                meta: None,
                task: None,
            }),
        );
        let resp = svc.call(req).await.unwrap();
        assert!(resp.inner.is_err(), "should block disallowed tool call");
        let err = resp.inner.unwrap_err();
        assert!(err.message.contains("fs/write"));
    }

    #[tokio::test]
    async fn allow_list_permits_call() {
        let mock = MockService::with_tools(&["fs/read", "fs/write"]);
        let mut svc = BearerScopingService::new(mock);

        let req = request_with_bearer_scope(
            &["fs/read"],
            &[],
            McpRequest::CallTool(CallToolParams {
                name: "fs/read".to_string(),
                arguments: serde_json::json!({}),
                meta: None,
                task: None,
            }),
        );
        let resp = svc.call(req).await.unwrap();
        assert!(resp.inner.is_ok(), "should allow permitted tool call");
    }

    #[tokio::test]
    async fn deny_list_blocks_call() {
        let mock = MockService::with_tools(&["fs/read", "fs/write"]);
        let mut svc = BearerScopingService::new(mock);

        let req = request_with_bearer_scope(
            &[],
            &["fs/write"],
            McpRequest::CallTool(CallToolParams {
                name: "fs/write".to_string(),
                arguments: serde_json::json!({}),
                meta: None,
                task: None,
            }),
        );
        let resp = svc.call(req).await.unwrap();
        assert!(resp.inner.is_err(), "should block denied tool call");
    }
}