Skip to main content

mcp_proxy/
rbac.rs

1//! Role-based access control (RBAC) middleware for the proxy.
2//!
3//! This module enforces per-role tool access policies on authenticated requests.
4//! It reads JWT claims from [`RouterRequest`] extensions, maps claim values to
5//! named roles via a configurable mapping, and applies per-role allow/deny lists
6//! to both `tools/call` and `tools/list` requests.
7//!
8//! # How role resolution works
9//!
10//! 1. The auth layer (JWT or introspection) validates the token and inserts
11//!    [`TokenClaims`](tower_mcp::oauth::token::TokenClaims) into the request
12//!    extensions.
13//! 2. [`RbacConfig`] reads a configured claim (e.g. `"scope"`, `"role"`,
14//!    `"groups"`) from those claims.
15//! 3. The claim value is matched against `role_mapping.mapping` to resolve a
16//!    role name (e.g. `"read-only"` -> `"reader"`).
17//! 4. The resolved role's `allow_tools` and `deny_tools` lists determine access.
18//!
19//! If no [`TokenClaims`](tower_mcp::oauth::token::TokenClaims) are present in
20//! the request extensions (e.g. unauthenticated or bearer-token-only requests),
21//! the RBAC layer passes the request through without restriction.
22//!
23//! # Allow/deny list semantics
24//!
25//! Each role can define an allow list, a deny list, or both:
26//!
27//! - **Allow list only**: only the listed tools are accessible. All others are
28//!   denied.
29//! - **Deny list only**: all tools are accessible except those listed.
30//! - **Both**: a tool must appear in the allow list AND not appear in the deny
31//!   list.
32//! - **Neither** (empty lists): the role has unrestricted access (e.g. an admin
33//!   role).
34//!
35//! # Interaction with capability filtering
36//!
37//! RBAC runs **on top of** the static capability filter configured per backend.
38//! The final set of visible tools is the **intersection** of what the backend
39//! exposes and what the role permits -- RBAC can only further restrict, never
40//! widen, the tools a client can see or call.
41//!
42//! # Configuration example
43//!
44//! ```toml
45//! [auth]
46//! type = "jwt"
47//! issuer = "https://auth.example.com"
48//! audience = "mcp-proxy"
49//! jwks_uri = "https://auth.example.com/.well-known/jwks.json"
50//!
51//! [[auth.roles]]
52//! name = "admin"
53//! # Empty allow/deny = unrestricted access
54//!
55//! [[auth.roles]]
56//! name = "reader"
57//! allow_tools = ["files/read_file", "files/list_dir"]
58//!
59//! [[auth.roles]]
60//! name = "developer"
61//! deny_tools = ["admin/restart", "admin/shutdown"]
62//!
63//! [auth.role_mapping]
64//! claim = "scope"
65//!
66//! [auth.role_mapping.mapping]
67//! admin = "admin"
68//! read-only = "reader"
69//! dev = "developer"
70//! ```
71//!
72//! # Enforcement
73//!
74//! [`RbacService`] is a Tower middleware that wraps the proxy's inner service.
75//! On `tools/call` requests, it checks the tool name against the resolved role
76//! before forwarding. On `tools/list` responses, it filters out tools the role
77//! cannot access. Denied calls receive a JSON-RPC `InvalidParams` error with
78//! a message identifying the role and tool.
79
80use std::collections::{HashMap, HashSet};
81use std::convert::Infallible;
82use std::future::Future;
83use std::pin::Pin;
84use std::sync::Arc;
85use std::task::{Context, Poll};
86
87use tower::Service;
88
89use tower_mcp::protocol::{McpRequest, McpResponse};
90use tower_mcp::{RouterRequest, RouterResponse};
91use tower_mcp_types::JsonRpcError;
92
93use crate::config::{RoleConfig, RoleMappingConfig};
94
95/// Outcome of resolving a role from request token claims.
96enum RoleResolution {
97    /// No `TokenClaims` present in the request (e.g. unauthenticated or
98    /// bearer-token-only requests). Always passes through.
99    NoClaims,
100    /// Claims present and the mapped claim value resolved to a named role.
101    Role(String),
102    /// Claims present, but the claim value is not in `claim_to_role`. Governed
103    /// by `default_deny`.
104    Unmapped,
105}
106
107/// Resolved RBAC rules.
108#[derive(Clone)]
109pub struct RbacConfig {
110    /// Claim name to read from TokenClaims (e.g. "scope", "role")
111    claim: String,
112    /// Map of claim value -> role name
113    claim_to_role: HashMap<String, String>,
114    /// Map of role name -> allowed tools (empty = all allowed)
115    role_allow: HashMap<String, HashSet<String>>,
116    /// Map of role name -> denied tools
117    role_deny: HashMap<String, HashSet<String>>,
118    /// Deny authenticated requests whose claim value is not in `claim_to_role`.
119    default_deny: bool,
120}
121
122impl RbacConfig {
123    /// Build RBAC config from role definitions and claim-to-role mapping.
124    pub fn new(roles: &[RoleConfig], mapping: &RoleMappingConfig) -> Self {
125        let mut role_allow = HashMap::new();
126        let mut role_deny = HashMap::new();
127
128        for role in roles {
129            if !role.allow_tools.is_empty() {
130                role_allow.insert(
131                    role.name.clone(),
132                    role.allow_tools.iter().cloned().collect(),
133                );
134            }
135            if !role.deny_tools.is_empty() {
136                role_deny.insert(role.name.clone(), role.deny_tools.iter().cloned().collect());
137            }
138        }
139
140        Self {
141            claim: mapping.claim.clone(),
142            claim_to_role: mapping.mapping.clone(),
143            role_allow,
144            role_deny,
145            default_deny: mapping.default_deny,
146        }
147    }
148
149    /// Resolve the role for the current request from TokenClaims.
150    ///
151    /// Distinguishes three cases: no claims present (pass through), claims that
152    /// map to a named role, and claims whose value is unrecognized (governed by
153    /// `default_deny`).
154    fn resolve_role(&self, extensions: &tower_mcp::router::Extensions) -> RoleResolution {
155        let Some(claims) = extensions.get::<tower_mcp::oauth::token::TokenClaims>() else {
156            return RoleResolution::NoClaims;
157        };
158
159        // Check standard scope field first
160        if self.claim == "scope" {
161            let scopes = claims.scopes();
162            for scope in &scopes {
163                if let Some(role) = self.claim_to_role.get(scope) {
164                    return RoleResolution::Role(role.clone());
165                }
166            }
167            return RoleResolution::Unmapped;
168        }
169
170        // Check extra claims
171        if let Some(value) = claims.extra.get(&self.claim) {
172            let claim_str = match value {
173                serde_json::Value::String(s) => s.clone(),
174                other => other.to_string(),
175            };
176            // Try direct mapping
177            if let Some(role) = self.claim_to_role.get(&claim_str) {
178                return RoleResolution::Role(role.clone());
179            }
180            // Try space-delimited (like scope)
181            for part in claim_str.split_whitespace() {
182                if let Some(role) = self.claim_to_role.get(part) {
183                    return RoleResolution::Role(role.clone());
184                }
185            }
186        }
187
188        RoleResolution::Unmapped
189    }
190
191    /// Check if a tool is allowed for the given role.
192    fn is_tool_allowed(&self, role: &str, tool_name: &str) -> bool {
193        // If role has an allowlist, tool must be in it
194        if let Some(allowed) = self.role_allow.get(role)
195            && !allowed.contains(tool_name)
196        {
197            return false;
198        }
199        // If role has a denylist, tool must not be in it
200        if let Some(denied) = self.role_deny.get(role)
201            && denied.contains(tool_name)
202        {
203            return false;
204        }
205        true
206    }
207}
208
209/// Middleware that enforces RBAC on tool calls and list responses.
210#[derive(Clone)]
211pub struct RbacService<S> {
212    inner: S,
213    config: Arc<RbacConfig>,
214}
215
216impl<S> RbacService<S> {
217    /// Create a new RBAC enforcement service wrapping `inner`.
218    pub fn new(inner: S, config: RbacConfig) -> Self {
219        Self {
220            inner,
221            config: Arc::new(config),
222        }
223    }
224}
225
226impl<S> Service<RouterRequest> for RbacService<S>
227where
228    S: Service<RouterRequest, Response = RouterResponse, Error = Infallible>
229        + Clone
230        + Send
231        + 'static,
232    S::Future: Send,
233{
234    type Response = RouterResponse;
235    type Error = Infallible;
236    type Future = Pin<Box<dyn Future<Output = Result<RouterResponse, Infallible>> + Send>>;
237
238    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
239        self.inner.poll_ready(cx)
240    }
241
242    fn call(&mut self, req: RouterRequest) -> Self::Future {
243        let config = Arc::clone(&self.config);
244        let request_id = req.id.clone();
245
246        // Resolve role from extensions.
247        let role = match config.resolve_role(&req.extensions) {
248            // No TokenClaims at all (unauthenticated or bearer-token-only
249            // requests, already validated by the auth layer): pass through, no
250            // RBAC restriction applies.
251            RoleResolution::NoClaims => {
252                let fut = self.inner.call(req);
253                return Box::pin(fut);
254            }
255            // Valid claims whose scope is not in the mapping. Under default-deny
256            // this is rejected; otherwise it preserves the legacy pass-through.
257            RoleResolution::Unmapped => {
258                if config.default_deny {
259                    return Box::pin(async move {
260                        Ok(RouterResponse {
261                            id: request_id,
262                            inner: Err(JsonRpcError::invalid_params(
263                                "Authenticated principal carries no recognized role; \
264                                 access denied (rbac default_deny)"
265                                    .to_string(),
266                            )),
267                        })
268                    });
269                }
270                let fut = self.inner.call(req);
271                return Box::pin(fut);
272            }
273            RoleResolution::Role(role) => role,
274        };
275
276        let role_for_filter = role.clone();
277
278        // Check tool calls against RBAC
279        if let McpRequest::CallTool(ref params) = req.inner
280            && !config.is_tool_allowed(&role, &params.name)
281        {
282            let tool_name = params.name.clone();
283            return Box::pin(async move {
284                Ok(RouterResponse {
285                    id: request_id,
286                    inner: Err(JsonRpcError::invalid_params(format!(
287                        "Role '{}' is not authorized to call tool: {}",
288                        role, tool_name
289                    ))),
290                })
291            });
292        }
293
294        let fut = self.inner.call(req);
295
296        Box::pin(async move {
297            let mut resp = fut.await?;
298
299            // Filter list_tools response based on role
300            if let Ok(McpResponse::ListTools(ref mut result)) = resp.inner {
301                result
302                    .tools
303                    .retain(|tool| config.is_tool_allowed(&role_for_filter, &tool.name));
304            }
305
306            Ok(resp)
307        })
308    }
309}
310
311#[cfg(test)]
312mod tests {
313    use std::collections::HashMap;
314
315    use tower::Service;
316    use tower_mcp::oauth::token::TokenClaims;
317    use tower_mcp::protocol::{McpRequest, McpResponse, RequestId};
318    use tower_mcp::router::Extensions;
319
320    use super::{RbacConfig, RbacService};
321    use crate::config::{RoleConfig, RoleMappingConfig};
322    use crate::test_util::MockService;
323
324    fn test_rbac_config() -> RbacConfig {
325        rbac_config_with_default_deny(false)
326    }
327
328    fn rbac_config_with_default_deny(default_deny: bool) -> RbacConfig {
329        let roles = vec![
330            RoleConfig {
331                name: "admin".into(),
332                allow_tools: vec![],
333                deny_tools: vec![],
334            },
335            RoleConfig {
336                name: "reader".into(),
337                allow_tools: vec!["fs/read".into()],
338                deny_tools: vec![],
339            },
340        ];
341        let mapping = RoleMappingConfig {
342            claim: "scope".into(),
343            mapping: HashMap::from([
344                ("admin".into(), "admin".into()),
345                ("read-only".into(), "reader".into()),
346            ]),
347            default_deny,
348        };
349        RbacConfig::new(&roles, &mapping)
350    }
351
352    fn request_with_scope(scope: &str, inner: McpRequest) -> tower_mcp::RouterRequest {
353        let mut extensions = Extensions::new();
354        extensions.insert(TokenClaims {
355            sub: None,
356            iss: None,
357            aud: None,
358            exp: None,
359            scope: Some(scope.to_string()),
360            client_id: None,
361            extra: HashMap::new(),
362        });
363        tower_mcp::RouterRequest {
364            id: RequestId::Number(1),
365            inner,
366            extensions,
367        }
368    }
369
370    #[tokio::test]
371    async fn test_rbac_admin_can_call_any_tool() {
372        let mock = MockService::with_tools(&["fs/read", "fs/write"]);
373        let mut svc = RbacService::new(mock, test_rbac_config());
374
375        let req = request_with_scope(
376            "admin",
377            McpRequest::CallTool(tower_mcp::protocol::CallToolParams {
378                name: "fs/write".to_string(),
379                arguments: serde_json::json!({}),
380                meta: None,
381                task: None,
382            }),
383        );
384        let resp = svc.call(req).await.unwrap();
385        assert!(resp.inner.is_ok(), "admin should call any tool");
386    }
387
388    #[tokio::test]
389    async fn test_rbac_reader_denied_write() {
390        let mock = MockService::with_tools(&["fs/read", "fs/write"]);
391        let mut svc = RbacService::new(mock, test_rbac_config());
392
393        let req = request_with_scope(
394            "read-only",
395            McpRequest::CallTool(tower_mcp::protocol::CallToolParams {
396                name: "fs/write".to_string(),
397                arguments: serde_json::json!({}),
398                meta: None,
399                task: None,
400            }),
401        );
402        let resp = svc.call(req).await.unwrap();
403        let err = resp.inner.unwrap_err();
404        assert!(err.message.contains("not authorized"));
405    }
406
407    #[tokio::test]
408    async fn test_rbac_reader_allowed_read() {
409        let mock = MockService::with_tools(&["fs/read"]);
410        let mut svc = RbacService::new(mock, test_rbac_config());
411
412        let req = request_with_scope(
413            "read-only",
414            McpRequest::CallTool(tower_mcp::protocol::CallToolParams {
415                name: "fs/read".to_string(),
416                arguments: serde_json::json!({}),
417                meta: None,
418                task: None,
419            }),
420        );
421        let resp = svc.call(req).await.unwrap();
422        assert!(resp.inner.is_ok(), "reader should call allowed tools");
423    }
424
425    #[tokio::test]
426    async fn test_rbac_filters_list_tools_for_role() {
427        let mock = MockService::with_tools(&["fs/read", "fs/write", "fs/delete"]);
428        let mut svc = RbacService::new(mock, test_rbac_config());
429
430        let req = request_with_scope("read-only", McpRequest::ListTools(Default::default()));
431        let resp = svc.call(req).await.unwrap();
432
433        match resp.inner.unwrap() {
434            McpResponse::ListTools(result) => {
435                let names: Vec<&str> = result.tools.iter().map(|t| t.name.as_str()).collect();
436                assert!(names.contains(&"fs/read"));
437                assert!(!names.contains(&"fs/write"));
438                assert!(!names.contains(&"fs/delete"));
439            }
440            other => panic!("expected ListTools, got: {:?}", other),
441        }
442    }
443
444    #[tokio::test]
445    async fn test_rbac_no_claims_passes_through() {
446        let mock = MockService::with_tools(&["fs/write"]);
447        let mut svc = RbacService::new(mock, test_rbac_config());
448
449        // No TokenClaims in extensions
450        let req = tower_mcp::RouterRequest {
451            id: RequestId::Number(1),
452            inner: McpRequest::CallTool(tower_mcp::protocol::CallToolParams {
453                name: "fs/write".to_string(),
454                arguments: serde_json::json!({}),
455                meta: None,
456                task: None,
457            }),
458            extensions: Extensions::new(),
459        };
460        let resp = svc.call(req).await.unwrap();
461        assert!(resp.inner.is_ok(), "no claims should pass through");
462    }
463
464    #[tokio::test]
465    async fn test_rbac_unmapped_scope_passes_through_by_default() {
466        // Valid claims, but the scope is not in the mapping. With default_deny
467        // = false (the default), this preserves legacy pass-through behavior.
468        let mock = MockService::with_tools(&["fs/write"]);
469        let mut svc = RbacService::new(mock, rbac_config_with_default_deny(false));
470
471        let req = request_with_scope(
472            "unknown-scope",
473            McpRequest::CallTool(tower_mcp::protocol::CallToolParams {
474                name: "fs/write".to_string(),
475                arguments: serde_json::json!({}),
476                meta: None,
477                task: None,
478            }),
479        );
480        let resp = svc.call(req).await.unwrap();
481        assert!(
482            resp.inner.is_ok(),
483            "unmapped scope should pass through when default_deny is false"
484        );
485    }
486
487    #[tokio::test]
488    async fn test_rbac_unmapped_scope_denied_with_default_deny() {
489        // Valid claims, scope not in the mapping, default_deny = true: denied.
490        let mock = MockService::with_tools(&["fs/write"]);
491        let mut svc = RbacService::new(mock, rbac_config_with_default_deny(true));
492
493        let req = request_with_scope(
494            "unknown-scope",
495            McpRequest::CallTool(tower_mcp::protocol::CallToolParams {
496                name: "fs/write".to_string(),
497                arguments: serde_json::json!({}),
498                meta: None,
499                task: None,
500            }),
501        );
502        let resp = svc.call(req).await.unwrap();
503        let err = resp.inner.unwrap_err();
504        assert!(
505            err.message.contains("default_deny"),
506            "unmapped scope should be denied when default_deny is true, got: {}",
507            err.message
508        );
509    }
510
511    #[tokio::test]
512    async fn test_rbac_mapped_scope_resolves_with_default_deny_enabled() {
513        // A recognized scope must still resolve its role regardless of the
514        // default_deny setting: reader can read, cannot write.
515        let mock = MockService::with_tools(&["fs/read", "fs/write"]);
516        let mut svc = RbacService::new(mock, rbac_config_with_default_deny(true));
517
518        let read_req = request_with_scope(
519            "read-only",
520            McpRequest::CallTool(tower_mcp::protocol::CallToolParams {
521                name: "fs/read".to_string(),
522                arguments: serde_json::json!({}),
523                meta: None,
524                task: None,
525            }),
526        );
527        let resp = svc.call(read_req).await.unwrap();
528        assert!(
529            resp.inner.is_ok(),
530            "mapped role should still resolve with default_deny enabled"
531        );
532
533        let write_req = request_with_scope(
534            "read-only",
535            McpRequest::CallTool(tower_mcp::protocol::CallToolParams {
536                name: "fs/write".to_string(),
537                arguments: serde_json::json!({}),
538                meta: None,
539                task: None,
540            }),
541        );
542        let resp = svc.call(write_req).await.unwrap();
543        let err = resp.inner.unwrap_err();
544        assert!(
545            err.message.contains("not authorized"),
546            "reader should be denied write via role policy, got: {}",
547            err.message
548        );
549    }
550
551    #[tokio::test]
552    async fn test_rbac_no_claims_passes_through_with_default_deny() {
553        // default_deny must NOT affect the "no TokenClaims at all" case: a
554        // request with no claims always passes through.
555        let mock = MockService::with_tools(&["fs/write"]);
556        let mut svc = RbacService::new(mock, rbac_config_with_default_deny(true));
557
558        let req = tower_mcp::RouterRequest {
559            id: RequestId::Number(1),
560            inner: McpRequest::CallTool(tower_mcp::protocol::CallToolParams {
561                name: "fs/write".to_string(),
562                arguments: serde_json::json!({}),
563                meta: None,
564                task: None,
565            }),
566            extensions: Extensions::new(),
567        };
568        let resp = svc.call(req).await.unwrap();
569        assert!(
570            resp.inner.is_ok(),
571            "no claims must pass through even when default_deny is true"
572        );
573    }
574}