1use std::{net::IpAddr, num::NonZeroU32, sync::Arc, time::Duration};
12
13use axum::{
14 body::Body,
15 extract::ConnectInfo,
16 http::{Method, Request, StatusCode},
17 middleware::Next,
18 response::{IntoResponse, Response},
19};
20use hmac::{Hmac, Mac};
21use http_body_util::BodyExt;
22use secrecy::{ExposeSecret, SecretString};
23use serde::Deserialize;
24use sha2::Sha256;
25
26use crate::{
27 auth::{AuthIdentity, TlsConnInfo},
28 bounded_limiter::BoundedKeyedLimiter,
29 error::McpxError,
30};
31
32pub(crate) type ToolRateLimiter = BoundedKeyedLimiter<IpAddr>;
35
36const DEFAULT_TOOL_RATE: NonZeroU32 = NonZeroU32::new(120).unwrap();
39
40const DEFAULT_TOOL_MAX_TRACKED_KEYS: usize = 10_000;
43
44const DEFAULT_TOOL_IDLE_EVICTION: Duration = Duration::from_mins(15);
46
47#[must_use]
53pub(crate) fn build_tool_rate_limiter(max_per_minute: u32) -> Arc<ToolRateLimiter> {
54 build_tool_rate_limiter_with_bounds(
55 max_per_minute,
56 DEFAULT_TOOL_MAX_TRACKED_KEYS,
57 DEFAULT_TOOL_IDLE_EVICTION,
58 )
59}
60
61#[must_use]
63pub(crate) fn build_tool_rate_limiter_with_bounds(
64 max_per_minute: u32,
65 max_tracked_keys: usize,
66 idle_eviction: Duration,
67) -> Arc<ToolRateLimiter> {
68 let quota =
69 governor::Quota::per_minute(NonZeroU32::new(max_per_minute).unwrap_or(DEFAULT_TOOL_RATE));
70 Arc::new(BoundedKeyedLimiter::new(
71 quota,
72 max_tracked_keys,
73 idle_eviction,
74 ))
75}
76
77tokio::task_local! {
84 static CURRENT_ROLE: String;
85 static CURRENT_IDENTITY: String;
86 static CURRENT_TOKEN: SecretString;
87 static CURRENT_SUB: String;
88}
89
90#[must_use]
93pub fn current_role() -> Option<String> {
94 CURRENT_ROLE.try_with(Clone::clone).ok()
95}
96
97#[must_use]
100pub fn current_identity() -> Option<String> {
101 CURRENT_IDENTITY.try_with(Clone::clone).ok()
102}
103
104#[must_use]
117pub fn current_token() -> Option<SecretString> {
118 CURRENT_TOKEN
119 .try_with(|t| {
120 if t.expose_secret().is_empty() {
121 None
122 } else {
123 Some(t.clone())
124 }
125 })
126 .ok()
127 .flatten()
128}
129
130#[must_use]
134pub fn current_sub() -> Option<String> {
135 CURRENT_SUB
136 .try_with(Clone::clone)
137 .ok()
138 .filter(|s| !s.is_empty())
139}
140
141pub async fn with_token_scope<F: Future>(token: SecretString, f: F) -> F::Output {
146 CURRENT_TOKEN.scope(token, f).await
147}
148
149pub async fn with_rbac_scope<F: Future>(
154 role: String,
155 identity: String,
156 token: SecretString,
157 sub: String,
158 f: F,
159) -> F::Output {
160 CURRENT_ROLE
161 .scope(
162 role,
163 CURRENT_IDENTITY.scope(
164 identity,
165 CURRENT_TOKEN.scope(token, CURRENT_SUB.scope(sub, f)),
166 ),
167 )
168 .await
169}
170
171#[derive(Debug, Clone, Deserialize)]
173#[non_exhaustive]
174pub struct RoleConfig {
175 pub name: String,
177 #[serde(default)]
179 pub description: Option<String>,
180 #[serde(default)]
182 pub allow: Vec<String>,
183 #[serde(default)]
185 pub deny: Vec<String>,
186 #[serde(default = "default_hosts")]
188 pub hosts: Vec<String>,
189 #[serde(default)]
193 pub argument_allowlists: Vec<ArgumentAllowlist>,
194}
195
196impl RoleConfig {
197 #[must_use]
199 pub fn new(name: impl Into<String>, allow: Vec<String>, hosts: Vec<String>) -> Self {
200 Self {
201 name: name.into(),
202 description: None,
203 allow,
204 deny: vec![],
205 hosts,
206 argument_allowlists: vec![],
207 }
208 }
209
210 #[must_use]
212 pub fn with_argument_allowlists(mut self, allowlists: Vec<ArgumentAllowlist>) -> Self {
213 self.argument_allowlists = allowlists;
214 self
215 }
216}
217
218#[derive(Debug, Clone, Deserialize)]
225#[non_exhaustive]
226pub struct ArgumentAllowlist {
227 pub tool: String,
229 pub argument: String,
231 #[serde(default)]
233 pub allowed: Vec<String>,
234}
235
236impl ArgumentAllowlist {
237 #[must_use]
239 pub fn new(tool: impl Into<String>, argument: impl Into<String>, allowed: Vec<String>) -> Self {
240 Self {
241 tool: tool.into(),
242 argument: argument.into(),
243 allowed,
244 }
245 }
246}
247
248fn default_hosts() -> Vec<String> {
249 vec!["*".into()]
250}
251
252#[derive(Debug, Clone, Default, Deserialize)]
254#[non_exhaustive]
255pub struct RbacConfig {
256 #[serde(default)]
258 pub enabled: bool,
259 #[serde(default)]
261 pub roles: Vec<RoleConfig>,
262 #[serde(default)]
271 pub redaction_salt: Option<SecretString>,
272}
273
274impl RbacConfig {
275 #[must_use]
277 pub fn with_roles(roles: Vec<RoleConfig>) -> Self {
278 Self {
279 enabled: true,
280 roles,
281 redaction_salt: None,
282 }
283 }
284}
285
286#[derive(Debug, Clone, Copy, PartialEq, Eq)]
288#[non_exhaustive]
289pub enum RbacDecision {
290 Allow,
292 Deny,
294}
295
296#[derive(Debug, Clone, serde::Serialize)]
298#[non_exhaustive]
299pub struct RbacRoleSummary {
300 pub name: String,
302 pub allow: usize,
304 pub deny: usize,
306 pub hosts: usize,
308 pub argument_allowlists: usize,
310}
311
312#[derive(Debug, Clone, serde::Serialize)]
314#[non_exhaustive]
315pub struct RbacPolicySummary {
316 pub enabled: bool,
318 pub roles: Vec<RbacRoleSummary>,
320}
321
322#[derive(Debug, Clone)]
328#[non_exhaustive]
329pub struct RbacPolicy {
330 roles: Vec<RoleConfig>,
331 enabled: bool,
332 redaction_salt: Arc<SecretString>,
335}
336
337impl RbacPolicy {
338 #[must_use]
341 pub fn new(config: &RbacConfig) -> Self {
342 let salt = config
343 .redaction_salt
344 .clone()
345 .unwrap_or_else(|| process_redaction_salt().clone());
346 Self {
347 roles: config.roles.clone(),
348 enabled: config.enabled,
349 redaction_salt: Arc::new(salt),
350 }
351 }
352
353 #[must_use]
355 pub fn disabled() -> Self {
356 Self {
357 roles: Vec::new(),
358 enabled: false,
359 redaction_salt: Arc::new(process_redaction_salt().clone()),
360 }
361 }
362
363 #[must_use]
365 pub fn is_enabled(&self) -> bool {
366 self.enabled
367 }
368
369 #[must_use]
374 pub fn summary(&self) -> RbacPolicySummary {
375 let roles = self
376 .roles
377 .iter()
378 .map(|r| RbacRoleSummary {
379 name: r.name.clone(),
380 allow: r.allow.len(),
381 deny: r.deny.len(),
382 hosts: r.hosts.len(),
383 argument_allowlists: r.argument_allowlists.len(),
384 })
385 .collect();
386 RbacPolicySummary {
387 enabled: self.enabled,
388 roles,
389 }
390 }
391
392 #[must_use]
397 pub fn check_operation(&self, role: &str, operation: &str) -> RbacDecision {
398 if !self.enabled {
399 return RbacDecision::Allow;
400 }
401 let Some(role_cfg) = self.find_role(role) else {
402 return RbacDecision::Deny;
403 };
404 if role_cfg.deny.iter().any(|d| d == operation) {
405 return RbacDecision::Deny;
406 }
407 if role_cfg.allow.iter().any(|a| a == "*" || a == operation) {
408 return RbacDecision::Allow;
409 }
410 RbacDecision::Deny
411 }
412
413 #[must_use]
420 pub fn check(&self, role: &str, operation: &str, host: &str) -> RbacDecision {
421 if !self.enabled {
422 return RbacDecision::Allow;
423 }
424 let Some(role_cfg) = self.find_role(role) else {
425 return RbacDecision::Deny;
426 };
427 if role_cfg.deny.iter().any(|d| d == operation) {
428 return RbacDecision::Deny;
429 }
430 if !role_cfg.allow.iter().any(|a| a == "*" || a == operation) {
431 return RbacDecision::Deny;
432 }
433 if !Self::host_matches(&role_cfg.hosts, host) {
434 return RbacDecision::Deny;
435 }
436 RbacDecision::Allow
437 }
438
439 #[must_use]
441 pub fn host_visible(&self, role: &str, host: &str) -> bool {
442 if !self.enabled {
443 return true;
444 }
445 let Some(role_cfg) = self.find_role(role) else {
446 return false;
447 };
448 Self::host_matches(&role_cfg.hosts, host)
449 }
450
451 #[must_use]
453 pub fn host_patterns(&self, role: &str) -> Option<&[String]> {
454 self.find_role(role).map(|r| r.hosts.as_slice())
455 }
456
457 #[must_use]
486 pub fn argument_allowed(&self, role: &str, tool: &str, argument: &str, value: &str) -> bool {
487 if !self.enabled {
488 return true;
489 }
490 let Some(role_cfg) = self.find_role(role) else {
491 return false;
492 };
493 for al in &role_cfg.argument_allowlists {
494 if al.tool != tool && !glob_match(&al.tool, tool) {
495 continue;
496 }
497 if al.argument != argument {
498 continue;
499 }
500 if al.allowed.is_empty() {
501 continue;
502 }
503 let Some(tokens) = shlex::split(value) else {
508 return false;
509 };
510 let Some(first_token) = tokens.first() else {
511 return false;
512 };
513 if first_token.is_empty() {
517 return false;
518 }
519 let basename = first_token
523 .rsplit('/')
524 .next()
525 .unwrap_or(first_token.as_str());
526 if !al.allowed.iter().any(|a| a == first_token || a == basename) {
527 return false;
528 }
529 }
530 true
531 }
532
533 fn find_role(&self, name: &str) -> Option<&RoleConfig> {
535 self.roles.iter().find(|r| r.name == name)
536 }
537
538 fn host_matches(patterns: &[String], host: &str) -> bool {
540 patterns.iter().any(|p| glob_match(p, host))
541 }
542
543 #[must_use]
552 pub fn redact_arg(&self, value: &str) -> String {
553 redact_with_salt(self.redaction_salt.expose_secret().as_bytes(), value)
554 }
555}
556
557fn process_redaction_salt() -> &'static SecretString {
560 use base64::{Engine as _, engine::general_purpose::STANDARD_NO_PAD};
561 static PROCESS_SALT: std::sync::OnceLock<SecretString> = std::sync::OnceLock::new();
562 PROCESS_SALT.get_or_init(|| {
563 let mut bytes = [0u8; 32];
564 rand::fill(&mut bytes);
565 SecretString::from(STANDARD_NO_PAD.encode(bytes))
568 })
569}
570
571fn redact_with_salt(salt: &[u8], value: &str) -> String {
576 use std::fmt::Write as _;
577
578 use sha2::Digest as _;
579
580 type HmacSha256 = Hmac<Sha256>;
581 let mut mac = if let Ok(m) = HmacSha256::new_from_slice(salt) {
587 m
588 } else {
589 let digest = Sha256::digest(salt);
590 #[allow(clippy::expect_used)] HmacSha256::new_from_slice(&digest).expect("32-byte SHA256 digest is valid HMAC key")
592 };
593 mac.update(value.as_bytes());
594 let bytes = mac.finalize().into_bytes();
595 let prefix = bytes.get(..4).unwrap_or(&[0; 4]);
597 let mut out = String::with_capacity(8);
598 for b in prefix {
599 let _ = write!(out, "{b:02x}");
600 }
601 out
602}
603
604#[allow(clippy::too_many_lines)]
625pub(crate) async fn rbac_middleware(
626 policy: Arc<RbacPolicy>,
627 tool_limiter: Option<Arc<ToolRateLimiter>>,
628 req: Request<Body>,
629 next: Next,
630) -> Response {
631 if req.method() != Method::POST {
633 return next.run(req).await;
634 }
635
636 let peer_ip: Option<IpAddr> = req
638 .extensions()
639 .get::<ConnectInfo<std::net::SocketAddr>>()
640 .map(|ci| ci.0.ip())
641 .or_else(|| {
642 req.extensions()
643 .get::<ConnectInfo<TlsConnInfo>>()
644 .map(|ci| ci.0.addr.ip())
645 });
646
647 let identity = req.extensions().get::<AuthIdentity>();
649 let identity_name = identity.map(|id| id.name.clone()).unwrap_or_default();
650 let role = identity.map(|id| id.role.clone()).unwrap_or_default();
651 let raw_token: SecretString = identity
654 .and_then(|id| id.raw_token.clone())
655 .unwrap_or_else(|| SecretString::from(String::new()));
656 let sub = identity.and_then(|id| id.sub.clone()).unwrap_or_default();
657
658 if policy.is_enabled() && identity.is_none() {
660 return McpxError::Rbac("no authenticated identity".into()).into_response();
661 }
662
663 let (parts, body) = req.into_parts();
665 let bytes = match body.collect().await {
666 Ok(collected) => collected.to_bytes(),
667 Err(e) => {
668 tracing::error!(error = %e, "failed to read request body");
669 return (
670 StatusCode::INTERNAL_SERVER_ERROR,
671 "failed to read request body",
672 )
673 .into_response();
674 }
675 };
676
677 if let Ok(json) = serde_json::from_slice::<serde_json::Value>(&bytes) {
679 let tool_calls = extract_tool_calls(&json);
680 if !tool_calls.is_empty() {
681 for params in tool_calls {
682 if let Some(resp) = enforce_rate_limit(tool_limiter.as_deref(), peer_ip) {
683 return resp;
684 }
685 if policy.is_enabled()
686 && let Some(resp) = enforce_tool_policy(&policy, &identity_name, &role, params)
687 {
688 return resp;
689 }
690 }
691 }
692 }
693 let req = Request::from_parts(parts, Body::from(bytes));
697
698 if role.is_empty() {
700 next.run(req).await
701 } else {
702 CURRENT_ROLE
703 .scope(
704 role,
705 CURRENT_IDENTITY.scope(
706 identity_name,
707 CURRENT_TOKEN.scope(raw_token, CURRENT_SUB.scope(sub, next.run(req))),
708 ),
709 )
710 .await
711 }
712}
713
714fn extract_tool_calls(value: &serde_json::Value) -> Vec<&serde_json::Value> {
720 match value {
721 serde_json::Value::Object(map) => map
722 .get("method")
723 .and_then(serde_json::Value::as_str)
724 .filter(|method| *method == "tools/call")
725 .and_then(|_| map.get("params"))
726 .into_iter()
727 .collect(),
728 serde_json::Value::Array(items) => items
729 .iter()
730 .filter_map(|item| match item {
731 serde_json::Value::Object(map) => map
732 .get("method")
733 .and_then(serde_json::Value::as_str)
734 .filter(|method| *method == "tools/call")
735 .and_then(|_| map.get("params")),
736 serde_json::Value::Null
737 | serde_json::Value::Bool(_)
738 | serde_json::Value::Number(_)
739 | serde_json::Value::String(_)
740 | serde_json::Value::Array(_) => None,
741 })
742 .collect(),
743 serde_json::Value::Null
744 | serde_json::Value::Bool(_)
745 | serde_json::Value::Number(_)
746 | serde_json::Value::String(_) => Vec::new(),
747 }
748}
749
750fn enforce_rate_limit(
753 tool_limiter: Option<&ToolRateLimiter>,
754 peer_ip: Option<IpAddr>,
755) -> Option<Response> {
756 let limiter = tool_limiter?;
757 let ip = peer_ip?;
758 if limiter.check_key(&ip).is_err() {
759 tracing::warn!(%ip, "tool invocation rate limited");
760 return Some(McpxError::RateLimited("too many tool invocations".into()).into_response());
761 }
762 None
763}
764
765fn enforce_tool_policy(
774 policy: &RbacPolicy,
775 identity_name: &str,
776 role: &str,
777 params: &serde_json::Value,
778) -> Option<Response> {
779 let tool_name = params.get("name").and_then(|v| v.as_str()).unwrap_or("");
780 let host = params
781 .get("arguments")
782 .and_then(|a| a.get("host"))
783 .and_then(|h| h.as_str());
784
785 let decision = if let Some(host) = host {
786 policy.check(role, tool_name, host)
787 } else {
788 policy.check_operation(role, tool_name)
789 };
790 if decision == RbacDecision::Deny {
791 tracing::warn!(
792 user = %identity_name,
793 role = %role,
794 tool = tool_name,
795 host = host.unwrap_or("-"),
796 "RBAC denied"
797 );
798 return Some(
799 McpxError::Rbac(format!("{tool_name} denied for role '{role}'")).into_response(),
800 );
801 }
802
803 let args = params.get("arguments").and_then(|a| a.as_object())?;
804 for (arg_key, arg_val) in args {
805 if let Some(val_str) = arg_val.as_str()
806 && !policy.argument_allowed(role, tool_name, arg_key, val_str)
807 {
808 tracing::warn!(
813 user = %identity_name,
814 role = %role,
815 tool = tool_name,
816 argument = arg_key,
817 arg_hmac = %policy.redact_arg(val_str),
818 "argument not in allowlist"
819 );
820 return Some(
821 McpxError::Rbac(format!(
822 "argument '{arg_key}' value not in allowlist for tool '{tool_name}'"
823 ))
824 .into_response(),
825 );
826 }
827 }
828 None
829}
830
831fn glob_match(pattern: &str, text: &str) -> bool {
836 let parts: Vec<&str> = pattern.split('*').collect();
837 if parts.len() == 1 {
838 return pattern == text;
840 }
841
842 let mut pos = 0;
843
844 if let Some(&first) = parts.first()
846 && !first.is_empty()
847 {
848 if !text.starts_with(first) {
849 return false;
850 }
851 pos = first.len();
852 }
853
854 if let Some(&last) = parts.last()
856 && !last.is_empty()
857 {
858 if !text[pos..].ends_with(last) {
859 return false;
860 }
861 let end = text.len() - last.len();
863 if pos > end {
864 return false;
865 }
866 let middle = &text[pos..end];
868 let middle_parts = parts.get(1..parts.len() - 1).unwrap_or_default();
869 return match_middle(middle, middle_parts);
870 }
871
872 let middle = &text[pos..];
874 let middle_parts = parts.get(1..parts.len() - 1).unwrap_or_default();
875 match_middle(middle, middle_parts)
876}
877
878fn match_middle(mut text: &str, parts: &[&str]) -> bool {
880 for part in parts {
881 if part.is_empty() {
882 continue;
883 }
884 if let Some(idx) = text.find(part) {
885 text = &text[idx + part.len()..];
886 } else {
887 return false;
888 }
889 }
890 true
891}
892
893#[cfg(test)]
894mod tests {
895 use super::*;
896
897 fn test_policy() -> RbacPolicy {
898 RbacPolicy::new(&RbacConfig {
899 enabled: true,
900 roles: vec![
901 RoleConfig {
902 name: "viewer".into(),
903 description: Some("Read-only".into()),
904 allow: vec![
905 "list_hosts".into(),
906 "resource_list".into(),
907 "resource_inspect".into(),
908 "resource_logs".into(),
909 "system_info".into(),
910 ],
911 deny: vec![],
912 hosts: vec!["*".into()],
913 argument_allowlists: vec![],
914 },
915 RoleConfig {
916 name: "deploy".into(),
917 description: Some("Lifecycle management".into()),
918 allow: vec![
919 "list_hosts".into(),
920 "resource_list".into(),
921 "resource_run".into(),
922 "resource_start".into(),
923 "resource_stop".into(),
924 "resource_restart".into(),
925 "resource_logs".into(),
926 "image_pull".into(),
927 ],
928 deny: vec!["resource_delete".into(), "resource_exec".into()],
929 hosts: vec!["web-*".into(), "api-*".into()],
930 argument_allowlists: vec![],
931 },
932 RoleConfig {
933 name: "ops".into(),
934 description: Some("Full access".into()),
935 allow: vec!["*".into()],
936 deny: vec![],
937 hosts: vec!["*".into()],
938 argument_allowlists: vec![],
939 },
940 RoleConfig {
941 name: "restricted-exec".into(),
942 description: Some("Exec with argument allowlist".into()),
943 allow: vec!["resource_exec".into()],
944 deny: vec![],
945 hosts: vec!["dev-*".into()],
946 argument_allowlists: vec![ArgumentAllowlist {
947 tool: "resource_exec".into(),
948 argument: "cmd".into(),
949 allowed: vec![
950 "sh".into(),
951 "bash".into(),
952 "cat".into(),
953 "ls".into(),
954 "ps".into(),
955 ],
956 }],
957 },
958 ],
959 redaction_salt: None,
960 })
961 }
962
963 #[test]
966 fn glob_exact_match() {
967 assert!(glob_match("web-prod-1", "web-prod-1"));
968 assert!(!glob_match("web-prod-1", "web-prod-2"));
969 }
970
971 #[test]
972 fn glob_star_suffix() {
973 assert!(glob_match("web-*", "web-prod-1"));
974 assert!(glob_match("web-*", "web-staging"));
975 assert!(!glob_match("web-*", "api-prod"));
976 }
977
978 #[test]
979 fn glob_star_prefix() {
980 assert!(glob_match("*-prod", "web-prod"));
981 assert!(glob_match("*-prod", "api-prod"));
982 assert!(!glob_match("*-prod", "web-staging"));
983 }
984
985 #[test]
986 fn glob_star_middle() {
987 assert!(glob_match("web-*-prod", "web-us-prod"));
988 assert!(glob_match("web-*-prod", "web-eu-east-prod"));
989 assert!(!glob_match("web-*-prod", "web-staging"));
990 }
991
992 #[test]
993 fn glob_star_only() {
994 assert!(glob_match("*", "anything"));
995 assert!(glob_match("*", ""));
996 }
997
998 #[test]
999 fn glob_multiple_stars() {
1000 assert!(glob_match("*web*prod*", "my-web-us-prod-1"));
1001 assert!(!glob_match("*web*prod*", "my-api-us-staging"));
1002 }
1003
1004 #[test]
1007 fn disabled_policy_allows_everything() {
1008 let policy = RbacPolicy::new(&RbacConfig {
1009 enabled: false,
1010 roles: vec![],
1011 redaction_salt: None,
1012 });
1013 assert_eq!(
1014 policy.check("nonexistent", "resource_delete", "any-host"),
1015 RbacDecision::Allow
1016 );
1017 }
1018
1019 #[test]
1020 fn unknown_role_denied() {
1021 let policy = test_policy();
1022 assert_eq!(
1023 policy.check("unknown", "resource_list", "web-prod-1"),
1024 RbacDecision::Deny
1025 );
1026 }
1027
1028 #[test]
1029 fn viewer_allowed_read_ops() {
1030 let policy = test_policy();
1031 assert_eq!(
1032 policy.check("viewer", "resource_list", "web-prod-1"),
1033 RbacDecision::Allow
1034 );
1035 assert_eq!(
1036 policy.check("viewer", "system_info", "db-host"),
1037 RbacDecision::Allow
1038 );
1039 }
1040
1041 #[test]
1042 fn viewer_denied_write_ops() {
1043 let policy = test_policy();
1044 assert_eq!(
1045 policy.check("viewer", "resource_run", "web-prod-1"),
1046 RbacDecision::Deny
1047 );
1048 assert_eq!(
1049 policy.check("viewer", "resource_delete", "web-prod-1"),
1050 RbacDecision::Deny
1051 );
1052 }
1053
1054 #[test]
1055 fn deploy_allowed_on_matching_hosts() {
1056 let policy = test_policy();
1057 assert_eq!(
1058 policy.check("deploy", "resource_run", "web-prod-1"),
1059 RbacDecision::Allow
1060 );
1061 assert_eq!(
1062 policy.check("deploy", "resource_start", "api-staging"),
1063 RbacDecision::Allow
1064 );
1065 }
1066
1067 #[test]
1068 fn deploy_denied_on_non_matching_host() {
1069 let policy = test_policy();
1070 assert_eq!(
1071 policy.check("deploy", "resource_run", "db-prod-1"),
1072 RbacDecision::Deny
1073 );
1074 }
1075
1076 #[test]
1077 fn deny_overrides_allow() {
1078 let policy = test_policy();
1079 assert_eq!(
1080 policy.check("deploy", "resource_delete", "web-prod-1"),
1081 RbacDecision::Deny
1082 );
1083 assert_eq!(
1084 policy.check("deploy", "resource_exec", "web-prod-1"),
1085 RbacDecision::Deny
1086 );
1087 }
1088
1089 #[test]
1090 fn ops_wildcard_allows_everything() {
1091 let policy = test_policy();
1092 assert_eq!(
1093 policy.check("ops", "resource_delete", "any-host"),
1094 RbacDecision::Allow
1095 );
1096 assert_eq!(
1097 policy.check("ops", "secret_create", "db-host"),
1098 RbacDecision::Allow
1099 );
1100 }
1101
1102 #[test]
1105 fn host_visible_respects_globs() {
1106 let policy = test_policy();
1107 assert!(policy.host_visible("deploy", "web-prod-1"));
1108 assert!(policy.host_visible("deploy", "api-staging"));
1109 assert!(!policy.host_visible("deploy", "db-prod-1"));
1110 assert!(policy.host_visible("ops", "anything"));
1111 assert!(policy.host_visible("viewer", "anything"));
1112 }
1113
1114 #[test]
1115 fn host_visible_unknown_role() {
1116 let policy = test_policy();
1117 assert!(!policy.host_visible("unknown", "web-prod-1"));
1118 }
1119
1120 #[test]
1123 fn argument_allowed_no_allowlist() {
1124 let policy = test_policy();
1125 assert!(policy.argument_allowed("ops", "resource_exec", "cmd", "rm -rf /"));
1127 assert!(policy.argument_allowed("ops", "resource_exec", "cmd", "bash"));
1128 }
1129
1130 #[test]
1131 fn argument_allowed_with_allowlist() {
1132 let policy = test_policy();
1133 assert!(policy.argument_allowed("restricted-exec", "resource_exec", "cmd", "sh"));
1134 assert!(policy.argument_allowed(
1135 "restricted-exec",
1136 "resource_exec",
1137 "cmd",
1138 "bash -c 'echo hi'"
1139 ));
1140 assert!(policy.argument_allowed(
1141 "restricted-exec",
1142 "resource_exec",
1143 "cmd",
1144 "cat /etc/hosts"
1145 ));
1146 assert!(policy.argument_allowed(
1147 "restricted-exec",
1148 "resource_exec",
1149 "cmd",
1150 "/usr/bin/ls -la"
1151 ));
1152 }
1153
1154 #[test]
1155 fn argument_denied_not_in_allowlist() {
1156 let policy = test_policy();
1157 assert!(!policy.argument_allowed("restricted-exec", "resource_exec", "cmd", "rm -rf /"));
1158 assert!(!policy.argument_allowed(
1159 "restricted-exec",
1160 "resource_exec",
1161 "cmd",
1162 "python3 exploit.py"
1163 ));
1164 assert!(!policy.argument_allowed(
1165 "restricted-exec",
1166 "resource_exec",
1167 "cmd",
1168 "/usr/bin/curl evil.com"
1169 ));
1170 }
1171
1172 #[test]
1173 fn argument_denied_unknown_role() {
1174 let policy = test_policy();
1175 assert!(!policy.argument_allowed("unknown", "resource_exec", "cmd", "sh"));
1176 }
1177
1178 fn shlex_policy(allowed: Vec<String>) -> RbacPolicy {
1187 let role = RoleConfig::new("viewer", vec!["run".into()], vec!["*".into()])
1188 .with_argument_allowlists(vec![ArgumentAllowlist::new("run", "cmd", allowed)]);
1189 let mut config = RbacConfig::with_roles(vec![role]);
1190 config.enabled = true;
1191 RbacPolicy::new(&config)
1192 }
1193
1194 #[test]
1195 fn argument_allowed_matches_quoted_path_with_spaces() {
1196 let policy = shlex_policy(vec!["/usr/bin/my tool".into()]);
1197 assert!(policy.argument_allowed("viewer", "run", "cmd", r#""/usr/bin/my tool" --flag"#));
1198 }
1199
1200 #[test]
1201 fn argument_allowed_matches_basename_of_quoted_path() {
1202 let policy = shlex_policy(vec!["my tool".into()]);
1203 assert!(policy.argument_allowed("viewer", "run", "cmd", r#""/usr/bin/my tool" --flag"#));
1204 }
1205
1206 #[test]
1207 fn argument_allowed_fails_closed_on_unbalanced_quote() {
1208 let policy = shlex_policy(vec!["unbalanced".into()]);
1209 assert!(!policy.argument_allowed("viewer", "run", "cmd", r"unbalanced 'quote"));
1210 }
1211
1212 #[test]
1213 fn argument_allowed_fails_closed_on_empty_string() {
1214 let policy = shlex_policy(vec![String::new()]);
1215 assert!(!policy.argument_allowed("viewer", "run", "cmd", ""));
1216 }
1217
1218 #[test]
1219 fn argument_allowed_handles_single_quoted_executable() {
1220 let policy = shlex_policy(vec!["/bin/sh".into()]);
1221 assert!(policy.argument_allowed("viewer", "run", "cmd", r"'/bin/sh' -c 'echo hi'"));
1222 }
1223
1224 #[test]
1225 fn argument_allowed_handles_tab_separator() {
1226 let policy = shlex_policy(vec!["ls".into()]);
1227 assert!(policy.argument_allowed("viewer", "run", "cmd", "ls\t/etc/passwd"));
1228 }
1229
1230 #[test]
1231 fn argument_allowed_plain_token_unchanged() {
1232 let policy = shlex_policy(vec!["ls".into()]);
1233 assert!(policy.argument_allowed("viewer", "run", "cmd", "ls"));
1234 }
1235
1236 #[test]
1242 fn argument_allowed_fails_closed_on_quoted_empty_first_token() {
1243 let policy = shlex_policy(vec![String::new()]);
1247 assert!(!policy.argument_allowed("viewer", "run", "cmd", r#""""#));
1248 }
1249
1250 #[test]
1251 fn argument_allowed_quoted_literal_token_no_longer_matches() {
1252 let policy = shlex_policy(vec!["'bash'".into()]);
1258 assert!(!policy.argument_allowed("viewer", "run", "cmd", "'bash' -c true"));
1259 }
1260
1261 #[test]
1262 fn argument_allowed_backslash_literal_token_no_longer_matches() {
1263 let policy = shlex_policy(vec![r"foo\bar".into()]);
1268 assert!(!policy.argument_allowed("viewer", "run", "cmd", r"foo\bar --x"));
1269 }
1270
1271 #[test]
1272 fn argument_allowed_windows_path_no_longer_matches() {
1273 let policy = shlex_policy(vec![r"C:\Windows\System32\cmd.exe".into()]);
1278 assert!(!policy.argument_allowed(
1279 "viewer",
1280 "run",
1281 "cmd",
1282 r"C:\Windows\System32\cmd.exe /c dir"
1283 ));
1284 }
1285
1286 #[test]
1289 fn host_patterns_returns_globs() {
1290 let policy = test_policy();
1291 assert_eq!(
1292 policy.host_patterns("deploy"),
1293 Some(vec!["web-*".to_owned(), "api-*".to_owned()].as_slice())
1294 );
1295 assert_eq!(
1296 policy.host_patterns("ops"),
1297 Some(vec!["*".to_owned()].as_slice())
1298 );
1299 assert!(policy.host_patterns("nonexistent").is_none());
1300 }
1301
1302 #[test]
1305 fn check_operation_allows_without_host() {
1306 let policy = test_policy();
1307 assert_eq!(
1308 policy.check_operation("deploy", "resource_run"),
1309 RbacDecision::Allow
1310 );
1311 assert_eq!(
1313 policy.check("deploy", "resource_run", "db-prod-1"),
1314 RbacDecision::Deny
1315 );
1316 }
1317
1318 #[test]
1319 fn check_operation_deny_overrides() {
1320 let policy = test_policy();
1321 assert_eq!(
1322 policy.check_operation("deploy", "resource_delete"),
1323 RbacDecision::Deny
1324 );
1325 }
1326
1327 #[test]
1328 fn check_operation_unknown_role() {
1329 let policy = test_policy();
1330 assert_eq!(
1331 policy.check_operation("unknown", "resource_list"),
1332 RbacDecision::Deny
1333 );
1334 }
1335
1336 #[test]
1337 fn check_operation_disabled() {
1338 let policy = RbacPolicy::new(&RbacConfig {
1339 enabled: false,
1340 roles: vec![],
1341 redaction_salt: None,
1342 });
1343 assert_eq!(
1344 policy.check_operation("nonexistent", "anything"),
1345 RbacDecision::Allow
1346 );
1347 }
1348
1349 #[test]
1352 fn current_role_returns_none_outside_scope() {
1353 assert!(current_role().is_none());
1354 }
1355
1356 #[test]
1357 fn current_identity_returns_none_outside_scope() {
1358 assert!(current_identity().is_none());
1359 }
1360
1361 use axum::{
1364 body::Body,
1365 http::{Method, Request, StatusCode},
1366 };
1367 use tower::ServiceExt as _;
1368
1369 fn tool_call_body(tool: &str, args: &serde_json::Value) -> String {
1370 serde_json::json!({
1371 "jsonrpc": "2.0",
1372 "id": 1,
1373 "method": "tools/call",
1374 "params": {
1375 "name": tool,
1376 "arguments": args
1377 }
1378 })
1379 .to_string()
1380 }
1381
1382 fn rbac_router(policy: Arc<RbacPolicy>) -> axum::Router {
1383 axum::Router::new()
1384 .route("/mcp", axum::routing::post(|| async { "ok" }))
1385 .layer(axum::middleware::from_fn(move |req, next| {
1386 let p = Arc::clone(&policy);
1387 rbac_middleware(p, None, req, next)
1388 }))
1389 }
1390
1391 fn rbac_router_with_identity(policy: Arc<RbacPolicy>, identity: AuthIdentity) -> axum::Router {
1392 axum::Router::new()
1393 .route("/mcp", axum::routing::post(|| async { "ok" }))
1394 .layer(axum::middleware::from_fn(
1395 move |mut req: Request<Body>, next: Next| {
1396 let p = Arc::clone(&policy);
1397 let id = identity.clone();
1398 async move {
1399 req.extensions_mut().insert(id);
1400 rbac_middleware(p, None, req, next).await
1401 }
1402 },
1403 ))
1404 }
1405
1406 #[tokio::test]
1407 async fn middleware_passes_non_post() {
1408 let policy = Arc::new(test_policy());
1409 let app = rbac_router(policy);
1410 let req = Request::builder()
1412 .method(Method::GET)
1413 .uri("/mcp")
1414 .body(Body::empty())
1415 .unwrap();
1416 let resp = app.oneshot(req).await.unwrap();
1419 assert_eq!(resp.status(), StatusCode::METHOD_NOT_ALLOWED);
1420 }
1421
1422 #[tokio::test]
1423 async fn middleware_denies_without_identity() {
1424 let policy = Arc::new(test_policy());
1425 let app = rbac_router(policy);
1426 let body = tool_call_body("resource_list", &serde_json::json!({}));
1427 let req = Request::builder()
1428 .method(Method::POST)
1429 .uri("/mcp")
1430 .header("content-type", "application/json")
1431 .body(Body::from(body))
1432 .unwrap();
1433 let resp = app.oneshot(req).await.unwrap();
1434 assert_eq!(resp.status(), StatusCode::FORBIDDEN);
1435 }
1436
1437 #[tokio::test]
1438 async fn middleware_allows_permitted_tool() {
1439 let policy = Arc::new(test_policy());
1440 let id = AuthIdentity {
1441 method: crate::auth::AuthMethod::BearerToken,
1442 name: "alice".into(),
1443 role: "viewer".into(),
1444 raw_token: None,
1445 sub: None,
1446 };
1447 let app = rbac_router_with_identity(policy, id);
1448 let body = tool_call_body("resource_list", &serde_json::json!({}));
1449 let req = Request::builder()
1450 .method(Method::POST)
1451 .uri("/mcp")
1452 .header("content-type", "application/json")
1453 .body(Body::from(body))
1454 .unwrap();
1455 let resp = app.oneshot(req).await.unwrap();
1456 assert_eq!(resp.status(), StatusCode::OK);
1457 }
1458
1459 #[tokio::test]
1460 async fn middleware_denies_unpermitted_tool() {
1461 let policy = Arc::new(test_policy());
1462 let id = AuthIdentity {
1463 method: crate::auth::AuthMethod::BearerToken,
1464 name: "alice".into(),
1465 role: "viewer".into(),
1466 raw_token: None,
1467 sub: None,
1468 };
1469 let app = rbac_router_with_identity(policy, id);
1470 let body = tool_call_body("resource_delete", &serde_json::json!({}));
1471 let req = Request::builder()
1472 .method(Method::POST)
1473 .uri("/mcp")
1474 .header("content-type", "application/json")
1475 .body(Body::from(body))
1476 .unwrap();
1477 let resp = app.oneshot(req).await.unwrap();
1478 assert_eq!(resp.status(), StatusCode::FORBIDDEN);
1479 }
1480
1481 #[tokio::test]
1482 async fn middleware_passes_non_tool_call_post() {
1483 let policy = Arc::new(test_policy());
1484 let id = AuthIdentity {
1485 method: crate::auth::AuthMethod::BearerToken,
1486 name: "alice".into(),
1487 role: "viewer".into(),
1488 raw_token: None,
1489 sub: None,
1490 };
1491 let app = rbac_router_with_identity(policy, id);
1492 let body = serde_json::json!({
1494 "jsonrpc": "2.0",
1495 "id": 1,
1496 "method": "resources/list"
1497 })
1498 .to_string();
1499 let req = Request::builder()
1500 .method(Method::POST)
1501 .uri("/mcp")
1502 .header("content-type", "application/json")
1503 .body(Body::from(body))
1504 .unwrap();
1505 let resp = app.oneshot(req).await.unwrap();
1506 assert_eq!(resp.status(), StatusCode::OK);
1507 }
1508
1509 #[tokio::test]
1510 async fn middleware_enforces_argument_allowlist() {
1511 let policy = Arc::new(test_policy());
1512 let id = AuthIdentity {
1513 method: crate::auth::AuthMethod::BearerToken,
1514 name: "dev".into(),
1515 role: "restricted-exec".into(),
1516 raw_token: None,
1517 sub: None,
1518 };
1519 let app = rbac_router_with_identity(Arc::clone(&policy), id.clone());
1521 let body = tool_call_body(
1522 "resource_exec",
1523 &serde_json::json!({"cmd": "ls -la", "host": "dev-1"}),
1524 );
1525 let req = Request::builder()
1526 .method(Method::POST)
1527 .uri("/mcp")
1528 .body(Body::from(body))
1529 .unwrap();
1530 let resp = app.oneshot(req).await.unwrap();
1531 assert_eq!(resp.status(), StatusCode::OK);
1532
1533 let app = rbac_router_with_identity(policy, id);
1535 let body = tool_call_body(
1536 "resource_exec",
1537 &serde_json::json!({"cmd": "rm -rf /", "host": "dev-1"}),
1538 );
1539 let req = Request::builder()
1540 .method(Method::POST)
1541 .uri("/mcp")
1542 .body(Body::from(body))
1543 .unwrap();
1544 let resp = app.oneshot(req).await.unwrap();
1545 assert_eq!(resp.status(), StatusCode::FORBIDDEN);
1546 }
1547
1548 #[tokio::test]
1549 async fn middleware_disabled_policy_passes_everything() {
1550 let policy = Arc::new(RbacPolicy::disabled());
1551 let app = rbac_router(policy);
1552 let body = tool_call_body("anything", &serde_json::json!({}));
1554 let req = Request::builder()
1555 .method(Method::POST)
1556 .uri("/mcp")
1557 .body(Body::from(body))
1558 .unwrap();
1559 let resp = app.oneshot(req).await.unwrap();
1560 assert_eq!(resp.status(), StatusCode::OK);
1561 }
1562
1563 #[tokio::test]
1564 async fn middleware_batch_all_allowed_passes() {
1565 let policy = Arc::new(test_policy());
1566 let id = AuthIdentity {
1567 method: crate::auth::AuthMethod::BearerToken,
1568 name: "alice".into(),
1569 role: "viewer".into(),
1570 raw_token: None,
1571 sub: None,
1572 };
1573 let app = rbac_router_with_identity(policy, id);
1574 let body = serde_json::json!([
1575 {
1576 "jsonrpc": "2.0",
1577 "id": 1,
1578 "method": "tools/call",
1579 "params": { "name": "resource_list", "arguments": {} }
1580 },
1581 {
1582 "jsonrpc": "2.0",
1583 "id": 2,
1584 "method": "tools/call",
1585 "params": { "name": "system_info", "arguments": {} }
1586 }
1587 ])
1588 .to_string();
1589 let req = Request::builder()
1590 .method(Method::POST)
1591 .uri("/mcp")
1592 .header("content-type", "application/json")
1593 .body(Body::from(body))
1594 .unwrap();
1595 let resp = app.oneshot(req).await.unwrap();
1596 assert_eq!(resp.status(), StatusCode::OK);
1597 }
1598
1599 #[tokio::test]
1600 async fn middleware_batch_with_denied_call_rejects_entire_batch() {
1601 let policy = Arc::new(test_policy());
1602 let id = AuthIdentity {
1603 method: crate::auth::AuthMethod::BearerToken,
1604 name: "alice".into(),
1605 role: "viewer".into(),
1606 raw_token: None,
1607 sub: None,
1608 };
1609 let app = rbac_router_with_identity(policy, id);
1610 let body = serde_json::json!([
1611 {
1612 "jsonrpc": "2.0",
1613 "id": 1,
1614 "method": "tools/call",
1615 "params": { "name": "resource_list", "arguments": {} }
1616 },
1617 {
1618 "jsonrpc": "2.0",
1619 "id": 2,
1620 "method": "tools/call",
1621 "params": { "name": "resource_delete", "arguments": {} }
1622 }
1623 ])
1624 .to_string();
1625 let req = Request::builder()
1626 .method(Method::POST)
1627 .uri("/mcp")
1628 .header("content-type", "application/json")
1629 .body(Body::from(body))
1630 .unwrap();
1631 let resp = app.oneshot(req).await.unwrap();
1632 assert_eq!(resp.status(), StatusCode::FORBIDDEN);
1633 }
1634
1635 #[tokio::test]
1636 async fn middleware_batch_mixed_allowed_and_denied_rejects() {
1637 let policy = Arc::new(test_policy());
1638 let id = AuthIdentity {
1639 method: crate::auth::AuthMethod::BearerToken,
1640 name: "dev".into(),
1641 role: "restricted-exec".into(),
1642 raw_token: None,
1643 sub: None,
1644 };
1645 let app = rbac_router_with_identity(policy, id);
1646 let body = serde_json::json!([
1647 {
1648 "jsonrpc": "2.0",
1649 "id": 1,
1650 "method": "tools/call",
1651 "params": {
1652 "name": "resource_exec",
1653 "arguments": { "cmd": "ls -la", "host": "dev-1" }
1654 }
1655 },
1656 {
1657 "jsonrpc": "2.0",
1658 "id": 2,
1659 "method": "tools/call",
1660 "params": {
1661 "name": "resource_exec",
1662 "arguments": { "cmd": "rm -rf /", "host": "dev-1" }
1663 }
1664 }
1665 ])
1666 .to_string();
1667 let req = Request::builder()
1668 .method(Method::POST)
1669 .uri("/mcp")
1670 .header("content-type", "application/json")
1671 .body(Body::from(body))
1672 .unwrap();
1673 let resp = app.oneshot(req).await.unwrap();
1674 assert_eq!(resp.status(), StatusCode::FORBIDDEN);
1675 }
1676
1677 #[test]
1680 fn redact_with_salt_is_deterministic_per_salt() {
1681 let salt = b"unit-test-salt";
1682 let a = redact_with_salt(salt, "rm -rf /");
1683 let b = redact_with_salt(salt, "rm -rf /");
1684 assert_eq!(a, b, "same input + salt must yield identical hash");
1685 assert_eq!(a.len(), 8, "redacted hash is 8 hex chars (4 bytes)");
1686 assert!(
1687 a.chars().all(|c| c.is_ascii_hexdigit()),
1688 "redacted hash must be lowercase hex: {a}"
1689 );
1690 }
1691
1692 #[test]
1693 fn redact_with_salt_differs_across_salts() {
1694 let v = "the-same-value";
1695 let h1 = redact_with_salt(b"salt-one", v);
1696 let h2 = redact_with_salt(b"salt-two", v);
1697 assert_ne!(
1698 h1, h2,
1699 "different salts must produce different hashes for the same value"
1700 );
1701 }
1702
1703 #[test]
1704 fn redact_with_salt_distinguishes_values() {
1705 let salt = b"k";
1706 let h1 = redact_with_salt(salt, "alpha");
1707 let h2 = redact_with_salt(salt, "beta");
1708 assert_ne!(h1, h2, "different values must produce different hashes");
1710 }
1711
1712 #[test]
1713 fn policy_with_configured_salt_redacts_consistently() {
1714 let cfg = RbacConfig {
1715 enabled: true,
1716 roles: vec![],
1717 redaction_salt: Some(SecretString::from("my-stable-salt")),
1718 };
1719 let p1 = RbacPolicy::new(&cfg);
1720 let p2 = RbacPolicy::new(&cfg);
1721 assert_eq!(
1722 p1.redact_arg("payload"),
1723 p2.redact_arg("payload"),
1724 "policies built from the same configured salt must agree"
1725 );
1726 }
1727
1728 #[test]
1729 fn policy_without_configured_salt_uses_process_salt() {
1730 let cfg = RbacConfig {
1731 enabled: true,
1732 roles: vec![],
1733 redaction_salt: None,
1734 };
1735 let p1 = RbacPolicy::new(&cfg);
1736 let p2 = RbacPolicy::new(&cfg);
1737 assert_eq!(
1739 p1.redact_arg("payload"),
1740 p2.redact_arg("payload"),
1741 "process-wide salt must be consistent within one process"
1742 );
1743 }
1744
1745 #[test]
1746 fn redact_arg_is_fast_enough() {
1747 let salt = b"perf-sanity-salt-32-bytes-padded";
1751 let value = "x".repeat(256);
1752 let start = std::time::Instant::now();
1753 let _ = redact_with_salt(salt, &value);
1754 let elapsed = start.elapsed();
1755 assert!(
1756 elapsed < Duration::from_millis(5),
1757 "single redact_with_salt took {elapsed:?}, expected <5 ms even in debug"
1758 );
1759 }
1760
1761 #[tokio::test]
1773 async fn deny_path_uses_explicit_identity_not_task_local() {
1774 let policy = Arc::new(test_policy());
1775 let id = AuthIdentity {
1776 method: crate::auth::AuthMethod::BearerToken,
1777 name: "alice-the-auditor".into(),
1778 role: "viewer".into(),
1779 raw_token: None,
1780 sub: None,
1781 };
1782 let app = rbac_router_with_identity(policy, id);
1783 let body = tool_call_body("resource_delete", &serde_json::json!({}));
1785 let req = Request::builder()
1786 .method(Method::POST)
1787 .uri("/mcp")
1788 .header("content-type", "application/json")
1789 .body(Body::from(body))
1790 .unwrap();
1791 let resp = app.oneshot(req).await.unwrap();
1792 assert_eq!(resp.status(), StatusCode::FORBIDDEN);
1793 }
1794}