1use std::{net::IpAddr, num::NonZeroU32, sync::Arc, time::Duration};
12
13use axum::{
14 body::Body,
15 extract::ConnectInfo,
16 http::{Method, Request, StatusCode},
17 middleware::Next,
18 response::{IntoResponse, Response},
19};
20use hmac::{Hmac, Mac};
21use http_body_util::BodyExt;
22use secrecy::{ExposeSecret, SecretString};
23use serde::Deserialize;
24use sha2::Sha256;
25
26use crate::{
27 auth::{AuthIdentity, TlsConnInfo},
28 bounded_limiter::BoundedKeyedLimiter,
29 error::McpxError,
30};
31
32pub(crate) type ToolRateLimiter = BoundedKeyedLimiter<IpAddr>;
35
36const DEFAULT_TOOL_RATE: NonZeroU32 = NonZeroU32::new(120).unwrap();
39
40const DEFAULT_TOOL_MAX_TRACKED_KEYS: usize = 10_000;
43
44const DEFAULT_TOOL_IDLE_EVICTION: Duration = Duration::from_mins(15);
46
47#[must_use]
53pub(crate) fn build_tool_rate_limiter(max_per_minute: u32) -> Arc<ToolRateLimiter> {
54 build_tool_rate_limiter_with_bounds(
55 max_per_minute,
56 DEFAULT_TOOL_MAX_TRACKED_KEYS,
57 DEFAULT_TOOL_IDLE_EVICTION,
58 )
59}
60
61#[must_use]
63pub(crate) fn build_tool_rate_limiter_with_bounds(
64 max_per_minute: u32,
65 max_tracked_keys: usize,
66 idle_eviction: Duration,
67) -> Arc<ToolRateLimiter> {
68 let quota =
69 governor::Quota::per_minute(NonZeroU32::new(max_per_minute).unwrap_or(DEFAULT_TOOL_RATE));
70 Arc::new(BoundedKeyedLimiter::new(
71 quota,
72 max_tracked_keys,
73 idle_eviction,
74 ))
75}
76
77tokio::task_local! {
84 static CURRENT_ROLE: String;
85 static CURRENT_IDENTITY: String;
86 static CURRENT_TOKEN: SecretString;
87 static CURRENT_SUB: String;
88}
89
90#[must_use]
93pub fn current_role() -> Option<String> {
94 CURRENT_ROLE.try_with(Clone::clone).ok()
95}
96
97#[must_use]
100pub fn current_identity() -> Option<String> {
101 CURRENT_IDENTITY.try_with(Clone::clone).ok()
102}
103
104#[must_use]
117pub fn current_token() -> Option<SecretString> {
118 CURRENT_TOKEN
119 .try_with(|t| {
120 if t.expose_secret().is_empty() {
121 None
122 } else {
123 Some(t.clone())
124 }
125 })
126 .ok()
127 .flatten()
128}
129
130#[must_use]
134pub fn current_sub() -> Option<String> {
135 CURRENT_SUB
136 .try_with(Clone::clone)
137 .ok()
138 .filter(|s| !s.is_empty())
139}
140
141pub async fn with_token_scope<F: Future>(token: SecretString, f: F) -> F::Output {
146 CURRENT_TOKEN.scope(token, f).await
147}
148
149pub async fn with_rbac_scope<F: Future>(
154 role: String,
155 identity: String,
156 token: SecretString,
157 sub: String,
158 f: F,
159) -> F::Output {
160 CURRENT_ROLE
161 .scope(
162 role,
163 CURRENT_IDENTITY.scope(
164 identity,
165 CURRENT_TOKEN.scope(token, CURRENT_SUB.scope(sub, f)),
166 ),
167 )
168 .await
169}
170
171#[derive(Debug, Clone, Deserialize)]
173#[non_exhaustive]
174pub struct RoleConfig {
175 pub name: String,
177 #[serde(default)]
179 pub description: Option<String>,
180 #[serde(default)]
182 pub allow: Vec<String>,
183 #[serde(default)]
185 pub deny: Vec<String>,
186 #[serde(default = "default_hosts")]
188 pub hosts: Vec<String>,
189 #[serde(default)]
193 pub argument_allowlists: Vec<ArgumentAllowlist>,
194}
195
196impl RoleConfig {
197 #[must_use]
199 pub fn new(name: impl Into<String>, allow: Vec<String>, hosts: Vec<String>) -> Self {
200 Self {
201 name: name.into(),
202 description: None,
203 allow,
204 deny: vec![],
205 hosts,
206 argument_allowlists: vec![],
207 }
208 }
209
210 #[must_use]
212 pub fn with_argument_allowlists(mut self, allowlists: Vec<ArgumentAllowlist>) -> Self {
213 self.argument_allowlists = allowlists;
214 self
215 }
216}
217
218#[derive(Debug, Clone, Deserialize)]
225#[non_exhaustive]
226pub struct ArgumentAllowlist {
227 pub tool: String,
229 pub argument: String,
231 #[serde(default)]
233 pub allowed: Vec<String>,
234}
235
236impl ArgumentAllowlist {
237 #[must_use]
239 pub fn new(tool: impl Into<String>, argument: impl Into<String>, allowed: Vec<String>) -> Self {
240 Self {
241 tool: tool.into(),
242 argument: argument.into(),
243 allowed,
244 }
245 }
246}
247
248fn default_hosts() -> Vec<String> {
249 vec!["*".into()]
250}
251
252#[derive(Debug, Clone, Default, Deserialize)]
254#[non_exhaustive]
255pub struct RbacConfig {
256 #[serde(default)]
258 pub enabled: bool,
259 #[serde(default)]
261 pub roles: Vec<RoleConfig>,
262 #[serde(default)]
271 pub redaction_salt: Option<SecretString>,
272}
273
274impl RbacConfig {
275 #[must_use]
277 pub fn with_roles(roles: Vec<RoleConfig>) -> Self {
278 Self {
279 enabled: true,
280 roles,
281 redaction_salt: None,
282 }
283 }
284}
285
286#[derive(Debug, Clone, Copy, PartialEq, Eq)]
288#[non_exhaustive]
289pub enum RbacDecision {
290 Allow,
292 Deny,
294}
295
296#[derive(Debug, Clone, serde::Serialize)]
298#[non_exhaustive]
299pub struct RbacRoleSummary {
300 pub name: String,
302 pub allow: usize,
304 pub deny: usize,
306 pub hosts: usize,
308 pub argument_allowlists: usize,
310}
311
312#[derive(Debug, Clone, serde::Serialize)]
314#[non_exhaustive]
315pub struct RbacPolicySummary {
316 pub enabled: bool,
318 pub roles: Vec<RbacRoleSummary>,
320}
321
322#[derive(Debug, Clone)]
328#[non_exhaustive]
329pub struct RbacPolicy {
330 roles: Vec<RoleConfig>,
331 enabled: bool,
332 redaction_salt: Arc<SecretString>,
335}
336
337impl RbacPolicy {
338 #[must_use]
341 pub fn new(config: &RbacConfig) -> Self {
342 let salt = config
343 .redaction_salt
344 .clone()
345 .unwrap_or_else(|| process_redaction_salt().clone());
346 Self {
347 roles: config.roles.clone(),
348 enabled: config.enabled,
349 redaction_salt: Arc::new(salt),
350 }
351 }
352
353 #[must_use]
355 pub fn disabled() -> Self {
356 Self {
357 roles: Vec::new(),
358 enabled: false,
359 redaction_salt: Arc::new(process_redaction_salt().clone()),
360 }
361 }
362
363 #[must_use]
365 pub fn is_enabled(&self) -> bool {
366 self.enabled
367 }
368
369 #[must_use]
374 pub fn summary(&self) -> RbacPolicySummary {
375 let roles = self
376 .roles
377 .iter()
378 .map(|r| RbacRoleSummary {
379 name: r.name.clone(),
380 allow: r.allow.len(),
381 deny: r.deny.len(),
382 hosts: r.hosts.len(),
383 argument_allowlists: r.argument_allowlists.len(),
384 })
385 .collect();
386 RbacPolicySummary {
387 enabled: self.enabled,
388 roles,
389 }
390 }
391
392 #[must_use]
397 pub fn check_operation(&self, role: &str, operation: &str) -> RbacDecision {
398 if !self.enabled {
399 return RbacDecision::Allow;
400 }
401 let Some(role_cfg) = self.find_role(role) else {
402 return RbacDecision::Deny;
403 };
404 if role_cfg.deny.iter().any(|d| d == operation) {
405 return RbacDecision::Deny;
406 }
407 if role_cfg.allow.iter().any(|a| a == "*" || a == operation) {
408 return RbacDecision::Allow;
409 }
410 RbacDecision::Deny
411 }
412
413 #[must_use]
420 pub fn check(&self, role: &str, operation: &str, host: &str) -> RbacDecision {
421 if !self.enabled {
422 return RbacDecision::Allow;
423 }
424 let Some(role_cfg) = self.find_role(role) else {
425 return RbacDecision::Deny;
426 };
427 if role_cfg.deny.iter().any(|d| d == operation) {
428 return RbacDecision::Deny;
429 }
430 if !role_cfg.allow.iter().any(|a| a == "*" || a == operation) {
431 return RbacDecision::Deny;
432 }
433 if !Self::host_matches(&role_cfg.hosts, host) {
434 return RbacDecision::Deny;
435 }
436 RbacDecision::Allow
437 }
438
439 #[must_use]
441 pub fn host_visible(&self, role: &str, host: &str) -> bool {
442 if !self.enabled {
443 return true;
444 }
445 let Some(role_cfg) = self.find_role(role) else {
446 return false;
447 };
448 Self::host_matches(&role_cfg.hosts, host)
449 }
450
451 #[must_use]
453 pub fn host_patterns(&self, role: &str) -> Option<&[String]> {
454 self.find_role(role).map(|r| r.hosts.as_slice())
455 }
456
457 #[must_use]
464 pub fn argument_allowed(&self, role: &str, tool: &str, argument: &str, value: &str) -> bool {
465 if !self.enabled {
466 return true;
467 }
468 let Some(role_cfg) = self.find_role(role) else {
469 return false;
470 };
471 for al in &role_cfg.argument_allowlists {
472 if al.tool != tool && !glob_match(&al.tool, tool) {
473 continue;
474 }
475 if al.argument != argument {
476 continue;
477 }
478 if al.allowed.is_empty() {
479 continue;
480 }
481 let first_token = value.split_whitespace().next().unwrap_or(value);
483 let basename = first_token.rsplit('/').next().unwrap_or(first_token);
485 if !al.allowed.iter().any(|a| a == first_token || a == basename) {
486 return false;
487 }
488 }
489 true
490 }
491
492 fn find_role(&self, name: &str) -> Option<&RoleConfig> {
494 self.roles.iter().find(|r| r.name == name)
495 }
496
497 fn host_matches(patterns: &[String], host: &str) -> bool {
499 patterns.iter().any(|p| glob_match(p, host))
500 }
501
502 #[must_use]
511 pub fn redact_arg(&self, value: &str) -> String {
512 redact_with_salt(self.redaction_salt.expose_secret().as_bytes(), value)
513 }
514}
515
516fn process_redaction_salt() -> &'static SecretString {
519 use base64::{Engine as _, engine::general_purpose::STANDARD_NO_PAD};
520 static PROCESS_SALT: std::sync::OnceLock<SecretString> = std::sync::OnceLock::new();
521 PROCESS_SALT.get_or_init(|| {
522 let mut bytes = [0u8; 32];
523 rand::fill(&mut bytes);
524 SecretString::from(STANDARD_NO_PAD.encode(bytes))
527 })
528}
529
530fn redact_with_salt(salt: &[u8], value: &str) -> String {
535 use std::fmt::Write as _;
536
537 use sha2::Digest as _;
538
539 type HmacSha256 = Hmac<Sha256>;
540 let mut mac = if let Ok(m) = HmacSha256::new_from_slice(salt) {
546 m
547 } else {
548 let digest = Sha256::digest(salt);
549 #[allow(clippy::expect_used)] HmacSha256::new_from_slice(&digest).expect("32-byte SHA256 digest is valid HMAC key")
551 };
552 mac.update(value.as_bytes());
553 let bytes = mac.finalize().into_bytes();
554 let prefix = bytes.get(..4).unwrap_or(&[0; 4]);
556 let mut out = String::with_capacity(8);
557 for b in prefix {
558 let _ = write!(out, "{b:02x}");
559 }
560 out
561}
562
563#[allow(clippy::too_many_lines)]
584pub(crate) async fn rbac_middleware(
585 policy: Arc<RbacPolicy>,
586 tool_limiter: Option<Arc<ToolRateLimiter>>,
587 req: Request<Body>,
588 next: Next,
589) -> Response {
590 if req.method() != Method::POST {
592 return next.run(req).await;
593 }
594
595 let peer_ip: Option<IpAddr> = req
597 .extensions()
598 .get::<ConnectInfo<std::net::SocketAddr>>()
599 .map(|ci| ci.0.ip())
600 .or_else(|| {
601 req.extensions()
602 .get::<ConnectInfo<TlsConnInfo>>()
603 .map(|ci| ci.0.addr.ip())
604 });
605
606 let identity = req.extensions().get::<AuthIdentity>();
608 let identity_name = identity.map(|id| id.name.clone()).unwrap_or_default();
609 let role = identity.map(|id| id.role.clone()).unwrap_or_default();
610 let raw_token: SecretString = identity
613 .and_then(|id| id.raw_token.clone())
614 .unwrap_or_else(|| SecretString::from(String::new()));
615 let sub = identity.and_then(|id| id.sub.clone()).unwrap_or_default();
616
617 if policy.is_enabled() && identity.is_none() {
619 return McpxError::Rbac("no authenticated identity".into()).into_response();
620 }
621
622 let (parts, body) = req.into_parts();
624 let bytes = match body.collect().await {
625 Ok(collected) => collected.to_bytes(),
626 Err(e) => {
627 tracing::error!(error = %e, "failed to read request body");
628 return (
629 StatusCode::INTERNAL_SERVER_ERROR,
630 "failed to read request body",
631 )
632 .into_response();
633 }
634 };
635
636 if let Ok(msg) = serde_json::from_slice::<JsonRpcEnvelope>(&bytes)
638 && msg.method.as_deref() == Some("tools/call")
639 {
640 if let Some(resp) = enforce_rate_limit(tool_limiter.as_deref(), peer_ip) {
641 return resp;
642 }
643 if let Some(ref params) = msg.params
644 && policy.is_enabled()
645 && let Some(resp) = enforce_tool_policy(&policy, &identity_name, &role, params)
646 {
647 return resp;
648 }
649 }
650 let req = Request::from_parts(parts, Body::from(bytes));
654
655 if role.is_empty() {
657 next.run(req).await
658 } else {
659 CURRENT_ROLE
660 .scope(
661 role,
662 CURRENT_IDENTITY.scope(
663 identity_name,
664 CURRENT_TOKEN.scope(raw_token, CURRENT_SUB.scope(sub, next.run(req))),
665 ),
666 )
667 .await
668 }
669}
670
671#[derive(Deserialize)]
673struct JsonRpcEnvelope {
674 method: Option<String>,
675 params: Option<serde_json::Value>,
676}
677
678fn enforce_rate_limit(
681 tool_limiter: Option<&ToolRateLimiter>,
682 peer_ip: Option<IpAddr>,
683) -> Option<Response> {
684 let limiter = tool_limiter?;
685 let ip = peer_ip?;
686 if limiter.check_key(&ip).is_err() {
687 tracing::warn!(%ip, "tool invocation rate limited");
688 return Some(McpxError::RateLimited("too many tool invocations".into()).into_response());
689 }
690 None
691}
692
693fn enforce_tool_policy(
702 policy: &RbacPolicy,
703 identity_name: &str,
704 role: &str,
705 params: &serde_json::Value,
706) -> Option<Response> {
707 let tool_name = params.get("name").and_then(|v| v.as_str()).unwrap_or("");
708 let host = params
709 .get("arguments")
710 .and_then(|a| a.get("host"))
711 .and_then(|h| h.as_str());
712
713 let decision = if let Some(host) = host {
714 policy.check(role, tool_name, host)
715 } else {
716 policy.check_operation(role, tool_name)
717 };
718 if decision == RbacDecision::Deny {
719 tracing::warn!(
720 user = %identity_name,
721 role = %role,
722 tool = tool_name,
723 host = host.unwrap_or("-"),
724 "RBAC denied"
725 );
726 return Some(
727 McpxError::Rbac(format!("{tool_name} denied for role '{role}'")).into_response(),
728 );
729 }
730
731 let args = params.get("arguments").and_then(|a| a.as_object())?;
732 for (arg_key, arg_val) in args {
733 if let Some(val_str) = arg_val.as_str()
734 && !policy.argument_allowed(role, tool_name, arg_key, val_str)
735 {
736 tracing::warn!(
741 user = %identity_name,
742 role = %role,
743 tool = tool_name,
744 argument = arg_key,
745 arg_hmac = %policy.redact_arg(val_str),
746 "argument not in allowlist"
747 );
748 return Some(
749 McpxError::Rbac(format!(
750 "argument '{arg_key}' value not in allowlist for tool '{tool_name}'"
751 ))
752 .into_response(),
753 );
754 }
755 }
756 None
757}
758
759fn glob_match(pattern: &str, text: &str) -> bool {
764 let parts: Vec<&str> = pattern.split('*').collect();
765 if parts.len() == 1 {
766 return pattern == text;
768 }
769
770 let mut pos = 0;
771
772 if let Some(&first) = parts.first()
774 && !first.is_empty()
775 {
776 if !text.starts_with(first) {
777 return false;
778 }
779 pos = first.len();
780 }
781
782 if let Some(&last) = parts.last()
784 && !last.is_empty()
785 {
786 if !text[pos..].ends_with(last) {
787 return false;
788 }
789 let end = text.len() - last.len();
791 if pos > end {
792 return false;
793 }
794 let middle = &text[pos..end];
796 let middle_parts = parts.get(1..parts.len() - 1).unwrap_or_default();
797 return match_middle(middle, middle_parts);
798 }
799
800 let middle = &text[pos..];
802 let middle_parts = parts.get(1..parts.len() - 1).unwrap_or_default();
803 match_middle(middle, middle_parts)
804}
805
806fn match_middle(mut text: &str, parts: &[&str]) -> bool {
808 for part in parts {
809 if part.is_empty() {
810 continue;
811 }
812 if let Some(idx) = text.find(part) {
813 text = &text[idx + part.len()..];
814 } else {
815 return false;
816 }
817 }
818 true
819}
820
821#[cfg(test)]
822mod tests {
823 use super::*;
824
825 fn test_policy() -> RbacPolicy {
826 RbacPolicy::new(&RbacConfig {
827 enabled: true,
828 roles: vec![
829 RoleConfig {
830 name: "viewer".into(),
831 description: Some("Read-only".into()),
832 allow: vec![
833 "list_hosts".into(),
834 "resource_list".into(),
835 "resource_inspect".into(),
836 "resource_logs".into(),
837 "system_info".into(),
838 ],
839 deny: vec![],
840 hosts: vec!["*".into()],
841 argument_allowlists: vec![],
842 },
843 RoleConfig {
844 name: "deploy".into(),
845 description: Some("Lifecycle management".into()),
846 allow: vec![
847 "list_hosts".into(),
848 "resource_list".into(),
849 "resource_run".into(),
850 "resource_start".into(),
851 "resource_stop".into(),
852 "resource_restart".into(),
853 "resource_logs".into(),
854 "image_pull".into(),
855 ],
856 deny: vec!["resource_delete".into(), "resource_exec".into()],
857 hosts: vec!["web-*".into(), "api-*".into()],
858 argument_allowlists: vec![],
859 },
860 RoleConfig {
861 name: "ops".into(),
862 description: Some("Full access".into()),
863 allow: vec!["*".into()],
864 deny: vec![],
865 hosts: vec!["*".into()],
866 argument_allowlists: vec![],
867 },
868 RoleConfig {
869 name: "restricted-exec".into(),
870 description: Some("Exec with argument allowlist".into()),
871 allow: vec!["resource_exec".into()],
872 deny: vec![],
873 hosts: vec!["dev-*".into()],
874 argument_allowlists: vec![ArgumentAllowlist {
875 tool: "resource_exec".into(),
876 argument: "cmd".into(),
877 allowed: vec![
878 "sh".into(),
879 "bash".into(),
880 "cat".into(),
881 "ls".into(),
882 "ps".into(),
883 ],
884 }],
885 },
886 ],
887 redaction_salt: None,
888 })
889 }
890
891 #[test]
894 fn glob_exact_match() {
895 assert!(glob_match("web-prod-1", "web-prod-1"));
896 assert!(!glob_match("web-prod-1", "web-prod-2"));
897 }
898
899 #[test]
900 fn glob_star_suffix() {
901 assert!(glob_match("web-*", "web-prod-1"));
902 assert!(glob_match("web-*", "web-staging"));
903 assert!(!glob_match("web-*", "api-prod"));
904 }
905
906 #[test]
907 fn glob_star_prefix() {
908 assert!(glob_match("*-prod", "web-prod"));
909 assert!(glob_match("*-prod", "api-prod"));
910 assert!(!glob_match("*-prod", "web-staging"));
911 }
912
913 #[test]
914 fn glob_star_middle() {
915 assert!(glob_match("web-*-prod", "web-us-prod"));
916 assert!(glob_match("web-*-prod", "web-eu-east-prod"));
917 assert!(!glob_match("web-*-prod", "web-staging"));
918 }
919
920 #[test]
921 fn glob_star_only() {
922 assert!(glob_match("*", "anything"));
923 assert!(glob_match("*", ""));
924 }
925
926 #[test]
927 fn glob_multiple_stars() {
928 assert!(glob_match("*web*prod*", "my-web-us-prod-1"));
929 assert!(!glob_match("*web*prod*", "my-api-us-staging"));
930 }
931
932 #[test]
935 fn disabled_policy_allows_everything() {
936 let policy = RbacPolicy::new(&RbacConfig {
937 enabled: false,
938 roles: vec![],
939 redaction_salt: None,
940 });
941 assert_eq!(
942 policy.check("nonexistent", "resource_delete", "any-host"),
943 RbacDecision::Allow
944 );
945 }
946
947 #[test]
948 fn unknown_role_denied() {
949 let policy = test_policy();
950 assert_eq!(
951 policy.check("unknown", "resource_list", "web-prod-1"),
952 RbacDecision::Deny
953 );
954 }
955
956 #[test]
957 fn viewer_allowed_read_ops() {
958 let policy = test_policy();
959 assert_eq!(
960 policy.check("viewer", "resource_list", "web-prod-1"),
961 RbacDecision::Allow
962 );
963 assert_eq!(
964 policy.check("viewer", "system_info", "db-host"),
965 RbacDecision::Allow
966 );
967 }
968
969 #[test]
970 fn viewer_denied_write_ops() {
971 let policy = test_policy();
972 assert_eq!(
973 policy.check("viewer", "resource_run", "web-prod-1"),
974 RbacDecision::Deny
975 );
976 assert_eq!(
977 policy.check("viewer", "resource_delete", "web-prod-1"),
978 RbacDecision::Deny
979 );
980 }
981
982 #[test]
983 fn deploy_allowed_on_matching_hosts() {
984 let policy = test_policy();
985 assert_eq!(
986 policy.check("deploy", "resource_run", "web-prod-1"),
987 RbacDecision::Allow
988 );
989 assert_eq!(
990 policy.check("deploy", "resource_start", "api-staging"),
991 RbacDecision::Allow
992 );
993 }
994
995 #[test]
996 fn deploy_denied_on_non_matching_host() {
997 let policy = test_policy();
998 assert_eq!(
999 policy.check("deploy", "resource_run", "db-prod-1"),
1000 RbacDecision::Deny
1001 );
1002 }
1003
1004 #[test]
1005 fn deny_overrides_allow() {
1006 let policy = test_policy();
1007 assert_eq!(
1008 policy.check("deploy", "resource_delete", "web-prod-1"),
1009 RbacDecision::Deny
1010 );
1011 assert_eq!(
1012 policy.check("deploy", "resource_exec", "web-prod-1"),
1013 RbacDecision::Deny
1014 );
1015 }
1016
1017 #[test]
1018 fn ops_wildcard_allows_everything() {
1019 let policy = test_policy();
1020 assert_eq!(
1021 policy.check("ops", "resource_delete", "any-host"),
1022 RbacDecision::Allow
1023 );
1024 assert_eq!(
1025 policy.check("ops", "secret_create", "db-host"),
1026 RbacDecision::Allow
1027 );
1028 }
1029
1030 #[test]
1033 fn host_visible_respects_globs() {
1034 let policy = test_policy();
1035 assert!(policy.host_visible("deploy", "web-prod-1"));
1036 assert!(policy.host_visible("deploy", "api-staging"));
1037 assert!(!policy.host_visible("deploy", "db-prod-1"));
1038 assert!(policy.host_visible("ops", "anything"));
1039 assert!(policy.host_visible("viewer", "anything"));
1040 }
1041
1042 #[test]
1043 fn host_visible_unknown_role() {
1044 let policy = test_policy();
1045 assert!(!policy.host_visible("unknown", "web-prod-1"));
1046 }
1047
1048 #[test]
1051 fn argument_allowed_no_allowlist() {
1052 let policy = test_policy();
1053 assert!(policy.argument_allowed("ops", "resource_exec", "cmd", "rm -rf /"));
1055 assert!(policy.argument_allowed("ops", "resource_exec", "cmd", "bash"));
1056 }
1057
1058 #[test]
1059 fn argument_allowed_with_allowlist() {
1060 let policy = test_policy();
1061 assert!(policy.argument_allowed("restricted-exec", "resource_exec", "cmd", "sh"));
1062 assert!(policy.argument_allowed(
1063 "restricted-exec",
1064 "resource_exec",
1065 "cmd",
1066 "bash -c 'echo hi'"
1067 ));
1068 assert!(policy.argument_allowed(
1069 "restricted-exec",
1070 "resource_exec",
1071 "cmd",
1072 "cat /etc/hosts"
1073 ));
1074 assert!(policy.argument_allowed(
1075 "restricted-exec",
1076 "resource_exec",
1077 "cmd",
1078 "/usr/bin/ls -la"
1079 ));
1080 }
1081
1082 #[test]
1083 fn argument_denied_not_in_allowlist() {
1084 let policy = test_policy();
1085 assert!(!policy.argument_allowed("restricted-exec", "resource_exec", "cmd", "rm -rf /"));
1086 assert!(!policy.argument_allowed(
1087 "restricted-exec",
1088 "resource_exec",
1089 "cmd",
1090 "python3 exploit.py"
1091 ));
1092 assert!(!policy.argument_allowed(
1093 "restricted-exec",
1094 "resource_exec",
1095 "cmd",
1096 "/usr/bin/curl evil.com"
1097 ));
1098 }
1099
1100 #[test]
1101 fn argument_denied_unknown_role() {
1102 let policy = test_policy();
1103 assert!(!policy.argument_allowed("unknown", "resource_exec", "cmd", "sh"));
1104 }
1105
1106 #[test]
1109 fn host_patterns_returns_globs() {
1110 let policy = test_policy();
1111 assert_eq!(
1112 policy.host_patterns("deploy"),
1113 Some(vec!["web-*".to_owned(), "api-*".to_owned()].as_slice())
1114 );
1115 assert_eq!(
1116 policy.host_patterns("ops"),
1117 Some(vec!["*".to_owned()].as_slice())
1118 );
1119 assert!(policy.host_patterns("nonexistent").is_none());
1120 }
1121
1122 #[test]
1125 fn check_operation_allows_without_host() {
1126 let policy = test_policy();
1127 assert_eq!(
1128 policy.check_operation("deploy", "resource_run"),
1129 RbacDecision::Allow
1130 );
1131 assert_eq!(
1133 policy.check("deploy", "resource_run", "db-prod-1"),
1134 RbacDecision::Deny
1135 );
1136 }
1137
1138 #[test]
1139 fn check_operation_deny_overrides() {
1140 let policy = test_policy();
1141 assert_eq!(
1142 policy.check_operation("deploy", "resource_delete"),
1143 RbacDecision::Deny
1144 );
1145 }
1146
1147 #[test]
1148 fn check_operation_unknown_role() {
1149 let policy = test_policy();
1150 assert_eq!(
1151 policy.check_operation("unknown", "resource_list"),
1152 RbacDecision::Deny
1153 );
1154 }
1155
1156 #[test]
1157 fn check_operation_disabled() {
1158 let policy = RbacPolicy::new(&RbacConfig {
1159 enabled: false,
1160 roles: vec![],
1161 redaction_salt: None,
1162 });
1163 assert_eq!(
1164 policy.check_operation("nonexistent", "anything"),
1165 RbacDecision::Allow
1166 );
1167 }
1168
1169 #[test]
1172 fn current_role_returns_none_outside_scope() {
1173 assert!(current_role().is_none());
1174 }
1175
1176 #[test]
1177 fn current_identity_returns_none_outside_scope() {
1178 assert!(current_identity().is_none());
1179 }
1180
1181 use axum::{
1184 body::Body,
1185 http::{Method, Request, StatusCode},
1186 };
1187 use tower::ServiceExt as _;
1188
1189 fn tool_call_body(tool: &str, args: &serde_json::Value) -> String {
1190 serde_json::json!({
1191 "jsonrpc": "2.0",
1192 "id": 1,
1193 "method": "tools/call",
1194 "params": {
1195 "name": tool,
1196 "arguments": args
1197 }
1198 })
1199 .to_string()
1200 }
1201
1202 fn rbac_router(policy: Arc<RbacPolicy>) -> axum::Router {
1203 axum::Router::new()
1204 .route("/mcp", axum::routing::post(|| async { "ok" }))
1205 .layer(axum::middleware::from_fn(move |req, next| {
1206 let p = Arc::clone(&policy);
1207 rbac_middleware(p, None, req, next)
1208 }))
1209 }
1210
1211 fn rbac_router_with_identity(policy: Arc<RbacPolicy>, identity: AuthIdentity) -> axum::Router {
1212 axum::Router::new()
1213 .route("/mcp", axum::routing::post(|| async { "ok" }))
1214 .layer(axum::middleware::from_fn(
1215 move |mut req: Request<Body>, next: Next| {
1216 let p = Arc::clone(&policy);
1217 let id = identity.clone();
1218 async move {
1219 req.extensions_mut().insert(id);
1220 rbac_middleware(p, None, req, next).await
1221 }
1222 },
1223 ))
1224 }
1225
1226 #[tokio::test]
1227 async fn middleware_passes_non_post() {
1228 let policy = Arc::new(test_policy());
1229 let app = rbac_router(policy);
1230 let req = Request::builder()
1232 .method(Method::GET)
1233 .uri("/mcp")
1234 .body(Body::empty())
1235 .unwrap();
1236 let resp = app.oneshot(req).await.unwrap();
1239 assert_eq!(resp.status(), StatusCode::METHOD_NOT_ALLOWED);
1240 }
1241
1242 #[tokio::test]
1243 async fn middleware_denies_without_identity() {
1244 let policy = Arc::new(test_policy());
1245 let app = rbac_router(policy);
1246 let body = tool_call_body("resource_list", &serde_json::json!({}));
1247 let req = Request::builder()
1248 .method(Method::POST)
1249 .uri("/mcp")
1250 .header("content-type", "application/json")
1251 .body(Body::from(body))
1252 .unwrap();
1253 let resp = app.oneshot(req).await.unwrap();
1254 assert_eq!(resp.status(), StatusCode::FORBIDDEN);
1255 }
1256
1257 #[tokio::test]
1258 async fn middleware_allows_permitted_tool() {
1259 let policy = Arc::new(test_policy());
1260 let id = AuthIdentity {
1261 method: crate::auth::AuthMethod::BearerToken,
1262 name: "alice".into(),
1263 role: "viewer".into(),
1264 raw_token: None,
1265 sub: None,
1266 };
1267 let app = rbac_router_with_identity(policy, id);
1268 let body = tool_call_body("resource_list", &serde_json::json!({}));
1269 let req = Request::builder()
1270 .method(Method::POST)
1271 .uri("/mcp")
1272 .header("content-type", "application/json")
1273 .body(Body::from(body))
1274 .unwrap();
1275 let resp = app.oneshot(req).await.unwrap();
1276 assert_eq!(resp.status(), StatusCode::OK);
1277 }
1278
1279 #[tokio::test]
1280 async fn middleware_denies_unpermitted_tool() {
1281 let policy = Arc::new(test_policy());
1282 let id = AuthIdentity {
1283 method: crate::auth::AuthMethod::BearerToken,
1284 name: "alice".into(),
1285 role: "viewer".into(),
1286 raw_token: None,
1287 sub: None,
1288 };
1289 let app = rbac_router_with_identity(policy, id);
1290 let body = tool_call_body("resource_delete", &serde_json::json!({}));
1291 let req = Request::builder()
1292 .method(Method::POST)
1293 .uri("/mcp")
1294 .header("content-type", "application/json")
1295 .body(Body::from(body))
1296 .unwrap();
1297 let resp = app.oneshot(req).await.unwrap();
1298 assert_eq!(resp.status(), StatusCode::FORBIDDEN);
1299 }
1300
1301 #[tokio::test]
1302 async fn middleware_passes_non_tool_call_post() {
1303 let policy = Arc::new(test_policy());
1304 let id = AuthIdentity {
1305 method: crate::auth::AuthMethod::BearerToken,
1306 name: "alice".into(),
1307 role: "viewer".into(),
1308 raw_token: None,
1309 sub: None,
1310 };
1311 let app = rbac_router_with_identity(policy, id);
1312 let body = serde_json::json!({
1314 "jsonrpc": "2.0",
1315 "id": 1,
1316 "method": "resources/list"
1317 })
1318 .to_string();
1319 let req = Request::builder()
1320 .method(Method::POST)
1321 .uri("/mcp")
1322 .header("content-type", "application/json")
1323 .body(Body::from(body))
1324 .unwrap();
1325 let resp = app.oneshot(req).await.unwrap();
1326 assert_eq!(resp.status(), StatusCode::OK);
1327 }
1328
1329 #[tokio::test]
1330 async fn middleware_enforces_argument_allowlist() {
1331 let policy = Arc::new(test_policy());
1332 let id = AuthIdentity {
1333 method: crate::auth::AuthMethod::BearerToken,
1334 name: "dev".into(),
1335 role: "restricted-exec".into(),
1336 raw_token: None,
1337 sub: None,
1338 };
1339 let app = rbac_router_with_identity(Arc::clone(&policy), id.clone());
1341 let body = tool_call_body(
1342 "resource_exec",
1343 &serde_json::json!({"cmd": "ls -la", "host": "dev-1"}),
1344 );
1345 let req = Request::builder()
1346 .method(Method::POST)
1347 .uri("/mcp")
1348 .body(Body::from(body))
1349 .unwrap();
1350 let resp = app.oneshot(req).await.unwrap();
1351 assert_eq!(resp.status(), StatusCode::OK);
1352
1353 let app = rbac_router_with_identity(policy, id);
1355 let body = tool_call_body(
1356 "resource_exec",
1357 &serde_json::json!({"cmd": "rm -rf /", "host": "dev-1"}),
1358 );
1359 let req = Request::builder()
1360 .method(Method::POST)
1361 .uri("/mcp")
1362 .body(Body::from(body))
1363 .unwrap();
1364 let resp = app.oneshot(req).await.unwrap();
1365 assert_eq!(resp.status(), StatusCode::FORBIDDEN);
1366 }
1367
1368 #[tokio::test]
1369 async fn middleware_disabled_policy_passes_everything() {
1370 let policy = Arc::new(RbacPolicy::disabled());
1371 let app = rbac_router(policy);
1372 let body = tool_call_body("anything", &serde_json::json!({}));
1374 let req = Request::builder()
1375 .method(Method::POST)
1376 .uri("/mcp")
1377 .body(Body::from(body))
1378 .unwrap();
1379 let resp = app.oneshot(req).await.unwrap();
1380 assert_eq!(resp.status(), StatusCode::OK);
1381 }
1382
1383 #[test]
1386 fn redact_with_salt_is_deterministic_per_salt() {
1387 let salt = b"unit-test-salt";
1388 let a = redact_with_salt(salt, "rm -rf /");
1389 let b = redact_with_salt(salt, "rm -rf /");
1390 assert_eq!(a, b, "same input + salt must yield identical hash");
1391 assert_eq!(a.len(), 8, "redacted hash is 8 hex chars (4 bytes)");
1392 assert!(
1393 a.chars().all(|c| c.is_ascii_hexdigit()),
1394 "redacted hash must be lowercase hex: {a}"
1395 );
1396 }
1397
1398 #[test]
1399 fn redact_with_salt_differs_across_salts() {
1400 let v = "the-same-value";
1401 let h1 = redact_with_salt(b"salt-one", v);
1402 let h2 = redact_with_salt(b"salt-two", v);
1403 assert_ne!(
1404 h1, h2,
1405 "different salts must produce different hashes for the same value"
1406 );
1407 }
1408
1409 #[test]
1410 fn redact_with_salt_distinguishes_values() {
1411 let salt = b"k";
1412 let h1 = redact_with_salt(salt, "alpha");
1413 let h2 = redact_with_salt(salt, "beta");
1414 assert_ne!(h1, h2, "different values must produce different hashes");
1416 }
1417
1418 #[test]
1419 fn policy_with_configured_salt_redacts_consistently() {
1420 let cfg = RbacConfig {
1421 enabled: true,
1422 roles: vec![],
1423 redaction_salt: Some(SecretString::from("my-stable-salt")),
1424 };
1425 let p1 = RbacPolicy::new(&cfg);
1426 let p2 = RbacPolicy::new(&cfg);
1427 assert_eq!(
1428 p1.redact_arg("payload"),
1429 p2.redact_arg("payload"),
1430 "policies built from the same configured salt must agree"
1431 );
1432 }
1433
1434 #[test]
1435 fn policy_without_configured_salt_uses_process_salt() {
1436 let cfg = RbacConfig {
1437 enabled: true,
1438 roles: vec![],
1439 redaction_salt: None,
1440 };
1441 let p1 = RbacPolicy::new(&cfg);
1442 let p2 = RbacPolicy::new(&cfg);
1443 assert_eq!(
1445 p1.redact_arg("payload"),
1446 p2.redact_arg("payload"),
1447 "process-wide salt must be consistent within one process"
1448 );
1449 }
1450
1451 #[test]
1452 fn redact_arg_is_fast_enough() {
1453 let salt = b"perf-sanity-salt-32-bytes-padded";
1457 let value = "x".repeat(256);
1458 let start = std::time::Instant::now();
1459 let _ = redact_with_salt(salt, &value);
1460 let elapsed = start.elapsed();
1461 assert!(
1462 elapsed < Duration::from_millis(5),
1463 "single redact_with_salt took {elapsed:?}, expected <5 ms even in debug"
1464 );
1465 }
1466
1467 #[tokio::test]
1479 async fn deny_path_uses_explicit_identity_not_task_local() {
1480 let policy = Arc::new(test_policy());
1481 let id = AuthIdentity {
1482 method: crate::auth::AuthMethod::BearerToken,
1483 name: "alice-the-auditor".into(),
1484 role: "viewer".into(),
1485 raw_token: None,
1486 sub: None,
1487 };
1488 let app = rbac_router_with_identity(policy, id);
1489 let body = tool_call_body("resource_delete", &serde_json::json!({}));
1491 let req = Request::builder()
1492 .method(Method::POST)
1493 .uri("/mcp")
1494 .header("content-type", "application/json")
1495 .body(Body::from(body))
1496 .unwrap();
1497 let resp = app.oneshot(req).await.unwrap();
1498 assert_eq!(resp.status(), StatusCode::FORBIDDEN);
1499 }
1500}