1use std::collections::{HashMap, HashSet};
8use std::convert::Infallible;
9use std::future::Future;
10use std::pin::Pin;
11use std::sync::Arc;
12use std::task::{Context, Poll};
13
14use tower::Service;
15
16use tower_mcp::protocol::{McpRequest, McpResponse};
17use tower_mcp::{RouterRequest, RouterResponse};
18use tower_mcp_types::JsonRpcError;
19
20use crate::config::{RoleConfig, RoleMappingConfig};
21
22#[derive(Clone)]
24pub struct RbacConfig {
25 claim: String,
27 claim_to_role: HashMap<String, String>,
29 role_allow: HashMap<String, HashSet<String>>,
31 role_deny: HashMap<String, HashSet<String>>,
33}
34
35impl RbacConfig {
36 pub fn new(roles: &[RoleConfig], mapping: &RoleMappingConfig) -> Self {
38 let mut role_allow = HashMap::new();
39 let mut role_deny = HashMap::new();
40
41 for role in roles {
42 if !role.allow_tools.is_empty() {
43 role_allow.insert(
44 role.name.clone(),
45 role.allow_tools.iter().cloned().collect(),
46 );
47 }
48 if !role.deny_tools.is_empty() {
49 role_deny.insert(role.name.clone(), role.deny_tools.iter().cloned().collect());
50 }
51 }
52
53 Self {
54 claim: mapping.claim.clone(),
55 claim_to_role: mapping.mapping.clone(),
56 role_allow,
57 role_deny,
58 }
59 }
60
61 fn resolve_role(&self, extensions: &tower_mcp::router::Extensions) -> Option<String> {
63 let claims = extensions.get::<tower_mcp::oauth::token::TokenClaims>()?;
64
65 if self.claim == "scope" {
67 let scopes = claims.scopes();
68 for scope in &scopes {
69 if let Some(role) = self.claim_to_role.get(scope) {
70 return Some(role.clone());
71 }
72 }
73 return None;
74 }
75
76 if let Some(value) = claims.extra.get(&self.claim) {
78 let claim_str = match value {
79 serde_json::Value::String(s) => s.clone(),
80 other => other.to_string(),
81 };
82 if let Some(role) = self.claim_to_role.get(&claim_str) {
84 return Some(role.clone());
85 }
86 for part in claim_str.split_whitespace() {
88 if let Some(role) = self.claim_to_role.get(part) {
89 return Some(role.clone());
90 }
91 }
92 }
93
94 None
95 }
96
97 fn is_tool_allowed(&self, role: &str, tool_name: &str) -> bool {
99 if let Some(allowed) = self.role_allow.get(role)
101 && !allowed.contains(tool_name)
102 {
103 return false;
104 }
105 if let Some(denied) = self.role_deny.get(role)
107 && denied.contains(tool_name)
108 {
109 return false;
110 }
111 true
112 }
113}
114
115#[derive(Clone)]
117pub struct RbacService<S> {
118 inner: S,
119 config: Arc<RbacConfig>,
120}
121
122impl<S> RbacService<S> {
123 pub fn new(inner: S, config: RbacConfig) -> Self {
125 Self {
126 inner,
127 config: Arc::new(config),
128 }
129 }
130}
131
132impl<S> Service<RouterRequest> for RbacService<S>
133where
134 S: Service<RouterRequest, Response = RouterResponse, Error = Infallible>
135 + Clone
136 + Send
137 + 'static,
138 S::Future: Send,
139{
140 type Response = RouterResponse;
141 type Error = Infallible;
142 type Future = Pin<Box<dyn Future<Output = Result<RouterResponse, Infallible>> + Send>>;
143
144 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
145 self.inner.poll_ready(cx)
146 }
147
148 fn call(&mut self, req: RouterRequest) -> Self::Future {
149 let config = Arc::clone(&self.config);
150 let request_id = req.id.clone();
151
152 let role = config.resolve_role(&req.extensions);
154
155 let Some(role) = role else {
159 let fut = self.inner.call(req);
160 return Box::pin(fut);
161 };
162
163 let role_for_filter = role.clone();
164
165 if let McpRequest::CallTool(ref params) = req.inner
167 && !config.is_tool_allowed(&role, ¶ms.name)
168 {
169 let tool_name = params.name.clone();
170 return Box::pin(async move {
171 Ok(RouterResponse {
172 id: request_id,
173 inner: Err(JsonRpcError::invalid_params(format!(
174 "Role '{}' is not authorized to call tool: {}",
175 role, tool_name
176 ))),
177 })
178 });
179 }
180
181 let fut = self.inner.call(req);
182
183 Box::pin(async move {
184 let mut resp = fut.await?;
185
186 if let Ok(McpResponse::ListTools(ref mut result)) = resp.inner {
188 result
189 .tools
190 .retain(|tool| config.is_tool_allowed(&role_for_filter, &tool.name));
191 }
192
193 Ok(resp)
194 })
195 }
196}
197
198#[cfg(test)]
199mod tests {
200 use std::collections::HashMap;
201
202 use tower::Service;
203 use tower_mcp::oauth::token::TokenClaims;
204 use tower_mcp::protocol::{McpRequest, McpResponse, RequestId};
205 use tower_mcp::router::Extensions;
206
207 use super::{RbacConfig, RbacService};
208 use crate::config::{RoleConfig, RoleMappingConfig};
209 use crate::test_util::MockService;
210
211 fn test_rbac_config() -> RbacConfig {
212 let roles = vec![
213 RoleConfig {
214 name: "admin".into(),
215 allow_tools: vec![],
216 deny_tools: vec![],
217 },
218 RoleConfig {
219 name: "reader".into(),
220 allow_tools: vec!["fs/read".into()],
221 deny_tools: vec![],
222 },
223 ];
224 let mapping = RoleMappingConfig {
225 claim: "scope".into(),
226 mapping: HashMap::from([
227 ("admin".into(), "admin".into()),
228 ("read-only".into(), "reader".into()),
229 ]),
230 };
231 RbacConfig::new(&roles, &mapping)
232 }
233
234 fn request_with_scope(scope: &str, inner: McpRequest) -> tower_mcp::RouterRequest {
235 let mut extensions = Extensions::new();
236 extensions.insert(TokenClaims {
237 sub: None,
238 iss: None,
239 aud: None,
240 exp: None,
241 scope: Some(scope.to_string()),
242 client_id: None,
243 extra: HashMap::new(),
244 });
245 tower_mcp::RouterRequest {
246 id: RequestId::Number(1),
247 inner,
248 extensions,
249 }
250 }
251
252 #[tokio::test]
253 async fn test_rbac_admin_can_call_any_tool() {
254 let mock = MockService::with_tools(&["fs/read", "fs/write"]);
255 let mut svc = RbacService::new(mock, test_rbac_config());
256
257 let req = request_with_scope(
258 "admin",
259 McpRequest::CallTool(tower_mcp::protocol::CallToolParams {
260 name: "fs/write".to_string(),
261 arguments: serde_json::json!({}),
262 meta: None,
263 task: None,
264 }),
265 );
266 let resp = svc.call(req).await.unwrap();
267 assert!(resp.inner.is_ok(), "admin should call any tool");
268 }
269
270 #[tokio::test]
271 async fn test_rbac_reader_denied_write() {
272 let mock = MockService::with_tools(&["fs/read", "fs/write"]);
273 let mut svc = RbacService::new(mock, test_rbac_config());
274
275 let req = request_with_scope(
276 "read-only",
277 McpRequest::CallTool(tower_mcp::protocol::CallToolParams {
278 name: "fs/write".to_string(),
279 arguments: serde_json::json!({}),
280 meta: None,
281 task: None,
282 }),
283 );
284 let resp = svc.call(req).await.unwrap();
285 let err = resp.inner.unwrap_err();
286 assert!(err.message.contains("not authorized"));
287 }
288
289 #[tokio::test]
290 async fn test_rbac_reader_allowed_read() {
291 let mock = MockService::with_tools(&["fs/read"]);
292 let mut svc = RbacService::new(mock, test_rbac_config());
293
294 let req = request_with_scope(
295 "read-only",
296 McpRequest::CallTool(tower_mcp::protocol::CallToolParams {
297 name: "fs/read".to_string(),
298 arguments: serde_json::json!({}),
299 meta: None,
300 task: None,
301 }),
302 );
303 let resp = svc.call(req).await.unwrap();
304 assert!(resp.inner.is_ok(), "reader should call allowed tools");
305 }
306
307 #[tokio::test]
308 async fn test_rbac_filters_list_tools_for_role() {
309 let mock = MockService::with_tools(&["fs/read", "fs/write", "fs/delete"]);
310 let mut svc = RbacService::new(mock, test_rbac_config());
311
312 let req = request_with_scope("read-only", McpRequest::ListTools(Default::default()));
313 let resp = svc.call(req).await.unwrap();
314
315 match resp.inner.unwrap() {
316 McpResponse::ListTools(result) => {
317 let names: Vec<&str> = result.tools.iter().map(|t| t.name.as_str()).collect();
318 assert!(names.contains(&"fs/read"));
319 assert!(!names.contains(&"fs/write"));
320 assert!(!names.contains(&"fs/delete"));
321 }
322 other => panic!("expected ListTools, got: {:?}", other),
323 }
324 }
325
326 #[tokio::test]
327 async fn test_rbac_no_claims_passes_through() {
328 let mock = MockService::with_tools(&["fs/write"]);
329 let mut svc = RbacService::new(mock, test_rbac_config());
330
331 let req = tower_mcp::RouterRequest {
333 id: RequestId::Number(1),
334 inner: McpRequest::CallTool(tower_mcp::protocol::CallToolParams {
335 name: "fs/write".to_string(),
336 arguments: serde_json::json!({}),
337 meta: None,
338 task: None,
339 }),
340 extensions: Extensions::new(),
341 };
342 let resp = svc.call(req).await.unwrap();
343 assert!(resp.inner.is_ok(), "no claims should pass through");
344 }
345}