Skip to main content

mcp_proxy/
rbac.rs

1//! Role-based access control middleware for the proxy.
2//!
3//! Reads JWT claims from `RouterRequest.extensions`, maps them to roles
4//! via config, and applies per-role tool allow/deny lists.
5//! Runs on top of static capability filtering (can only further restrict).
6
7use std::collections::{HashMap, HashSet};
8use std::convert::Infallible;
9use std::future::Future;
10use std::pin::Pin;
11use std::sync::Arc;
12use std::task::{Context, Poll};
13
14use tower::Service;
15
16use tower_mcp::protocol::{McpRequest, McpResponse};
17use tower_mcp::{RouterRequest, RouterResponse};
18use tower_mcp_types::JsonRpcError;
19
20use crate::config::{RoleConfig, RoleMappingConfig};
21
22/// Resolved RBAC rules.
23#[derive(Clone)]
24pub struct RbacConfig {
25    /// Claim name to read from TokenClaims (e.g. "scope", "role")
26    claim: String,
27    /// Map of claim value -> role name
28    claim_to_role: HashMap<String, String>,
29    /// Map of role name -> allowed tools (empty = all allowed)
30    role_allow: HashMap<String, HashSet<String>>,
31    /// Map of role name -> denied tools
32    role_deny: HashMap<String, HashSet<String>>,
33}
34
35impl RbacConfig {
36    /// Build RBAC config from role definitions and claim-to-role mapping.
37    pub fn new(roles: &[RoleConfig], mapping: &RoleMappingConfig) -> Self {
38        let mut role_allow = HashMap::new();
39        let mut role_deny = HashMap::new();
40
41        for role in roles {
42            if !role.allow_tools.is_empty() {
43                role_allow.insert(
44                    role.name.clone(),
45                    role.allow_tools.iter().cloned().collect(),
46                );
47            }
48            if !role.deny_tools.is_empty() {
49                role_deny.insert(role.name.clone(), role.deny_tools.iter().cloned().collect());
50            }
51        }
52
53        Self {
54            claim: mapping.claim.clone(),
55            claim_to_role: mapping.mapping.clone(),
56            role_allow,
57            role_deny,
58        }
59    }
60
61    /// Resolve the role for the current request from TokenClaims.
62    fn resolve_role(&self, extensions: &tower_mcp::router::Extensions) -> Option<String> {
63        let claims = extensions.get::<tower_mcp::oauth::token::TokenClaims>()?;
64
65        // Check standard scope field first
66        if self.claim == "scope" {
67            let scopes = claims.scopes();
68            for scope in &scopes {
69                if let Some(role) = self.claim_to_role.get(scope) {
70                    return Some(role.clone());
71                }
72            }
73            return None;
74        }
75
76        // Check extra claims
77        if let Some(value) = claims.extra.get(&self.claim) {
78            let claim_str = match value {
79                serde_json::Value::String(s) => s.clone(),
80                other => other.to_string(),
81            };
82            // Try direct mapping
83            if let Some(role) = self.claim_to_role.get(&claim_str) {
84                return Some(role.clone());
85            }
86            // Try space-delimited (like scope)
87            for part in claim_str.split_whitespace() {
88                if let Some(role) = self.claim_to_role.get(part) {
89                    return Some(role.clone());
90                }
91            }
92        }
93
94        None
95    }
96
97    /// Check if a tool is allowed for the given role.
98    fn is_tool_allowed(&self, role: &str, tool_name: &str) -> bool {
99        // If role has an allowlist, tool must be in it
100        if let Some(allowed) = self.role_allow.get(role)
101            && !allowed.contains(tool_name)
102        {
103            return false;
104        }
105        // If role has a denylist, tool must not be in it
106        if let Some(denied) = self.role_deny.get(role)
107            && denied.contains(tool_name)
108        {
109            return false;
110        }
111        true
112    }
113}
114
115/// Middleware that enforces RBAC on tool calls and list responses.
116#[derive(Clone)]
117pub struct RbacService<S> {
118    inner: S,
119    config: Arc<RbacConfig>,
120}
121
122impl<S> RbacService<S> {
123    /// Create a new RBAC enforcement service wrapping `inner`.
124    pub fn new(inner: S, config: RbacConfig) -> Self {
125        Self {
126            inner,
127            config: Arc::new(config),
128        }
129    }
130}
131
132impl<S> Service<RouterRequest> for RbacService<S>
133where
134    S: Service<RouterRequest, Response = RouterResponse, Error = Infallible>
135        + Clone
136        + Send
137        + 'static,
138    S::Future: Send,
139{
140    type Response = RouterResponse;
141    type Error = Infallible;
142    type Future = Pin<Box<dyn Future<Output = Result<RouterResponse, Infallible>> + Send>>;
143
144    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
145        self.inner.poll_ready(cx)
146    }
147
148    fn call(&mut self, req: RouterRequest) -> Self::Future {
149        let config = Arc::clone(&self.config);
150        let request_id = req.id.clone();
151
152        // Resolve role from extensions
153        let role = config.resolve_role(&req.extensions);
154
155        // If no role resolved, pass through (no RBAC restriction applies)
156        // This allows unauthenticated or bearer-auth requests to proceed
157        // (they're already validated by the auth layer)
158        let Some(role) = role else {
159            let fut = self.inner.call(req);
160            return Box::pin(fut);
161        };
162
163        let role_for_filter = role.clone();
164
165        // Check tool calls against RBAC
166        if let McpRequest::CallTool(ref params) = req.inner
167            && !config.is_tool_allowed(&role, &params.name)
168        {
169            let tool_name = params.name.clone();
170            return Box::pin(async move {
171                Ok(RouterResponse {
172                    id: request_id,
173                    inner: Err(JsonRpcError::invalid_params(format!(
174                        "Role '{}' is not authorized to call tool: {}",
175                        role, tool_name
176                    ))),
177                })
178            });
179        }
180
181        let fut = self.inner.call(req);
182
183        Box::pin(async move {
184            let mut resp = fut.await?;
185
186            // Filter list_tools response based on role
187            if let Ok(McpResponse::ListTools(ref mut result)) = resp.inner {
188                result
189                    .tools
190                    .retain(|tool| config.is_tool_allowed(&role_for_filter, &tool.name));
191            }
192
193            Ok(resp)
194        })
195    }
196}
197
198#[cfg(test)]
199mod tests {
200    use std::collections::HashMap;
201
202    use tower::Service;
203    use tower_mcp::oauth::token::TokenClaims;
204    use tower_mcp::protocol::{McpRequest, McpResponse, RequestId};
205    use tower_mcp::router::Extensions;
206
207    use super::{RbacConfig, RbacService};
208    use crate::config::{RoleConfig, RoleMappingConfig};
209    use crate::test_util::MockService;
210
211    fn test_rbac_config() -> RbacConfig {
212        let roles = vec![
213            RoleConfig {
214                name: "admin".into(),
215                allow_tools: vec![],
216                deny_tools: vec![],
217            },
218            RoleConfig {
219                name: "reader".into(),
220                allow_tools: vec!["fs/read".into()],
221                deny_tools: vec![],
222            },
223        ];
224        let mapping = RoleMappingConfig {
225            claim: "scope".into(),
226            mapping: HashMap::from([
227                ("admin".into(), "admin".into()),
228                ("read-only".into(), "reader".into()),
229            ]),
230        };
231        RbacConfig::new(&roles, &mapping)
232    }
233
234    fn request_with_scope(scope: &str, inner: McpRequest) -> tower_mcp::RouterRequest {
235        let mut extensions = Extensions::new();
236        extensions.insert(TokenClaims {
237            sub: None,
238            iss: None,
239            aud: None,
240            exp: None,
241            scope: Some(scope.to_string()),
242            client_id: None,
243            extra: HashMap::new(),
244        });
245        tower_mcp::RouterRequest {
246            id: RequestId::Number(1),
247            inner,
248            extensions,
249        }
250    }
251
252    #[tokio::test]
253    async fn test_rbac_admin_can_call_any_tool() {
254        let mock = MockService::with_tools(&["fs/read", "fs/write"]);
255        let mut svc = RbacService::new(mock, test_rbac_config());
256
257        let req = request_with_scope(
258            "admin",
259            McpRequest::CallTool(tower_mcp::protocol::CallToolParams {
260                name: "fs/write".to_string(),
261                arguments: serde_json::json!({}),
262                meta: None,
263                task: None,
264            }),
265        );
266        let resp = svc.call(req).await.unwrap();
267        assert!(resp.inner.is_ok(), "admin should call any tool");
268    }
269
270    #[tokio::test]
271    async fn test_rbac_reader_denied_write() {
272        let mock = MockService::with_tools(&["fs/read", "fs/write"]);
273        let mut svc = RbacService::new(mock, test_rbac_config());
274
275        let req = request_with_scope(
276            "read-only",
277            McpRequest::CallTool(tower_mcp::protocol::CallToolParams {
278                name: "fs/write".to_string(),
279                arguments: serde_json::json!({}),
280                meta: None,
281                task: None,
282            }),
283        );
284        let resp = svc.call(req).await.unwrap();
285        let err = resp.inner.unwrap_err();
286        assert!(err.message.contains("not authorized"));
287    }
288
289    #[tokio::test]
290    async fn test_rbac_reader_allowed_read() {
291        let mock = MockService::with_tools(&["fs/read"]);
292        let mut svc = RbacService::new(mock, test_rbac_config());
293
294        let req = request_with_scope(
295            "read-only",
296            McpRequest::CallTool(tower_mcp::protocol::CallToolParams {
297                name: "fs/read".to_string(),
298                arguments: serde_json::json!({}),
299                meta: None,
300                task: None,
301            }),
302        );
303        let resp = svc.call(req).await.unwrap();
304        assert!(resp.inner.is_ok(), "reader should call allowed tools");
305    }
306
307    #[tokio::test]
308    async fn test_rbac_filters_list_tools_for_role() {
309        let mock = MockService::with_tools(&["fs/read", "fs/write", "fs/delete"]);
310        let mut svc = RbacService::new(mock, test_rbac_config());
311
312        let req = request_with_scope("read-only", McpRequest::ListTools(Default::default()));
313        let resp = svc.call(req).await.unwrap();
314
315        match resp.inner.unwrap() {
316            McpResponse::ListTools(result) => {
317                let names: Vec<&str> = result.tools.iter().map(|t| t.name.as_str()).collect();
318                assert!(names.contains(&"fs/read"));
319                assert!(!names.contains(&"fs/write"));
320                assert!(!names.contains(&"fs/delete"));
321            }
322            other => panic!("expected ListTools, got: {:?}", other),
323        }
324    }
325
326    #[tokio::test]
327    async fn test_rbac_no_claims_passes_through() {
328        let mock = MockService::with_tools(&["fs/write"]);
329        let mut svc = RbacService::new(mock, test_rbac_config());
330
331        // No TokenClaims in extensions
332        let req = tower_mcp::RouterRequest {
333            id: RequestId::Number(1),
334            inner: McpRequest::CallTool(tower_mcp::protocol::CallToolParams {
335                name: "fs/write".to_string(),
336                arguments: serde_json::json!({}),
337                meta: None,
338                task: None,
339            }),
340            extensions: Extensions::new(),
341        };
342        let resp = svc.call(req).await.unwrap();
343        assert!(resp.inner.is_ok(), "no claims should pass through");
344    }
345}