1use std::collections::{HashMap, HashSet};
81use std::convert::Infallible;
82use std::future::Future;
83use std::pin::Pin;
84use std::sync::Arc;
85use std::task::{Context, Poll};
86
87use tower::Service;
88
89use tower_mcp::protocol::{McpRequest, McpResponse};
90use tower_mcp::{RouterRequest, RouterResponse};
91use tower_mcp_types::JsonRpcError;
92
93use crate::config::{RoleConfig, RoleMappingConfig};
94
95enum RoleResolution {
97 NoClaims,
100 Role(String),
102 Unmapped,
105}
106
107#[derive(Clone)]
109pub struct RbacConfig {
110 claim: String,
112 claim_to_role: HashMap<String, String>,
114 role_allow: HashMap<String, HashSet<String>>,
116 role_deny: HashMap<String, HashSet<String>>,
118 default_deny: bool,
120}
121
122impl RbacConfig {
123 pub fn new(roles: &[RoleConfig], mapping: &RoleMappingConfig) -> Self {
125 let mut role_allow = HashMap::new();
126 let mut role_deny = HashMap::new();
127
128 for role in roles {
129 if !role.allow_tools.is_empty() {
130 role_allow.insert(
131 role.name.clone(),
132 role.allow_tools.iter().cloned().collect(),
133 );
134 }
135 if !role.deny_tools.is_empty() {
136 role_deny.insert(role.name.clone(), role.deny_tools.iter().cloned().collect());
137 }
138 }
139
140 Self {
141 claim: mapping.claim.clone(),
142 claim_to_role: mapping.mapping.clone(),
143 role_allow,
144 role_deny,
145 default_deny: mapping.default_deny,
146 }
147 }
148
149 fn resolve_role(&self, extensions: &tower_mcp::router::Extensions) -> RoleResolution {
155 let Some(claims) = extensions.get::<tower_mcp::oauth::token::TokenClaims>() else {
156 return RoleResolution::NoClaims;
157 };
158
159 if self.claim == "scope" {
161 let scopes = claims.scopes();
162 for scope in &scopes {
163 if let Some(role) = self.claim_to_role.get(scope) {
164 return RoleResolution::Role(role.clone());
165 }
166 }
167 return RoleResolution::Unmapped;
168 }
169
170 if let Some(value) = claims.extra.get(&self.claim) {
172 let claim_str = match value {
173 serde_json::Value::String(s) => s.clone(),
174 other => other.to_string(),
175 };
176 if let Some(role) = self.claim_to_role.get(&claim_str) {
178 return RoleResolution::Role(role.clone());
179 }
180 for part in claim_str.split_whitespace() {
182 if let Some(role) = self.claim_to_role.get(part) {
183 return RoleResolution::Role(role.clone());
184 }
185 }
186 }
187
188 RoleResolution::Unmapped
189 }
190
191 fn is_tool_allowed(&self, role: &str, tool_name: &str) -> bool {
193 if let Some(allowed) = self.role_allow.get(role)
195 && !allowed.contains(tool_name)
196 {
197 return false;
198 }
199 if let Some(denied) = self.role_deny.get(role)
201 && denied.contains(tool_name)
202 {
203 return false;
204 }
205 true
206 }
207}
208
209#[derive(Clone)]
211pub struct RbacService<S> {
212 inner: S,
213 config: Arc<RbacConfig>,
214}
215
216impl<S> RbacService<S> {
217 pub fn new(inner: S, config: RbacConfig) -> Self {
219 Self {
220 inner,
221 config: Arc::new(config),
222 }
223 }
224}
225
226impl<S> Service<RouterRequest> for RbacService<S>
227where
228 S: Service<RouterRequest, Response = RouterResponse, Error = Infallible>
229 + Clone
230 + Send
231 + 'static,
232 S::Future: Send,
233{
234 type Response = RouterResponse;
235 type Error = Infallible;
236 type Future = Pin<Box<dyn Future<Output = Result<RouterResponse, Infallible>> + Send>>;
237
238 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
239 self.inner.poll_ready(cx)
240 }
241
242 fn call(&mut self, req: RouterRequest) -> Self::Future {
243 let config = Arc::clone(&self.config);
244 let request_id = req.id.clone();
245
246 let role = match config.resolve_role(&req.extensions) {
248 RoleResolution::NoClaims => {
252 let fut = self.inner.call(req);
253 return Box::pin(fut);
254 }
255 RoleResolution::Unmapped => {
258 if config.default_deny {
259 return Box::pin(async move {
260 Ok(RouterResponse {
261 id: request_id,
262 inner: Err(JsonRpcError::invalid_params(
263 "Authenticated principal carries no recognized role; \
264 access denied (rbac default_deny)"
265 .to_string(),
266 )),
267 })
268 });
269 }
270 let fut = self.inner.call(req);
271 return Box::pin(fut);
272 }
273 RoleResolution::Role(role) => role,
274 };
275
276 let role_for_filter = role.clone();
277
278 if let McpRequest::CallTool(ref params) = req.inner
280 && !config.is_tool_allowed(&role, ¶ms.name)
281 {
282 let tool_name = params.name.clone();
283 return Box::pin(async move {
284 Ok(RouterResponse {
285 id: request_id,
286 inner: Err(JsonRpcError::invalid_params(format!(
287 "Role '{}' is not authorized to call tool: {}",
288 role, tool_name
289 ))),
290 })
291 });
292 }
293
294 let fut = self.inner.call(req);
295
296 Box::pin(async move {
297 let mut resp = fut.await?;
298
299 if let Ok(McpResponse::ListTools(ref mut result)) = resp.inner {
301 result
302 .tools
303 .retain(|tool| config.is_tool_allowed(&role_for_filter, &tool.name));
304 }
305
306 Ok(resp)
307 })
308 }
309}
310
311#[cfg(test)]
312mod tests {
313 use std::collections::HashMap;
314
315 use tower::Service;
316 use tower_mcp::oauth::token::TokenClaims;
317 use tower_mcp::protocol::{McpRequest, McpResponse, RequestId};
318 use tower_mcp::router::Extensions;
319
320 use super::{RbacConfig, RbacService};
321 use crate::config::{RoleConfig, RoleMappingConfig};
322 use crate::test_util::MockService;
323
324 fn test_rbac_config() -> RbacConfig {
325 rbac_config_with_default_deny(false)
326 }
327
328 fn rbac_config_with_default_deny(default_deny: bool) -> RbacConfig {
329 let roles = vec![
330 RoleConfig {
331 name: "admin".into(),
332 allow_tools: vec![],
333 deny_tools: vec![],
334 },
335 RoleConfig {
336 name: "reader".into(),
337 allow_tools: vec!["fs/read".into()],
338 deny_tools: vec![],
339 },
340 ];
341 let mapping = RoleMappingConfig {
342 claim: "scope".into(),
343 mapping: HashMap::from([
344 ("admin".into(), "admin".into()),
345 ("read-only".into(), "reader".into()),
346 ]),
347 default_deny,
348 };
349 RbacConfig::new(&roles, &mapping)
350 }
351
352 fn request_with_scope(scope: &str, inner: McpRequest) -> tower_mcp::RouterRequest {
353 let mut extensions = Extensions::new();
354 extensions.insert(TokenClaims {
355 sub: None,
356 iss: None,
357 aud: None,
358 exp: None,
359 scope: Some(scope.to_string()),
360 client_id: None,
361 extra: HashMap::new(),
362 });
363 tower_mcp::RouterRequest {
364 id: RequestId::Number(1),
365 inner,
366 extensions,
367 }
368 }
369
370 #[tokio::test]
371 async fn test_rbac_admin_can_call_any_tool() {
372 let mock = MockService::with_tools(&["fs/read", "fs/write"]);
373 let mut svc = RbacService::new(mock, test_rbac_config());
374
375 let req = request_with_scope(
376 "admin",
377 McpRequest::CallTool(tower_mcp::protocol::CallToolParams {
378 name: "fs/write".to_string(),
379 arguments: serde_json::json!({}),
380 meta: None,
381 task: None,
382 }),
383 );
384 let resp = svc.call(req).await.unwrap();
385 assert!(resp.inner.is_ok(), "admin should call any tool");
386 }
387
388 #[tokio::test]
389 async fn test_rbac_reader_denied_write() {
390 let mock = MockService::with_tools(&["fs/read", "fs/write"]);
391 let mut svc = RbacService::new(mock, test_rbac_config());
392
393 let req = request_with_scope(
394 "read-only",
395 McpRequest::CallTool(tower_mcp::protocol::CallToolParams {
396 name: "fs/write".to_string(),
397 arguments: serde_json::json!({}),
398 meta: None,
399 task: None,
400 }),
401 );
402 let resp = svc.call(req).await.unwrap();
403 let err = resp.inner.unwrap_err();
404 assert!(err.message.contains("not authorized"));
405 }
406
407 #[tokio::test]
408 async fn test_rbac_reader_allowed_read() {
409 let mock = MockService::with_tools(&["fs/read"]);
410 let mut svc = RbacService::new(mock, test_rbac_config());
411
412 let req = request_with_scope(
413 "read-only",
414 McpRequest::CallTool(tower_mcp::protocol::CallToolParams {
415 name: "fs/read".to_string(),
416 arguments: serde_json::json!({}),
417 meta: None,
418 task: None,
419 }),
420 );
421 let resp = svc.call(req).await.unwrap();
422 assert!(resp.inner.is_ok(), "reader should call allowed tools");
423 }
424
425 #[tokio::test]
426 async fn test_rbac_filters_list_tools_for_role() {
427 let mock = MockService::with_tools(&["fs/read", "fs/write", "fs/delete"]);
428 let mut svc = RbacService::new(mock, test_rbac_config());
429
430 let req = request_with_scope("read-only", McpRequest::ListTools(Default::default()));
431 let resp = svc.call(req).await.unwrap();
432
433 match resp.inner.unwrap() {
434 McpResponse::ListTools(result) => {
435 let names: Vec<&str> = result.tools.iter().map(|t| t.name.as_str()).collect();
436 assert!(names.contains(&"fs/read"));
437 assert!(!names.contains(&"fs/write"));
438 assert!(!names.contains(&"fs/delete"));
439 }
440 other => panic!("expected ListTools, got: {:?}", other),
441 }
442 }
443
444 #[tokio::test]
445 async fn test_rbac_no_claims_passes_through() {
446 let mock = MockService::with_tools(&["fs/write"]);
447 let mut svc = RbacService::new(mock, test_rbac_config());
448
449 let req = tower_mcp::RouterRequest {
451 id: RequestId::Number(1),
452 inner: McpRequest::CallTool(tower_mcp::protocol::CallToolParams {
453 name: "fs/write".to_string(),
454 arguments: serde_json::json!({}),
455 meta: None,
456 task: None,
457 }),
458 extensions: Extensions::new(),
459 };
460 let resp = svc.call(req).await.unwrap();
461 assert!(resp.inner.is_ok(), "no claims should pass through");
462 }
463
464 #[tokio::test]
465 async fn test_rbac_unmapped_scope_passes_through_by_default() {
466 let mock = MockService::with_tools(&["fs/write"]);
469 let mut svc = RbacService::new(mock, rbac_config_with_default_deny(false));
470
471 let req = request_with_scope(
472 "unknown-scope",
473 McpRequest::CallTool(tower_mcp::protocol::CallToolParams {
474 name: "fs/write".to_string(),
475 arguments: serde_json::json!({}),
476 meta: None,
477 task: None,
478 }),
479 );
480 let resp = svc.call(req).await.unwrap();
481 assert!(
482 resp.inner.is_ok(),
483 "unmapped scope should pass through when default_deny is false"
484 );
485 }
486
487 #[tokio::test]
488 async fn test_rbac_unmapped_scope_denied_with_default_deny() {
489 let mock = MockService::with_tools(&["fs/write"]);
491 let mut svc = RbacService::new(mock, rbac_config_with_default_deny(true));
492
493 let req = request_with_scope(
494 "unknown-scope",
495 McpRequest::CallTool(tower_mcp::protocol::CallToolParams {
496 name: "fs/write".to_string(),
497 arguments: serde_json::json!({}),
498 meta: None,
499 task: None,
500 }),
501 );
502 let resp = svc.call(req).await.unwrap();
503 let err = resp.inner.unwrap_err();
504 assert!(
505 err.message.contains("default_deny"),
506 "unmapped scope should be denied when default_deny is true, got: {}",
507 err.message
508 );
509 }
510
511 #[tokio::test]
512 async fn test_rbac_mapped_scope_resolves_with_default_deny_enabled() {
513 let mock = MockService::with_tools(&["fs/read", "fs/write"]);
516 let mut svc = RbacService::new(mock, rbac_config_with_default_deny(true));
517
518 let read_req = request_with_scope(
519 "read-only",
520 McpRequest::CallTool(tower_mcp::protocol::CallToolParams {
521 name: "fs/read".to_string(),
522 arguments: serde_json::json!({}),
523 meta: None,
524 task: None,
525 }),
526 );
527 let resp = svc.call(read_req).await.unwrap();
528 assert!(
529 resp.inner.is_ok(),
530 "mapped role should still resolve with default_deny enabled"
531 );
532
533 let write_req = request_with_scope(
534 "read-only",
535 McpRequest::CallTool(tower_mcp::protocol::CallToolParams {
536 name: "fs/write".to_string(),
537 arguments: serde_json::json!({}),
538 meta: None,
539 task: None,
540 }),
541 );
542 let resp = svc.call(write_req).await.unwrap();
543 let err = resp.inner.unwrap_err();
544 assert!(
545 err.message.contains("not authorized"),
546 "reader should be denied write via role policy, got: {}",
547 err.message
548 );
549 }
550
551 #[tokio::test]
552 async fn test_rbac_no_claims_passes_through_with_default_deny() {
553 let mock = MockService::with_tools(&["fs/write"]);
556 let mut svc = RbacService::new(mock, rbac_config_with_default_deny(true));
557
558 let req = tower_mcp::RouterRequest {
559 id: RequestId::Number(1),
560 inner: McpRequest::CallTool(tower_mcp::protocol::CallToolParams {
561 name: "fs/write".to_string(),
562 arguments: serde_json::json!({}),
563 meta: None,
564 task: None,
565 }),
566 extensions: Extensions::new(),
567 };
568 let resp = svc.call(req).await.unwrap();
569 assert!(
570 resp.inner.is_ok(),
571 "no claims must pass through even when default_deny is true"
572 );
573 }
574}