1use std::{net::IpAddr, num::NonZeroU32, sync::Arc, time::Duration};
12
13use axum::{
14 body::Body,
15 extract::ConnectInfo,
16 http::{Method, Request, StatusCode},
17 middleware::Next,
18 response::{IntoResponse, Response},
19};
20use hmac::{Hmac, Mac};
21use http_body_util::BodyExt;
22use secrecy::{ExposeSecret, SecretString};
23use serde::Deserialize;
24use sha2::Sha256;
25
26use crate::{
27 auth::{AuthIdentity, TlsConnInfo},
28 bounded_limiter::BoundedKeyedLimiter,
29 error::McpxError,
30};
31
32pub(crate) type ToolRateLimiter = BoundedKeyedLimiter<IpAddr>;
35
36const DEFAULT_TOOL_RATE: NonZeroU32 = NonZeroU32::new(120).unwrap();
39
40const DEFAULT_TOOL_MAX_TRACKED_KEYS: usize = 10_000;
43
44const DEFAULT_TOOL_IDLE_EVICTION: Duration = Duration::from_mins(15);
46
47#[must_use]
53pub(crate) fn build_tool_rate_limiter(max_per_minute: u32) -> Arc<ToolRateLimiter> {
54 build_tool_rate_limiter_with_bounds(
55 max_per_minute,
56 DEFAULT_TOOL_MAX_TRACKED_KEYS,
57 DEFAULT_TOOL_IDLE_EVICTION,
58 )
59}
60
61#[must_use]
63pub(crate) fn build_tool_rate_limiter_with_bounds(
64 max_per_minute: u32,
65 max_tracked_keys: usize,
66 idle_eviction: Duration,
67) -> Arc<ToolRateLimiter> {
68 let quota =
69 governor::Quota::per_minute(NonZeroU32::new(max_per_minute).unwrap_or(DEFAULT_TOOL_RATE));
70 Arc::new(BoundedKeyedLimiter::new(
71 quota,
72 max_tracked_keys,
73 idle_eviction,
74 ))
75}
76
77tokio::task_local! {
84 static CURRENT_ROLE: String;
85 static CURRENT_IDENTITY: String;
86 static CURRENT_TOKEN: SecretString;
87 static CURRENT_SUB: String;
88}
89
90#[must_use]
93pub fn current_role() -> Option<String> {
94 CURRENT_ROLE.try_with(Clone::clone).ok()
95}
96
97#[must_use]
100pub fn current_identity() -> Option<String> {
101 CURRENT_IDENTITY.try_with(Clone::clone).ok()
102}
103
104#[must_use]
117pub fn current_token() -> Option<SecretString> {
118 CURRENT_TOKEN
119 .try_with(|t| {
120 if t.expose_secret().is_empty() {
121 None
122 } else {
123 Some(t.clone())
124 }
125 })
126 .ok()
127 .flatten()
128}
129
130#[must_use]
134pub fn current_sub() -> Option<String> {
135 CURRENT_SUB
136 .try_with(Clone::clone)
137 .ok()
138 .filter(|s| !s.is_empty())
139}
140
141pub async fn with_token_scope<F: Future>(token: SecretString, f: F) -> F::Output {
146 CURRENT_TOKEN.scope(token, f).await
147}
148
149pub async fn with_rbac_scope<F: Future>(
154 role: String,
155 identity: String,
156 token: SecretString,
157 sub: String,
158 f: F,
159) -> F::Output {
160 CURRENT_ROLE
161 .scope(
162 role,
163 CURRENT_IDENTITY.scope(
164 identity,
165 CURRENT_TOKEN.scope(token, CURRENT_SUB.scope(sub, f)),
166 ),
167 )
168 .await
169}
170
171#[derive(Debug, Clone, Deserialize)]
173#[non_exhaustive]
174pub struct RoleConfig {
175 pub name: String,
177 #[serde(default)]
179 pub description: Option<String>,
180 #[serde(default)]
182 pub allow: Vec<String>,
183 #[serde(default)]
185 pub deny: Vec<String>,
186 #[serde(default = "default_hosts")]
188 pub hosts: Vec<String>,
189 #[serde(default)]
193 pub argument_allowlists: Vec<ArgumentAllowlist>,
194}
195
196impl RoleConfig {
197 #[must_use]
199 pub fn new(name: impl Into<String>, allow: Vec<String>, hosts: Vec<String>) -> Self {
200 Self {
201 name: name.into(),
202 description: None,
203 allow,
204 deny: vec![],
205 hosts,
206 argument_allowlists: vec![],
207 }
208 }
209
210 #[must_use]
212 pub fn with_argument_allowlists(mut self, allowlists: Vec<ArgumentAllowlist>) -> Self {
213 self.argument_allowlists = allowlists;
214 self
215 }
216}
217
218#[derive(Debug, Clone, Deserialize)]
225#[non_exhaustive]
226pub struct ArgumentAllowlist {
227 pub tool: String,
229 pub argument: String,
231 #[serde(default)]
233 pub allowed: Vec<String>,
234}
235
236impl ArgumentAllowlist {
237 #[must_use]
239 pub fn new(tool: impl Into<String>, argument: impl Into<String>, allowed: Vec<String>) -> Self {
240 Self {
241 tool: tool.into(),
242 argument: argument.into(),
243 allowed,
244 }
245 }
246}
247
248fn default_hosts() -> Vec<String> {
249 vec!["*".into()]
250}
251
252#[derive(Debug, Clone, Default, Deserialize)]
254#[non_exhaustive]
255pub struct RbacConfig {
256 #[serde(default)]
258 pub enabled: bool,
259 #[serde(default)]
261 pub roles: Vec<RoleConfig>,
262 #[serde(default)]
271 pub redaction_salt: Option<SecretString>,
272}
273
274impl RbacConfig {
275 #[must_use]
277 pub fn with_roles(roles: Vec<RoleConfig>) -> Self {
278 Self {
279 enabled: true,
280 roles,
281 redaction_salt: None,
282 }
283 }
284}
285
286#[derive(Debug, Clone, Copy, PartialEq, Eq)]
288#[non_exhaustive]
289pub enum RbacDecision {
290 Allow,
292 Deny,
294}
295
296#[derive(Debug, Clone, serde::Serialize)]
298#[non_exhaustive]
299pub struct RbacRoleSummary {
300 pub name: String,
302 pub allow: usize,
304 pub deny: usize,
306 pub hosts: usize,
308 pub argument_allowlists: usize,
310}
311
312#[derive(Debug, Clone, serde::Serialize)]
314#[non_exhaustive]
315pub struct RbacPolicySummary {
316 pub enabled: bool,
318 pub roles: Vec<RbacRoleSummary>,
320}
321
322#[derive(Debug, Clone)]
328#[non_exhaustive]
329pub struct RbacPolicy {
330 roles: Vec<RoleConfig>,
331 enabled: bool,
332 redaction_salt: Arc<SecretString>,
335}
336
337impl RbacPolicy {
338 #[must_use]
341 pub fn new(config: &RbacConfig) -> Self {
342 let salt = config
343 .redaction_salt
344 .clone()
345 .unwrap_or_else(|| process_redaction_salt().clone());
346 Self {
347 roles: config.roles.clone(),
348 enabled: config.enabled,
349 redaction_salt: Arc::new(salt),
350 }
351 }
352
353 #[must_use]
355 pub fn disabled() -> Self {
356 Self {
357 roles: Vec::new(),
358 enabled: false,
359 redaction_salt: Arc::new(process_redaction_salt().clone()),
360 }
361 }
362
363 #[must_use]
365 pub fn is_enabled(&self) -> bool {
366 self.enabled
367 }
368
369 #[must_use]
374 pub fn summary(&self) -> RbacPolicySummary {
375 let roles = self
376 .roles
377 .iter()
378 .map(|r| RbacRoleSummary {
379 name: r.name.clone(),
380 allow: r.allow.len(),
381 deny: r.deny.len(),
382 hosts: r.hosts.len(),
383 argument_allowlists: r.argument_allowlists.len(),
384 })
385 .collect();
386 RbacPolicySummary {
387 enabled: self.enabled,
388 roles,
389 }
390 }
391
392 #[must_use]
397 pub fn check_operation(&self, role: &str, operation: &str) -> RbacDecision {
398 if !self.enabled {
399 return RbacDecision::Allow;
400 }
401 let Some(role_cfg) = self.find_role(role) else {
402 return RbacDecision::Deny;
403 };
404 if role_cfg.deny.iter().any(|d| d == operation) {
405 return RbacDecision::Deny;
406 }
407 if role_cfg.allow.iter().any(|a| a == "*" || a == operation) {
408 return RbacDecision::Allow;
409 }
410 RbacDecision::Deny
411 }
412
413 #[must_use]
420 pub fn check(&self, role: &str, operation: &str, host: &str) -> RbacDecision {
421 if !self.enabled {
422 return RbacDecision::Allow;
423 }
424 let Some(role_cfg) = self.find_role(role) else {
425 return RbacDecision::Deny;
426 };
427 if role_cfg.deny.iter().any(|d| d == operation) {
428 return RbacDecision::Deny;
429 }
430 if !role_cfg.allow.iter().any(|a| a == "*" || a == operation) {
431 return RbacDecision::Deny;
432 }
433 if !Self::host_matches(&role_cfg.hosts, host) {
434 return RbacDecision::Deny;
435 }
436 RbacDecision::Allow
437 }
438
439 #[must_use]
441 pub fn host_visible(&self, role: &str, host: &str) -> bool {
442 if !self.enabled {
443 return true;
444 }
445 let Some(role_cfg) = self.find_role(role) else {
446 return false;
447 };
448 Self::host_matches(&role_cfg.hosts, host)
449 }
450
451 #[must_use]
453 pub fn host_patterns(&self, role: &str) -> Option<&[String]> {
454 self.find_role(role).map(|r| r.hosts.as_slice())
455 }
456
457 #[must_use]
464 pub fn argument_allowed(&self, role: &str, tool: &str, argument: &str, value: &str) -> bool {
465 if !self.enabled {
466 return true;
467 }
468 let Some(role_cfg) = self.find_role(role) else {
469 return false;
470 };
471 for al in &role_cfg.argument_allowlists {
472 if al.tool != tool && !glob_match(&al.tool, tool) {
473 continue;
474 }
475 if al.argument != argument {
476 continue;
477 }
478 if al.allowed.is_empty() {
479 continue;
480 }
481 let first_token = value.split_whitespace().next().unwrap_or(value);
483 let basename = first_token.rsplit('/').next().unwrap_or(first_token);
485 if !al.allowed.iter().any(|a| a == first_token || a == basename) {
486 return false;
487 }
488 }
489 true
490 }
491
492 fn find_role(&self, name: &str) -> Option<&RoleConfig> {
494 self.roles.iter().find(|r| r.name == name)
495 }
496
497 fn host_matches(patterns: &[String], host: &str) -> bool {
499 patterns.iter().any(|p| glob_match(p, host))
500 }
501
502 #[must_use]
511 pub fn redact_arg(&self, value: &str) -> String {
512 redact_with_salt(self.redaction_salt.expose_secret().as_bytes(), value)
513 }
514}
515
516fn process_redaction_salt() -> &'static SecretString {
519 use base64::{Engine as _, engine::general_purpose::STANDARD_NO_PAD};
520 static PROCESS_SALT: std::sync::OnceLock<SecretString> = std::sync::OnceLock::new();
521 PROCESS_SALT.get_or_init(|| {
522 let mut bytes = [0u8; 32];
523 rand::fill(&mut bytes);
524 SecretString::from(STANDARD_NO_PAD.encode(bytes))
527 })
528}
529
530fn redact_with_salt(salt: &[u8], value: &str) -> String {
535 use std::fmt::Write as _;
536
537 use sha2::Digest as _;
538
539 type HmacSha256 = Hmac<Sha256>;
540 let mut mac = if let Ok(m) = HmacSha256::new_from_slice(salt) {
546 m
547 } else {
548 let digest = Sha256::digest(salt);
549 #[allow(clippy::expect_used)] HmacSha256::new_from_slice(&digest).expect("32-byte SHA256 digest is valid HMAC key")
551 };
552 mac.update(value.as_bytes());
553 let bytes = mac.finalize().into_bytes();
554 let prefix = bytes.get(..4).unwrap_or(&[0; 4]);
556 let mut out = String::with_capacity(8);
557 for b in prefix {
558 let _ = write!(out, "{b:02x}");
559 }
560 out
561}
562
563#[allow(clippy::too_many_lines)]
584pub(crate) async fn rbac_middleware(
585 policy: Arc<RbacPolicy>,
586 tool_limiter: Option<Arc<ToolRateLimiter>>,
587 req: Request<Body>,
588 next: Next,
589) -> Response {
590 if req.method() != Method::POST {
592 return next.run(req).await;
593 }
594
595 let peer_ip: Option<IpAddr> = req
597 .extensions()
598 .get::<ConnectInfo<std::net::SocketAddr>>()
599 .map(|ci| ci.0.ip())
600 .or_else(|| {
601 req.extensions()
602 .get::<ConnectInfo<TlsConnInfo>>()
603 .map(|ci| ci.0.addr.ip())
604 });
605
606 let identity = req.extensions().get::<AuthIdentity>();
608 let identity_name = identity.map(|id| id.name.clone()).unwrap_or_default();
609 let role = identity.map(|id| id.role.clone()).unwrap_or_default();
610 let raw_token: SecretString = identity
613 .and_then(|id| id.raw_token.clone())
614 .unwrap_or_else(|| SecretString::from(String::new()));
615 let sub = identity.and_then(|id| id.sub.clone()).unwrap_or_default();
616
617 if policy.is_enabled() && identity.is_none() {
619 return McpxError::Rbac("no authenticated identity".into()).into_response();
620 }
621
622 let (parts, body) = req.into_parts();
624 let bytes = match body.collect().await {
625 Ok(collected) => collected.to_bytes(),
626 Err(e) => {
627 tracing::error!(error = %e, "failed to read request body");
628 return (
629 StatusCode::INTERNAL_SERVER_ERROR,
630 "failed to read request body",
631 )
632 .into_response();
633 }
634 };
635
636 if let Ok(json) = serde_json::from_slice::<serde_json::Value>(&bytes) {
638 let tool_calls = extract_tool_calls(&json);
639 if !tool_calls.is_empty() {
640 for params in tool_calls {
641 if let Some(resp) = enforce_rate_limit(tool_limiter.as_deref(), peer_ip) {
642 return resp;
643 }
644 if policy.is_enabled()
645 && let Some(resp) = enforce_tool_policy(&policy, &identity_name, &role, params)
646 {
647 return resp;
648 }
649 }
650 }
651 }
652 let req = Request::from_parts(parts, Body::from(bytes));
656
657 if role.is_empty() {
659 next.run(req).await
660 } else {
661 CURRENT_ROLE
662 .scope(
663 role,
664 CURRENT_IDENTITY.scope(
665 identity_name,
666 CURRENT_TOKEN.scope(raw_token, CURRENT_SUB.scope(sub, next.run(req))),
667 ),
668 )
669 .await
670 }
671}
672
673fn extract_tool_calls(value: &serde_json::Value) -> Vec<&serde_json::Value> {
679 match value {
680 serde_json::Value::Object(map) => map
681 .get("method")
682 .and_then(serde_json::Value::as_str)
683 .filter(|method| *method == "tools/call")
684 .and_then(|_| map.get("params"))
685 .into_iter()
686 .collect(),
687 serde_json::Value::Array(items) => items
688 .iter()
689 .filter_map(|item| match item {
690 serde_json::Value::Object(map) => map
691 .get("method")
692 .and_then(serde_json::Value::as_str)
693 .filter(|method| *method == "tools/call")
694 .and_then(|_| map.get("params")),
695 serde_json::Value::Null
696 | serde_json::Value::Bool(_)
697 | serde_json::Value::Number(_)
698 | serde_json::Value::String(_)
699 | serde_json::Value::Array(_) => None,
700 })
701 .collect(),
702 serde_json::Value::Null
703 | serde_json::Value::Bool(_)
704 | serde_json::Value::Number(_)
705 | serde_json::Value::String(_) => Vec::new(),
706 }
707}
708
709fn enforce_rate_limit(
712 tool_limiter: Option<&ToolRateLimiter>,
713 peer_ip: Option<IpAddr>,
714) -> Option<Response> {
715 let limiter = tool_limiter?;
716 let ip = peer_ip?;
717 if limiter.check_key(&ip).is_err() {
718 tracing::warn!(%ip, "tool invocation rate limited");
719 return Some(McpxError::RateLimited("too many tool invocations".into()).into_response());
720 }
721 None
722}
723
724fn enforce_tool_policy(
733 policy: &RbacPolicy,
734 identity_name: &str,
735 role: &str,
736 params: &serde_json::Value,
737) -> Option<Response> {
738 let tool_name = params.get("name").and_then(|v| v.as_str()).unwrap_or("");
739 let host = params
740 .get("arguments")
741 .and_then(|a| a.get("host"))
742 .and_then(|h| h.as_str());
743
744 let decision = if let Some(host) = host {
745 policy.check(role, tool_name, host)
746 } else {
747 policy.check_operation(role, tool_name)
748 };
749 if decision == RbacDecision::Deny {
750 tracing::warn!(
751 user = %identity_name,
752 role = %role,
753 tool = tool_name,
754 host = host.unwrap_or("-"),
755 "RBAC denied"
756 );
757 return Some(
758 McpxError::Rbac(format!("{tool_name} denied for role '{role}'")).into_response(),
759 );
760 }
761
762 let args = params.get("arguments").and_then(|a| a.as_object())?;
763 for (arg_key, arg_val) in args {
764 if let Some(val_str) = arg_val.as_str()
765 && !policy.argument_allowed(role, tool_name, arg_key, val_str)
766 {
767 tracing::warn!(
772 user = %identity_name,
773 role = %role,
774 tool = tool_name,
775 argument = arg_key,
776 arg_hmac = %policy.redact_arg(val_str),
777 "argument not in allowlist"
778 );
779 return Some(
780 McpxError::Rbac(format!(
781 "argument '{arg_key}' value not in allowlist for tool '{tool_name}'"
782 ))
783 .into_response(),
784 );
785 }
786 }
787 None
788}
789
790fn glob_match(pattern: &str, text: &str) -> bool {
795 let parts: Vec<&str> = pattern.split('*').collect();
796 if parts.len() == 1 {
797 return pattern == text;
799 }
800
801 let mut pos = 0;
802
803 if let Some(&first) = parts.first()
805 && !first.is_empty()
806 {
807 if !text.starts_with(first) {
808 return false;
809 }
810 pos = first.len();
811 }
812
813 if let Some(&last) = parts.last()
815 && !last.is_empty()
816 {
817 if !text[pos..].ends_with(last) {
818 return false;
819 }
820 let end = text.len() - last.len();
822 if pos > end {
823 return false;
824 }
825 let middle = &text[pos..end];
827 let middle_parts = parts.get(1..parts.len() - 1).unwrap_or_default();
828 return match_middle(middle, middle_parts);
829 }
830
831 let middle = &text[pos..];
833 let middle_parts = parts.get(1..parts.len() - 1).unwrap_or_default();
834 match_middle(middle, middle_parts)
835}
836
837fn match_middle(mut text: &str, parts: &[&str]) -> bool {
839 for part in parts {
840 if part.is_empty() {
841 continue;
842 }
843 if let Some(idx) = text.find(part) {
844 text = &text[idx + part.len()..];
845 } else {
846 return false;
847 }
848 }
849 true
850}
851
852#[cfg(test)]
853mod tests {
854 use super::*;
855
856 fn test_policy() -> RbacPolicy {
857 RbacPolicy::new(&RbacConfig {
858 enabled: true,
859 roles: vec![
860 RoleConfig {
861 name: "viewer".into(),
862 description: Some("Read-only".into()),
863 allow: vec![
864 "list_hosts".into(),
865 "resource_list".into(),
866 "resource_inspect".into(),
867 "resource_logs".into(),
868 "system_info".into(),
869 ],
870 deny: vec![],
871 hosts: vec!["*".into()],
872 argument_allowlists: vec![],
873 },
874 RoleConfig {
875 name: "deploy".into(),
876 description: Some("Lifecycle management".into()),
877 allow: vec![
878 "list_hosts".into(),
879 "resource_list".into(),
880 "resource_run".into(),
881 "resource_start".into(),
882 "resource_stop".into(),
883 "resource_restart".into(),
884 "resource_logs".into(),
885 "image_pull".into(),
886 ],
887 deny: vec!["resource_delete".into(), "resource_exec".into()],
888 hosts: vec!["web-*".into(), "api-*".into()],
889 argument_allowlists: vec![],
890 },
891 RoleConfig {
892 name: "ops".into(),
893 description: Some("Full access".into()),
894 allow: vec!["*".into()],
895 deny: vec![],
896 hosts: vec!["*".into()],
897 argument_allowlists: vec![],
898 },
899 RoleConfig {
900 name: "restricted-exec".into(),
901 description: Some("Exec with argument allowlist".into()),
902 allow: vec!["resource_exec".into()],
903 deny: vec![],
904 hosts: vec!["dev-*".into()],
905 argument_allowlists: vec![ArgumentAllowlist {
906 tool: "resource_exec".into(),
907 argument: "cmd".into(),
908 allowed: vec![
909 "sh".into(),
910 "bash".into(),
911 "cat".into(),
912 "ls".into(),
913 "ps".into(),
914 ],
915 }],
916 },
917 ],
918 redaction_salt: None,
919 })
920 }
921
922 #[test]
925 fn glob_exact_match() {
926 assert!(glob_match("web-prod-1", "web-prod-1"));
927 assert!(!glob_match("web-prod-1", "web-prod-2"));
928 }
929
930 #[test]
931 fn glob_star_suffix() {
932 assert!(glob_match("web-*", "web-prod-1"));
933 assert!(glob_match("web-*", "web-staging"));
934 assert!(!glob_match("web-*", "api-prod"));
935 }
936
937 #[test]
938 fn glob_star_prefix() {
939 assert!(glob_match("*-prod", "web-prod"));
940 assert!(glob_match("*-prod", "api-prod"));
941 assert!(!glob_match("*-prod", "web-staging"));
942 }
943
944 #[test]
945 fn glob_star_middle() {
946 assert!(glob_match("web-*-prod", "web-us-prod"));
947 assert!(glob_match("web-*-prod", "web-eu-east-prod"));
948 assert!(!glob_match("web-*-prod", "web-staging"));
949 }
950
951 #[test]
952 fn glob_star_only() {
953 assert!(glob_match("*", "anything"));
954 assert!(glob_match("*", ""));
955 }
956
957 #[test]
958 fn glob_multiple_stars() {
959 assert!(glob_match("*web*prod*", "my-web-us-prod-1"));
960 assert!(!glob_match("*web*prod*", "my-api-us-staging"));
961 }
962
963 #[test]
966 fn disabled_policy_allows_everything() {
967 let policy = RbacPolicy::new(&RbacConfig {
968 enabled: false,
969 roles: vec![],
970 redaction_salt: None,
971 });
972 assert_eq!(
973 policy.check("nonexistent", "resource_delete", "any-host"),
974 RbacDecision::Allow
975 );
976 }
977
978 #[test]
979 fn unknown_role_denied() {
980 let policy = test_policy();
981 assert_eq!(
982 policy.check("unknown", "resource_list", "web-prod-1"),
983 RbacDecision::Deny
984 );
985 }
986
987 #[test]
988 fn viewer_allowed_read_ops() {
989 let policy = test_policy();
990 assert_eq!(
991 policy.check("viewer", "resource_list", "web-prod-1"),
992 RbacDecision::Allow
993 );
994 assert_eq!(
995 policy.check("viewer", "system_info", "db-host"),
996 RbacDecision::Allow
997 );
998 }
999
1000 #[test]
1001 fn viewer_denied_write_ops() {
1002 let policy = test_policy();
1003 assert_eq!(
1004 policy.check("viewer", "resource_run", "web-prod-1"),
1005 RbacDecision::Deny
1006 );
1007 assert_eq!(
1008 policy.check("viewer", "resource_delete", "web-prod-1"),
1009 RbacDecision::Deny
1010 );
1011 }
1012
1013 #[test]
1014 fn deploy_allowed_on_matching_hosts() {
1015 let policy = test_policy();
1016 assert_eq!(
1017 policy.check("deploy", "resource_run", "web-prod-1"),
1018 RbacDecision::Allow
1019 );
1020 assert_eq!(
1021 policy.check("deploy", "resource_start", "api-staging"),
1022 RbacDecision::Allow
1023 );
1024 }
1025
1026 #[test]
1027 fn deploy_denied_on_non_matching_host() {
1028 let policy = test_policy();
1029 assert_eq!(
1030 policy.check("deploy", "resource_run", "db-prod-1"),
1031 RbacDecision::Deny
1032 );
1033 }
1034
1035 #[test]
1036 fn deny_overrides_allow() {
1037 let policy = test_policy();
1038 assert_eq!(
1039 policy.check("deploy", "resource_delete", "web-prod-1"),
1040 RbacDecision::Deny
1041 );
1042 assert_eq!(
1043 policy.check("deploy", "resource_exec", "web-prod-1"),
1044 RbacDecision::Deny
1045 );
1046 }
1047
1048 #[test]
1049 fn ops_wildcard_allows_everything() {
1050 let policy = test_policy();
1051 assert_eq!(
1052 policy.check("ops", "resource_delete", "any-host"),
1053 RbacDecision::Allow
1054 );
1055 assert_eq!(
1056 policy.check("ops", "secret_create", "db-host"),
1057 RbacDecision::Allow
1058 );
1059 }
1060
1061 #[test]
1064 fn host_visible_respects_globs() {
1065 let policy = test_policy();
1066 assert!(policy.host_visible("deploy", "web-prod-1"));
1067 assert!(policy.host_visible("deploy", "api-staging"));
1068 assert!(!policy.host_visible("deploy", "db-prod-1"));
1069 assert!(policy.host_visible("ops", "anything"));
1070 assert!(policy.host_visible("viewer", "anything"));
1071 }
1072
1073 #[test]
1074 fn host_visible_unknown_role() {
1075 let policy = test_policy();
1076 assert!(!policy.host_visible("unknown", "web-prod-1"));
1077 }
1078
1079 #[test]
1082 fn argument_allowed_no_allowlist() {
1083 let policy = test_policy();
1084 assert!(policy.argument_allowed("ops", "resource_exec", "cmd", "rm -rf /"));
1086 assert!(policy.argument_allowed("ops", "resource_exec", "cmd", "bash"));
1087 }
1088
1089 #[test]
1090 fn argument_allowed_with_allowlist() {
1091 let policy = test_policy();
1092 assert!(policy.argument_allowed("restricted-exec", "resource_exec", "cmd", "sh"));
1093 assert!(policy.argument_allowed(
1094 "restricted-exec",
1095 "resource_exec",
1096 "cmd",
1097 "bash -c 'echo hi'"
1098 ));
1099 assert!(policy.argument_allowed(
1100 "restricted-exec",
1101 "resource_exec",
1102 "cmd",
1103 "cat /etc/hosts"
1104 ));
1105 assert!(policy.argument_allowed(
1106 "restricted-exec",
1107 "resource_exec",
1108 "cmd",
1109 "/usr/bin/ls -la"
1110 ));
1111 }
1112
1113 #[test]
1114 fn argument_denied_not_in_allowlist() {
1115 let policy = test_policy();
1116 assert!(!policy.argument_allowed("restricted-exec", "resource_exec", "cmd", "rm -rf /"));
1117 assert!(!policy.argument_allowed(
1118 "restricted-exec",
1119 "resource_exec",
1120 "cmd",
1121 "python3 exploit.py"
1122 ));
1123 assert!(!policy.argument_allowed(
1124 "restricted-exec",
1125 "resource_exec",
1126 "cmd",
1127 "/usr/bin/curl evil.com"
1128 ));
1129 }
1130
1131 #[test]
1132 fn argument_denied_unknown_role() {
1133 let policy = test_policy();
1134 assert!(!policy.argument_allowed("unknown", "resource_exec", "cmd", "sh"));
1135 }
1136
1137 #[test]
1140 fn host_patterns_returns_globs() {
1141 let policy = test_policy();
1142 assert_eq!(
1143 policy.host_patterns("deploy"),
1144 Some(vec!["web-*".to_owned(), "api-*".to_owned()].as_slice())
1145 );
1146 assert_eq!(
1147 policy.host_patterns("ops"),
1148 Some(vec!["*".to_owned()].as_slice())
1149 );
1150 assert!(policy.host_patterns("nonexistent").is_none());
1151 }
1152
1153 #[test]
1156 fn check_operation_allows_without_host() {
1157 let policy = test_policy();
1158 assert_eq!(
1159 policy.check_operation("deploy", "resource_run"),
1160 RbacDecision::Allow
1161 );
1162 assert_eq!(
1164 policy.check("deploy", "resource_run", "db-prod-1"),
1165 RbacDecision::Deny
1166 );
1167 }
1168
1169 #[test]
1170 fn check_operation_deny_overrides() {
1171 let policy = test_policy();
1172 assert_eq!(
1173 policy.check_operation("deploy", "resource_delete"),
1174 RbacDecision::Deny
1175 );
1176 }
1177
1178 #[test]
1179 fn check_operation_unknown_role() {
1180 let policy = test_policy();
1181 assert_eq!(
1182 policy.check_operation("unknown", "resource_list"),
1183 RbacDecision::Deny
1184 );
1185 }
1186
1187 #[test]
1188 fn check_operation_disabled() {
1189 let policy = RbacPolicy::new(&RbacConfig {
1190 enabled: false,
1191 roles: vec![],
1192 redaction_salt: None,
1193 });
1194 assert_eq!(
1195 policy.check_operation("nonexistent", "anything"),
1196 RbacDecision::Allow
1197 );
1198 }
1199
1200 #[test]
1203 fn current_role_returns_none_outside_scope() {
1204 assert!(current_role().is_none());
1205 }
1206
1207 #[test]
1208 fn current_identity_returns_none_outside_scope() {
1209 assert!(current_identity().is_none());
1210 }
1211
1212 use axum::{
1215 body::Body,
1216 http::{Method, Request, StatusCode},
1217 };
1218 use tower::ServiceExt as _;
1219
1220 fn tool_call_body(tool: &str, args: &serde_json::Value) -> String {
1221 serde_json::json!({
1222 "jsonrpc": "2.0",
1223 "id": 1,
1224 "method": "tools/call",
1225 "params": {
1226 "name": tool,
1227 "arguments": args
1228 }
1229 })
1230 .to_string()
1231 }
1232
1233 fn rbac_router(policy: Arc<RbacPolicy>) -> axum::Router {
1234 axum::Router::new()
1235 .route("/mcp", axum::routing::post(|| async { "ok" }))
1236 .layer(axum::middleware::from_fn(move |req, next| {
1237 let p = Arc::clone(&policy);
1238 rbac_middleware(p, None, req, next)
1239 }))
1240 }
1241
1242 fn rbac_router_with_identity(policy: Arc<RbacPolicy>, identity: AuthIdentity) -> axum::Router {
1243 axum::Router::new()
1244 .route("/mcp", axum::routing::post(|| async { "ok" }))
1245 .layer(axum::middleware::from_fn(
1246 move |mut req: Request<Body>, next: Next| {
1247 let p = Arc::clone(&policy);
1248 let id = identity.clone();
1249 async move {
1250 req.extensions_mut().insert(id);
1251 rbac_middleware(p, None, req, next).await
1252 }
1253 },
1254 ))
1255 }
1256
1257 #[tokio::test]
1258 async fn middleware_passes_non_post() {
1259 let policy = Arc::new(test_policy());
1260 let app = rbac_router(policy);
1261 let req = Request::builder()
1263 .method(Method::GET)
1264 .uri("/mcp")
1265 .body(Body::empty())
1266 .unwrap();
1267 let resp = app.oneshot(req).await.unwrap();
1270 assert_eq!(resp.status(), StatusCode::METHOD_NOT_ALLOWED);
1271 }
1272
1273 #[tokio::test]
1274 async fn middleware_denies_without_identity() {
1275 let policy = Arc::new(test_policy());
1276 let app = rbac_router(policy);
1277 let body = tool_call_body("resource_list", &serde_json::json!({}));
1278 let req = Request::builder()
1279 .method(Method::POST)
1280 .uri("/mcp")
1281 .header("content-type", "application/json")
1282 .body(Body::from(body))
1283 .unwrap();
1284 let resp = app.oneshot(req).await.unwrap();
1285 assert_eq!(resp.status(), StatusCode::FORBIDDEN);
1286 }
1287
1288 #[tokio::test]
1289 async fn middleware_allows_permitted_tool() {
1290 let policy = Arc::new(test_policy());
1291 let id = AuthIdentity {
1292 method: crate::auth::AuthMethod::BearerToken,
1293 name: "alice".into(),
1294 role: "viewer".into(),
1295 raw_token: None,
1296 sub: None,
1297 };
1298 let app = rbac_router_with_identity(policy, id);
1299 let body = tool_call_body("resource_list", &serde_json::json!({}));
1300 let req = Request::builder()
1301 .method(Method::POST)
1302 .uri("/mcp")
1303 .header("content-type", "application/json")
1304 .body(Body::from(body))
1305 .unwrap();
1306 let resp = app.oneshot(req).await.unwrap();
1307 assert_eq!(resp.status(), StatusCode::OK);
1308 }
1309
1310 #[tokio::test]
1311 async fn middleware_denies_unpermitted_tool() {
1312 let policy = Arc::new(test_policy());
1313 let id = AuthIdentity {
1314 method: crate::auth::AuthMethod::BearerToken,
1315 name: "alice".into(),
1316 role: "viewer".into(),
1317 raw_token: None,
1318 sub: None,
1319 };
1320 let app = rbac_router_with_identity(policy, id);
1321 let body = tool_call_body("resource_delete", &serde_json::json!({}));
1322 let req = Request::builder()
1323 .method(Method::POST)
1324 .uri("/mcp")
1325 .header("content-type", "application/json")
1326 .body(Body::from(body))
1327 .unwrap();
1328 let resp = app.oneshot(req).await.unwrap();
1329 assert_eq!(resp.status(), StatusCode::FORBIDDEN);
1330 }
1331
1332 #[tokio::test]
1333 async fn middleware_passes_non_tool_call_post() {
1334 let policy = Arc::new(test_policy());
1335 let id = AuthIdentity {
1336 method: crate::auth::AuthMethod::BearerToken,
1337 name: "alice".into(),
1338 role: "viewer".into(),
1339 raw_token: None,
1340 sub: None,
1341 };
1342 let app = rbac_router_with_identity(policy, id);
1343 let body = serde_json::json!({
1345 "jsonrpc": "2.0",
1346 "id": 1,
1347 "method": "resources/list"
1348 })
1349 .to_string();
1350 let req = Request::builder()
1351 .method(Method::POST)
1352 .uri("/mcp")
1353 .header("content-type", "application/json")
1354 .body(Body::from(body))
1355 .unwrap();
1356 let resp = app.oneshot(req).await.unwrap();
1357 assert_eq!(resp.status(), StatusCode::OK);
1358 }
1359
1360 #[tokio::test]
1361 async fn middleware_enforces_argument_allowlist() {
1362 let policy = Arc::new(test_policy());
1363 let id = AuthIdentity {
1364 method: crate::auth::AuthMethod::BearerToken,
1365 name: "dev".into(),
1366 role: "restricted-exec".into(),
1367 raw_token: None,
1368 sub: None,
1369 };
1370 let app = rbac_router_with_identity(Arc::clone(&policy), id.clone());
1372 let body = tool_call_body(
1373 "resource_exec",
1374 &serde_json::json!({"cmd": "ls -la", "host": "dev-1"}),
1375 );
1376 let req = Request::builder()
1377 .method(Method::POST)
1378 .uri("/mcp")
1379 .body(Body::from(body))
1380 .unwrap();
1381 let resp = app.oneshot(req).await.unwrap();
1382 assert_eq!(resp.status(), StatusCode::OK);
1383
1384 let app = rbac_router_with_identity(policy, id);
1386 let body = tool_call_body(
1387 "resource_exec",
1388 &serde_json::json!({"cmd": "rm -rf /", "host": "dev-1"}),
1389 );
1390 let req = Request::builder()
1391 .method(Method::POST)
1392 .uri("/mcp")
1393 .body(Body::from(body))
1394 .unwrap();
1395 let resp = app.oneshot(req).await.unwrap();
1396 assert_eq!(resp.status(), StatusCode::FORBIDDEN);
1397 }
1398
1399 #[tokio::test]
1400 async fn middleware_disabled_policy_passes_everything() {
1401 let policy = Arc::new(RbacPolicy::disabled());
1402 let app = rbac_router(policy);
1403 let body = tool_call_body("anything", &serde_json::json!({}));
1405 let req = Request::builder()
1406 .method(Method::POST)
1407 .uri("/mcp")
1408 .body(Body::from(body))
1409 .unwrap();
1410 let resp = app.oneshot(req).await.unwrap();
1411 assert_eq!(resp.status(), StatusCode::OK);
1412 }
1413
1414 #[tokio::test]
1415 async fn middleware_batch_all_allowed_passes() {
1416 let policy = Arc::new(test_policy());
1417 let id = AuthIdentity {
1418 method: crate::auth::AuthMethod::BearerToken,
1419 name: "alice".into(),
1420 role: "viewer".into(),
1421 raw_token: None,
1422 sub: None,
1423 };
1424 let app = rbac_router_with_identity(policy, id);
1425 let body = serde_json::json!([
1426 {
1427 "jsonrpc": "2.0",
1428 "id": 1,
1429 "method": "tools/call",
1430 "params": { "name": "resource_list", "arguments": {} }
1431 },
1432 {
1433 "jsonrpc": "2.0",
1434 "id": 2,
1435 "method": "tools/call",
1436 "params": { "name": "system_info", "arguments": {} }
1437 }
1438 ])
1439 .to_string();
1440 let req = Request::builder()
1441 .method(Method::POST)
1442 .uri("/mcp")
1443 .header("content-type", "application/json")
1444 .body(Body::from(body))
1445 .unwrap();
1446 let resp = app.oneshot(req).await.unwrap();
1447 assert_eq!(resp.status(), StatusCode::OK);
1448 }
1449
1450 #[tokio::test]
1451 async fn middleware_batch_with_denied_call_rejects_entire_batch() {
1452 let policy = Arc::new(test_policy());
1453 let id = AuthIdentity {
1454 method: crate::auth::AuthMethod::BearerToken,
1455 name: "alice".into(),
1456 role: "viewer".into(),
1457 raw_token: None,
1458 sub: None,
1459 };
1460 let app = rbac_router_with_identity(policy, id);
1461 let body = serde_json::json!([
1462 {
1463 "jsonrpc": "2.0",
1464 "id": 1,
1465 "method": "tools/call",
1466 "params": { "name": "resource_list", "arguments": {} }
1467 },
1468 {
1469 "jsonrpc": "2.0",
1470 "id": 2,
1471 "method": "tools/call",
1472 "params": { "name": "resource_delete", "arguments": {} }
1473 }
1474 ])
1475 .to_string();
1476 let req = Request::builder()
1477 .method(Method::POST)
1478 .uri("/mcp")
1479 .header("content-type", "application/json")
1480 .body(Body::from(body))
1481 .unwrap();
1482 let resp = app.oneshot(req).await.unwrap();
1483 assert_eq!(resp.status(), StatusCode::FORBIDDEN);
1484 }
1485
1486 #[tokio::test]
1487 async fn middleware_batch_mixed_allowed_and_denied_rejects() {
1488 let policy = Arc::new(test_policy());
1489 let id = AuthIdentity {
1490 method: crate::auth::AuthMethod::BearerToken,
1491 name: "dev".into(),
1492 role: "restricted-exec".into(),
1493 raw_token: None,
1494 sub: None,
1495 };
1496 let app = rbac_router_with_identity(policy, id);
1497 let body = serde_json::json!([
1498 {
1499 "jsonrpc": "2.0",
1500 "id": 1,
1501 "method": "tools/call",
1502 "params": {
1503 "name": "resource_exec",
1504 "arguments": { "cmd": "ls -la", "host": "dev-1" }
1505 }
1506 },
1507 {
1508 "jsonrpc": "2.0",
1509 "id": 2,
1510 "method": "tools/call",
1511 "params": {
1512 "name": "resource_exec",
1513 "arguments": { "cmd": "rm -rf /", "host": "dev-1" }
1514 }
1515 }
1516 ])
1517 .to_string();
1518 let req = Request::builder()
1519 .method(Method::POST)
1520 .uri("/mcp")
1521 .header("content-type", "application/json")
1522 .body(Body::from(body))
1523 .unwrap();
1524 let resp = app.oneshot(req).await.unwrap();
1525 assert_eq!(resp.status(), StatusCode::FORBIDDEN);
1526 }
1527
1528 #[test]
1531 fn redact_with_salt_is_deterministic_per_salt() {
1532 let salt = b"unit-test-salt";
1533 let a = redact_with_salt(salt, "rm -rf /");
1534 let b = redact_with_salt(salt, "rm -rf /");
1535 assert_eq!(a, b, "same input + salt must yield identical hash");
1536 assert_eq!(a.len(), 8, "redacted hash is 8 hex chars (4 bytes)");
1537 assert!(
1538 a.chars().all(|c| c.is_ascii_hexdigit()),
1539 "redacted hash must be lowercase hex: {a}"
1540 );
1541 }
1542
1543 #[test]
1544 fn redact_with_salt_differs_across_salts() {
1545 let v = "the-same-value";
1546 let h1 = redact_with_salt(b"salt-one", v);
1547 let h2 = redact_with_salt(b"salt-two", v);
1548 assert_ne!(
1549 h1, h2,
1550 "different salts must produce different hashes for the same value"
1551 );
1552 }
1553
1554 #[test]
1555 fn redact_with_salt_distinguishes_values() {
1556 let salt = b"k";
1557 let h1 = redact_with_salt(salt, "alpha");
1558 let h2 = redact_with_salt(salt, "beta");
1559 assert_ne!(h1, h2, "different values must produce different hashes");
1561 }
1562
1563 #[test]
1564 fn policy_with_configured_salt_redacts_consistently() {
1565 let cfg = RbacConfig {
1566 enabled: true,
1567 roles: vec![],
1568 redaction_salt: Some(SecretString::from("my-stable-salt")),
1569 };
1570 let p1 = RbacPolicy::new(&cfg);
1571 let p2 = RbacPolicy::new(&cfg);
1572 assert_eq!(
1573 p1.redact_arg("payload"),
1574 p2.redact_arg("payload"),
1575 "policies built from the same configured salt must agree"
1576 );
1577 }
1578
1579 #[test]
1580 fn policy_without_configured_salt_uses_process_salt() {
1581 let cfg = RbacConfig {
1582 enabled: true,
1583 roles: vec![],
1584 redaction_salt: None,
1585 };
1586 let p1 = RbacPolicy::new(&cfg);
1587 let p2 = RbacPolicy::new(&cfg);
1588 assert_eq!(
1590 p1.redact_arg("payload"),
1591 p2.redact_arg("payload"),
1592 "process-wide salt must be consistent within one process"
1593 );
1594 }
1595
1596 #[test]
1597 fn redact_arg_is_fast_enough() {
1598 let salt = b"perf-sanity-salt-32-bytes-padded";
1602 let value = "x".repeat(256);
1603 let start = std::time::Instant::now();
1604 let _ = redact_with_salt(salt, &value);
1605 let elapsed = start.elapsed();
1606 assert!(
1607 elapsed < Duration::from_millis(5),
1608 "single redact_with_salt took {elapsed:?}, expected <5 ms even in debug"
1609 );
1610 }
1611
1612 #[tokio::test]
1624 async fn deny_path_uses_explicit_identity_not_task_local() {
1625 let policy = Arc::new(test_policy());
1626 let id = AuthIdentity {
1627 method: crate::auth::AuthMethod::BearerToken,
1628 name: "alice-the-auditor".into(),
1629 role: "viewer".into(),
1630 raw_token: None,
1631 sub: None,
1632 };
1633 let app = rbac_router_with_identity(policy, id);
1634 let body = tool_call_body("resource_delete", &serde_json::json!({}));
1636 let req = Request::builder()
1637 .method(Method::POST)
1638 .uri("/mcp")
1639 .header("content-type", "application/json")
1640 .body(Body::from(body))
1641 .unwrap();
1642 let resp = app.oneshot(req).await.unwrap();
1643 assert_eq!(resp.status(), StatusCode::FORBIDDEN);
1644 }
1645}