Skip to main content

mcp_proxy/
rbac.rs

1//! Role-based access control (RBAC) middleware for the proxy.
2//!
3//! This module enforces per-role tool access policies on authenticated requests.
4//! It reads JWT claims from [`RouterRequest`] extensions, maps claim values to
5//! named roles via a configurable mapping, and applies per-role allow/deny lists
6//! to both `tools/call` and `tools/list` requests.
7//!
8//! # How role resolution works
9//!
10//! 1. The auth layer (JWT or introspection) validates the token and inserts
11//!    [`TokenClaims`](tower_mcp::oauth::token::TokenClaims) into the request
12//!    extensions.
13//! 2. [`RbacConfig`] reads a configured claim (e.g. `"scope"`, `"role"`,
14//!    `"groups"`) from those claims.
15//! 3. The claim value is matched against `role_mapping.mapping` to resolve a
16//!    role name (e.g. `"read-only"` -> `"reader"`).
17//! 4. The resolved role's `allow_tools` and `deny_tools` lists determine access.
18//!
19//! If no [`TokenClaims`](tower_mcp::oauth::token::TokenClaims) are present in
20//! the request extensions (e.g. unauthenticated or bearer-token-only requests),
21//! the RBAC layer passes the request through without restriction.
22//!
23//! # Allow/deny list semantics
24//!
25//! Each role can define an allow list, a deny list, or both:
26//!
27//! - **Allow list only**: only the listed tools are accessible. All others are
28//!   denied.
29//! - **Deny list only**: all tools are accessible except those listed.
30//! - **Both**: a tool must appear in the allow list AND not appear in the deny
31//!   list.
32//! - **Neither** (empty lists): the role has unrestricted access (e.g. an admin
33//!   role).
34//!
35//! # Interaction with capability filtering
36//!
37//! RBAC runs **on top of** the static capability filter configured per backend.
38//! The final set of visible tools is the **intersection** of what the backend
39//! exposes and what the role permits -- RBAC can only further restrict, never
40//! widen, the tools a client can see or call.
41//!
42//! # Configuration example
43//!
44//! ```toml
45//! [auth]
46//! type = "jwt"
47//! issuer = "https://auth.example.com"
48//! audience = "mcp-proxy"
49//! jwks_uri = "https://auth.example.com/.well-known/jwks.json"
50//!
51//! [[auth.roles]]
52//! name = "admin"
53//! # Empty allow/deny = unrestricted access
54//!
55//! [[auth.roles]]
56//! name = "reader"
57//! allow_tools = ["files/read_file", "files/list_dir"]
58//!
59//! [[auth.roles]]
60//! name = "developer"
61//! deny_tools = ["admin/restart", "admin/shutdown"]
62//!
63//! [auth.role_mapping]
64//! claim = "scope"
65//!
66//! [auth.role_mapping.mapping]
67//! admin = "admin"
68//! read-only = "reader"
69//! dev = "developer"
70//! ```
71//!
72//! # Enforcement
73//!
74//! [`RbacService`] is a Tower middleware that wraps the proxy's inner service.
75//! On `tools/call` requests, it checks the tool name against the resolved role
76//! before forwarding. On `tools/list` responses, it filters out tools the role
77//! cannot access. Denied calls receive a JSON-RPC `InvalidParams` error with
78//! a message identifying the role and tool.
79
80use std::collections::{HashMap, HashSet};
81use std::convert::Infallible;
82use std::future::Future;
83use std::pin::Pin;
84use std::sync::Arc;
85use std::task::{Context, Poll};
86
87use tower::Service;
88
89use tower_mcp::protocol::{McpRequest, McpResponse};
90use tower_mcp::{RouterRequest, RouterResponse};
91use tower_mcp_types::JsonRpcError;
92
93use crate::config::{RoleConfig, RoleMappingConfig};
94
95/// Resolved RBAC rules.
96#[derive(Clone)]
97pub struct RbacConfig {
98    /// Claim name to read from TokenClaims (e.g. "scope", "role")
99    claim: String,
100    /// Map of claim value -> role name
101    claim_to_role: HashMap<String, String>,
102    /// Map of role name -> allowed tools (empty = all allowed)
103    role_allow: HashMap<String, HashSet<String>>,
104    /// Map of role name -> denied tools
105    role_deny: HashMap<String, HashSet<String>>,
106}
107
108impl RbacConfig {
109    /// Build RBAC config from role definitions and claim-to-role mapping.
110    pub fn new(roles: &[RoleConfig], mapping: &RoleMappingConfig) -> Self {
111        let mut role_allow = HashMap::new();
112        let mut role_deny = HashMap::new();
113
114        for role in roles {
115            if !role.allow_tools.is_empty() {
116                role_allow.insert(
117                    role.name.clone(),
118                    role.allow_tools.iter().cloned().collect(),
119                );
120            }
121            if !role.deny_tools.is_empty() {
122                role_deny.insert(role.name.clone(), role.deny_tools.iter().cloned().collect());
123            }
124        }
125
126        Self {
127            claim: mapping.claim.clone(),
128            claim_to_role: mapping.mapping.clone(),
129            role_allow,
130            role_deny,
131        }
132    }
133
134    /// Resolve the role for the current request from TokenClaims.
135    fn resolve_role(&self, extensions: &tower_mcp::router::Extensions) -> Option<String> {
136        let claims = extensions.get::<tower_mcp::oauth::token::TokenClaims>()?;
137
138        // Check standard scope field first
139        if self.claim == "scope" {
140            let scopes = claims.scopes();
141            for scope in &scopes {
142                if let Some(role) = self.claim_to_role.get(scope) {
143                    return Some(role.clone());
144                }
145            }
146            return None;
147        }
148
149        // Check extra claims
150        if let Some(value) = claims.extra.get(&self.claim) {
151            let claim_str = match value {
152                serde_json::Value::String(s) => s.clone(),
153                other => other.to_string(),
154            };
155            // Try direct mapping
156            if let Some(role) = self.claim_to_role.get(&claim_str) {
157                return Some(role.clone());
158            }
159            // Try space-delimited (like scope)
160            for part in claim_str.split_whitespace() {
161                if let Some(role) = self.claim_to_role.get(part) {
162                    return Some(role.clone());
163                }
164            }
165        }
166
167        None
168    }
169
170    /// Check if a tool is allowed for the given role.
171    fn is_tool_allowed(&self, role: &str, tool_name: &str) -> bool {
172        // If role has an allowlist, tool must be in it
173        if let Some(allowed) = self.role_allow.get(role)
174            && !allowed.contains(tool_name)
175        {
176            return false;
177        }
178        // If role has a denylist, tool must not be in it
179        if let Some(denied) = self.role_deny.get(role)
180            && denied.contains(tool_name)
181        {
182            return false;
183        }
184        true
185    }
186}
187
188/// Middleware that enforces RBAC on tool calls and list responses.
189#[derive(Clone)]
190pub struct RbacService<S> {
191    inner: S,
192    config: Arc<RbacConfig>,
193}
194
195impl<S> RbacService<S> {
196    /// Create a new RBAC enforcement service wrapping `inner`.
197    pub fn new(inner: S, config: RbacConfig) -> Self {
198        Self {
199            inner,
200            config: Arc::new(config),
201        }
202    }
203}
204
205impl<S> Service<RouterRequest> for RbacService<S>
206where
207    S: Service<RouterRequest, Response = RouterResponse, Error = Infallible>
208        + Clone
209        + Send
210        + 'static,
211    S::Future: Send,
212{
213    type Response = RouterResponse;
214    type Error = Infallible;
215    type Future = Pin<Box<dyn Future<Output = Result<RouterResponse, Infallible>> + Send>>;
216
217    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
218        self.inner.poll_ready(cx)
219    }
220
221    fn call(&mut self, req: RouterRequest) -> Self::Future {
222        let config = Arc::clone(&self.config);
223        let request_id = req.id.clone();
224
225        // Resolve role from extensions
226        let role = config.resolve_role(&req.extensions);
227
228        // If no role resolved, pass through (no RBAC restriction applies)
229        // This allows unauthenticated or bearer-auth requests to proceed
230        // (they're already validated by the auth layer)
231        let Some(role) = role else {
232            let fut = self.inner.call(req);
233            return Box::pin(fut);
234        };
235
236        let role_for_filter = role.clone();
237
238        // Check tool calls against RBAC
239        if let McpRequest::CallTool(ref params) = req.inner
240            && !config.is_tool_allowed(&role, &params.name)
241        {
242            let tool_name = params.name.clone();
243            return Box::pin(async move {
244                Ok(RouterResponse {
245                    id: request_id,
246                    inner: Err(JsonRpcError::invalid_params(format!(
247                        "Role '{}' is not authorized to call tool: {}",
248                        role, tool_name
249                    ))),
250                })
251            });
252        }
253
254        let fut = self.inner.call(req);
255
256        Box::pin(async move {
257            let mut resp = fut.await?;
258
259            // Filter list_tools response based on role
260            if let Ok(McpResponse::ListTools(ref mut result)) = resp.inner {
261                result
262                    .tools
263                    .retain(|tool| config.is_tool_allowed(&role_for_filter, &tool.name));
264            }
265
266            Ok(resp)
267        })
268    }
269}
270
271#[cfg(test)]
272mod tests {
273    use std::collections::HashMap;
274
275    use tower::Service;
276    use tower_mcp::oauth::token::TokenClaims;
277    use tower_mcp::protocol::{McpRequest, McpResponse, RequestId};
278    use tower_mcp::router::Extensions;
279
280    use super::{RbacConfig, RbacService};
281    use crate::config::{RoleConfig, RoleMappingConfig};
282    use crate::test_util::MockService;
283
284    fn test_rbac_config() -> RbacConfig {
285        let roles = vec![
286            RoleConfig {
287                name: "admin".into(),
288                allow_tools: vec![],
289                deny_tools: vec![],
290            },
291            RoleConfig {
292                name: "reader".into(),
293                allow_tools: vec!["fs/read".into()],
294                deny_tools: vec![],
295            },
296        ];
297        let mapping = RoleMappingConfig {
298            claim: "scope".into(),
299            mapping: HashMap::from([
300                ("admin".into(), "admin".into()),
301                ("read-only".into(), "reader".into()),
302            ]),
303        };
304        RbacConfig::new(&roles, &mapping)
305    }
306
307    fn request_with_scope(scope: &str, inner: McpRequest) -> tower_mcp::RouterRequest {
308        let mut extensions = Extensions::new();
309        extensions.insert(TokenClaims {
310            sub: None,
311            iss: None,
312            aud: None,
313            exp: None,
314            scope: Some(scope.to_string()),
315            client_id: None,
316            extra: HashMap::new(),
317        });
318        tower_mcp::RouterRequest {
319            id: RequestId::Number(1),
320            inner,
321            extensions,
322        }
323    }
324
325    #[tokio::test]
326    async fn test_rbac_admin_can_call_any_tool() {
327        let mock = MockService::with_tools(&["fs/read", "fs/write"]);
328        let mut svc = RbacService::new(mock, test_rbac_config());
329
330        let req = request_with_scope(
331            "admin",
332            McpRequest::CallTool(tower_mcp::protocol::CallToolParams {
333                name: "fs/write".to_string(),
334                arguments: serde_json::json!({}),
335                meta: None,
336                task: None,
337            }),
338        );
339        let resp = svc.call(req).await.unwrap();
340        assert!(resp.inner.is_ok(), "admin should call any tool");
341    }
342
343    #[tokio::test]
344    async fn test_rbac_reader_denied_write() {
345        let mock = MockService::with_tools(&["fs/read", "fs/write"]);
346        let mut svc = RbacService::new(mock, test_rbac_config());
347
348        let req = request_with_scope(
349            "read-only",
350            McpRequest::CallTool(tower_mcp::protocol::CallToolParams {
351                name: "fs/write".to_string(),
352                arguments: serde_json::json!({}),
353                meta: None,
354                task: None,
355            }),
356        );
357        let resp = svc.call(req).await.unwrap();
358        let err = resp.inner.unwrap_err();
359        assert!(err.message.contains("not authorized"));
360    }
361
362    #[tokio::test]
363    async fn test_rbac_reader_allowed_read() {
364        let mock = MockService::with_tools(&["fs/read"]);
365        let mut svc = RbacService::new(mock, test_rbac_config());
366
367        let req = request_with_scope(
368            "read-only",
369            McpRequest::CallTool(tower_mcp::protocol::CallToolParams {
370                name: "fs/read".to_string(),
371                arguments: serde_json::json!({}),
372                meta: None,
373                task: None,
374            }),
375        );
376        let resp = svc.call(req).await.unwrap();
377        assert!(resp.inner.is_ok(), "reader should call allowed tools");
378    }
379
380    #[tokio::test]
381    async fn test_rbac_filters_list_tools_for_role() {
382        let mock = MockService::with_tools(&["fs/read", "fs/write", "fs/delete"]);
383        let mut svc = RbacService::new(mock, test_rbac_config());
384
385        let req = request_with_scope("read-only", McpRequest::ListTools(Default::default()));
386        let resp = svc.call(req).await.unwrap();
387
388        match resp.inner.unwrap() {
389            McpResponse::ListTools(result) => {
390                let names: Vec<&str> = result.tools.iter().map(|t| t.name.as_str()).collect();
391                assert!(names.contains(&"fs/read"));
392                assert!(!names.contains(&"fs/write"));
393                assert!(!names.contains(&"fs/delete"));
394            }
395            other => panic!("expected ListTools, got: {:?}", other),
396        }
397    }
398
399    #[tokio::test]
400    async fn test_rbac_no_claims_passes_through() {
401        let mock = MockService::with_tools(&["fs/write"]);
402        let mut svc = RbacService::new(mock, test_rbac_config());
403
404        // No TokenClaims in extensions
405        let req = tower_mcp::RouterRequest {
406            id: RequestId::Number(1),
407            inner: McpRequest::CallTool(tower_mcp::protocol::CallToolParams {
408                name: "fs/write".to_string(),
409                arguments: serde_json::json!({}),
410                meta: None,
411                task: None,
412            }),
413            extensions: Extensions::new(),
414        };
415        let resp = svc.call(req).await.unwrap();
416        assert!(resp.inner.is_ok(), "no claims should pass through");
417    }
418}