1use std::collections::{HashMap, HashSet};
81use std::convert::Infallible;
82use std::future::Future;
83use std::pin::Pin;
84use std::sync::Arc;
85use std::task::{Context, Poll};
86
87use tower::Service;
88
89use tower_mcp::protocol::{McpRequest, McpResponse};
90use tower_mcp::{RouterRequest, RouterResponse};
91use tower_mcp_types::JsonRpcError;
92
93use crate::config::{RoleConfig, RoleMappingConfig};
94
95#[derive(Clone)]
97pub struct RbacConfig {
98 claim: String,
100 claim_to_role: HashMap<String, String>,
102 role_allow: HashMap<String, HashSet<String>>,
104 role_deny: HashMap<String, HashSet<String>>,
106}
107
108impl RbacConfig {
109 pub fn new(roles: &[RoleConfig], mapping: &RoleMappingConfig) -> Self {
111 let mut role_allow = HashMap::new();
112 let mut role_deny = HashMap::new();
113
114 for role in roles {
115 if !role.allow_tools.is_empty() {
116 role_allow.insert(
117 role.name.clone(),
118 role.allow_tools.iter().cloned().collect(),
119 );
120 }
121 if !role.deny_tools.is_empty() {
122 role_deny.insert(role.name.clone(), role.deny_tools.iter().cloned().collect());
123 }
124 }
125
126 Self {
127 claim: mapping.claim.clone(),
128 claim_to_role: mapping.mapping.clone(),
129 role_allow,
130 role_deny,
131 }
132 }
133
134 fn resolve_role(&self, extensions: &tower_mcp::router::Extensions) -> Option<String> {
136 let claims = extensions.get::<tower_mcp::oauth::token::TokenClaims>()?;
137
138 if self.claim == "scope" {
140 let scopes = claims.scopes();
141 for scope in &scopes {
142 if let Some(role) = self.claim_to_role.get(scope) {
143 return Some(role.clone());
144 }
145 }
146 return None;
147 }
148
149 if let Some(value) = claims.extra.get(&self.claim) {
151 let claim_str = match value {
152 serde_json::Value::String(s) => s.clone(),
153 other => other.to_string(),
154 };
155 if let Some(role) = self.claim_to_role.get(&claim_str) {
157 return Some(role.clone());
158 }
159 for part in claim_str.split_whitespace() {
161 if let Some(role) = self.claim_to_role.get(part) {
162 return Some(role.clone());
163 }
164 }
165 }
166
167 None
168 }
169
170 fn is_tool_allowed(&self, role: &str, tool_name: &str) -> bool {
172 if let Some(allowed) = self.role_allow.get(role)
174 && !allowed.contains(tool_name)
175 {
176 return false;
177 }
178 if let Some(denied) = self.role_deny.get(role)
180 && denied.contains(tool_name)
181 {
182 return false;
183 }
184 true
185 }
186}
187
188#[derive(Clone)]
190pub struct RbacService<S> {
191 inner: S,
192 config: Arc<RbacConfig>,
193}
194
195impl<S> RbacService<S> {
196 pub fn new(inner: S, config: RbacConfig) -> Self {
198 Self {
199 inner,
200 config: Arc::new(config),
201 }
202 }
203}
204
205impl<S> Service<RouterRequest> for RbacService<S>
206where
207 S: Service<RouterRequest, Response = RouterResponse, Error = Infallible>
208 + Clone
209 + Send
210 + 'static,
211 S::Future: Send,
212{
213 type Response = RouterResponse;
214 type Error = Infallible;
215 type Future = Pin<Box<dyn Future<Output = Result<RouterResponse, Infallible>> + Send>>;
216
217 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
218 self.inner.poll_ready(cx)
219 }
220
221 fn call(&mut self, req: RouterRequest) -> Self::Future {
222 let config = Arc::clone(&self.config);
223 let request_id = req.id.clone();
224
225 let role = config.resolve_role(&req.extensions);
227
228 let Some(role) = role else {
232 let fut = self.inner.call(req);
233 return Box::pin(fut);
234 };
235
236 let role_for_filter = role.clone();
237
238 if let McpRequest::CallTool(ref params) = req.inner
240 && !config.is_tool_allowed(&role, ¶ms.name)
241 {
242 let tool_name = params.name.clone();
243 return Box::pin(async move {
244 Ok(RouterResponse {
245 id: request_id,
246 inner: Err(JsonRpcError::invalid_params(format!(
247 "Role '{}' is not authorized to call tool: {}",
248 role, tool_name
249 ))),
250 })
251 });
252 }
253
254 let fut = self.inner.call(req);
255
256 Box::pin(async move {
257 let mut resp = fut.await?;
258
259 if let Ok(McpResponse::ListTools(ref mut result)) = resp.inner {
261 result
262 .tools
263 .retain(|tool| config.is_tool_allowed(&role_for_filter, &tool.name));
264 }
265
266 Ok(resp)
267 })
268 }
269}
270
271#[cfg(test)]
272mod tests {
273 use std::collections::HashMap;
274
275 use tower::Service;
276 use tower_mcp::oauth::token::TokenClaims;
277 use tower_mcp::protocol::{McpRequest, McpResponse, RequestId};
278 use tower_mcp::router::Extensions;
279
280 use super::{RbacConfig, RbacService};
281 use crate::config::{RoleConfig, RoleMappingConfig};
282 use crate::test_util::MockService;
283
284 fn test_rbac_config() -> RbacConfig {
285 let roles = vec![
286 RoleConfig {
287 name: "admin".into(),
288 allow_tools: vec![],
289 deny_tools: vec![],
290 },
291 RoleConfig {
292 name: "reader".into(),
293 allow_tools: vec!["fs/read".into()],
294 deny_tools: vec![],
295 },
296 ];
297 let mapping = RoleMappingConfig {
298 claim: "scope".into(),
299 mapping: HashMap::from([
300 ("admin".into(), "admin".into()),
301 ("read-only".into(), "reader".into()),
302 ]),
303 };
304 RbacConfig::new(&roles, &mapping)
305 }
306
307 fn request_with_scope(scope: &str, inner: McpRequest) -> tower_mcp::RouterRequest {
308 let mut extensions = Extensions::new();
309 extensions.insert(TokenClaims {
310 sub: None,
311 iss: None,
312 aud: None,
313 exp: None,
314 scope: Some(scope.to_string()),
315 client_id: None,
316 extra: HashMap::new(),
317 });
318 tower_mcp::RouterRequest {
319 id: RequestId::Number(1),
320 inner,
321 extensions,
322 }
323 }
324
325 #[tokio::test]
326 async fn test_rbac_admin_can_call_any_tool() {
327 let mock = MockService::with_tools(&["fs/read", "fs/write"]);
328 let mut svc = RbacService::new(mock, test_rbac_config());
329
330 let req = request_with_scope(
331 "admin",
332 McpRequest::CallTool(tower_mcp::protocol::CallToolParams {
333 name: "fs/write".to_string(),
334 arguments: serde_json::json!({}),
335 meta: None,
336 task: None,
337 }),
338 );
339 let resp = svc.call(req).await.unwrap();
340 assert!(resp.inner.is_ok(), "admin should call any tool");
341 }
342
343 #[tokio::test]
344 async fn test_rbac_reader_denied_write() {
345 let mock = MockService::with_tools(&["fs/read", "fs/write"]);
346 let mut svc = RbacService::new(mock, test_rbac_config());
347
348 let req = request_with_scope(
349 "read-only",
350 McpRequest::CallTool(tower_mcp::protocol::CallToolParams {
351 name: "fs/write".to_string(),
352 arguments: serde_json::json!({}),
353 meta: None,
354 task: None,
355 }),
356 );
357 let resp = svc.call(req).await.unwrap();
358 let err = resp.inner.unwrap_err();
359 assert!(err.message.contains("not authorized"));
360 }
361
362 #[tokio::test]
363 async fn test_rbac_reader_allowed_read() {
364 let mock = MockService::with_tools(&["fs/read"]);
365 let mut svc = RbacService::new(mock, test_rbac_config());
366
367 let req = request_with_scope(
368 "read-only",
369 McpRequest::CallTool(tower_mcp::protocol::CallToolParams {
370 name: "fs/read".to_string(),
371 arguments: serde_json::json!({}),
372 meta: None,
373 task: None,
374 }),
375 );
376 let resp = svc.call(req).await.unwrap();
377 assert!(resp.inner.is_ok(), "reader should call allowed tools");
378 }
379
380 #[tokio::test]
381 async fn test_rbac_filters_list_tools_for_role() {
382 let mock = MockService::with_tools(&["fs/read", "fs/write", "fs/delete"]);
383 let mut svc = RbacService::new(mock, test_rbac_config());
384
385 let req = request_with_scope("read-only", McpRequest::ListTools(Default::default()));
386 let resp = svc.call(req).await.unwrap();
387
388 match resp.inner.unwrap() {
389 McpResponse::ListTools(result) => {
390 let names: Vec<&str> = result.tools.iter().map(|t| t.name.as_str()).collect();
391 assert!(names.contains(&"fs/read"));
392 assert!(!names.contains(&"fs/write"));
393 assert!(!names.contains(&"fs/delete"));
394 }
395 other => panic!("expected ListTools, got: {:?}", other),
396 }
397 }
398
399 #[tokio::test]
400 async fn test_rbac_no_claims_passes_through() {
401 let mock = MockService::with_tools(&["fs/write"]);
402 let mut svc = RbacService::new(mock, test_rbac_config());
403
404 let req = tower_mcp::RouterRequest {
406 id: RequestId::Number(1),
407 inner: McpRequest::CallTool(tower_mcp::protocol::CallToolParams {
408 name: "fs/write".to_string(),
409 arguments: serde_json::json!({}),
410 meta: None,
411 task: None,
412 }),
413 extensions: Extensions::new(),
414 };
415 let resp = svc.call(req).await.unwrap();
416 assert!(resp.inner.is_ok(), "no claims should pass through");
417 }
418}