Skip to main content

rmcp_server_kit/
rbac.rs

1//! Role-Based Access Control (RBAC) policy engine.
2//!
3//! Evaluates `(role, operation, host)` tuples against a set of role
4//! definitions loaded from config.  Deny-overrides-allow semantics:
5//! an explicit deny entry always wins over a wildcard allow.
6//!
7//! Includes an axum middleware that inspects MCP JSON-RPC tool calls
8//! and enforces RBAC and per-IP tool rate limiting before the request
9//! reaches the handler.
10
11use std::{net::IpAddr, num::NonZeroU32, sync::Arc, time::Duration};
12
13use axum::{
14    body::Body,
15    extract::ConnectInfo,
16    http::{Method, Request, StatusCode},
17    middleware::Next,
18    response::{IntoResponse, Response},
19};
20use hmac::{Hmac, Mac};
21use http_body_util::BodyExt;
22use secrecy::{ExposeSecret, SecretString};
23use serde::Deserialize;
24use sha2::Sha256;
25
26use crate::{
27    auth::{AuthIdentity, TlsConnInfo},
28    bounded_limiter::BoundedKeyedLimiter,
29    error::McpxError,
30};
31
32/// Per-source-IP rate limiter for tool invocations. Memory-bounded against
33/// IP-spray `DoS` via [`BoundedKeyedLimiter`].
34pub(crate) type ToolRateLimiter = BoundedKeyedLimiter<IpAddr>;
35
36/// Default tool rate limit: 120 invocations per minute per source IP.
37// SAFETY: unwrap() is safe - literal 120 is provably non-zero (const-evaluated).
38const DEFAULT_TOOL_RATE: NonZeroU32 = NonZeroU32::new(120).unwrap();
39
40/// Default cap on the number of distinct source IPs tracked by the tool
41/// rate limiter. Bounded to defend against IP-spray `DoS` exhausting memory.
42const DEFAULT_TOOL_MAX_TRACKED_KEYS: usize = 10_000;
43
44/// Default idle-eviction window for the tool rate limiter (15 minutes).
45const DEFAULT_TOOL_IDLE_EVICTION: Duration = Duration::from_mins(15);
46
47/// Build a per-IP tool rate limiter from a max-calls-per-minute value.
48///
49/// Memory-bounded with `DEFAULT_TOOL_MAX_TRACKED_KEYS` tracked keys and
50/// `DEFAULT_TOOL_IDLE_EVICTION` idle eviction. Use
51/// [`build_tool_rate_limiter_with_bounds`] to override.
52#[must_use]
53pub(crate) fn build_tool_rate_limiter(max_per_minute: u32) -> Arc<ToolRateLimiter> {
54    build_tool_rate_limiter_with_bounds(
55        max_per_minute,
56        DEFAULT_TOOL_MAX_TRACKED_KEYS,
57        DEFAULT_TOOL_IDLE_EVICTION,
58    )
59}
60
61/// Build a per-IP tool rate limiter with explicit memory-bound parameters.
62#[must_use]
63pub(crate) fn build_tool_rate_limiter_with_bounds(
64    max_per_minute: u32,
65    max_tracked_keys: usize,
66    idle_eviction: Duration,
67) -> Arc<ToolRateLimiter> {
68    let quota =
69        governor::Quota::per_minute(NonZeroU32::new(max_per_minute).unwrap_or(DEFAULT_TOOL_RATE));
70    Arc::new(BoundedKeyedLimiter::new(
71        quota,
72        max_tracked_keys,
73        idle_eviction,
74    ))
75}
76
77// Task-local storage for the current caller's RBAC role and identity name.
78// Set by the RBAC middleware, read by tool handlers (e.g. list_hosts filtering, audit logging).
79//
80// `CURRENT_TOKEN` holds a [`SecretString`] so the raw bearer token is never
81// printed via `Debug` (it formats as `"[REDACTED alloc::string::String]"`)
82// and is zeroized on drop by the `secrecy` crate.
83tokio::task_local! {
84    static CURRENT_ROLE: String;
85    static CURRENT_IDENTITY: String;
86    static CURRENT_TOKEN: SecretString;
87    static CURRENT_SUB: String;
88}
89
90/// Get the current caller's RBAC role (set by RBAC middleware).
91/// Returns `None` outside an RBAC-scoped request context.
92#[must_use]
93pub fn current_role() -> Option<String> {
94    CURRENT_ROLE.try_with(Clone::clone).ok()
95}
96
97/// Get the current caller's identity name (set by RBAC middleware).
98/// Returns `None` outside an RBAC-scoped request context.
99#[must_use]
100pub fn current_identity() -> Option<String> {
101    CURRENT_IDENTITY.try_with(Clone::clone).ok()
102}
103
104/// Get the raw bearer token for the current request as a [`SecretString`].
105/// Returns `None` outside a request context or when auth used mTLS/API-key.
106/// Tool handlers use this for downstream token passthrough.
107///
108/// The returned value is wrapped in [`SecretString`] so it does not leak
109/// via `Debug`/`Display`/serde. Call `.expose_secret()` only when the
110/// raw value is actually needed (e.g. as the `Authorization` header on
111/// an outbound HTTP request).
112///
113/// An empty token is treated as absent (returns `None`); this preserves
114/// backward compatibility with the prior `Option<String>` API where the
115/// empty default sentinel meant "no token".
116#[must_use]
117pub fn current_token() -> Option<SecretString> {
118    CURRENT_TOKEN
119        .try_with(|t| {
120            if t.expose_secret().is_empty() {
121                None
122            } else {
123                Some(t.clone())
124            }
125        })
126        .ok()
127        .flatten()
128}
129
130/// Get the JWT `sub` claim (stable user ID, e.g. Keycloak UUID).
131/// Returns `None` outside a request context or for non-JWT auth.
132/// Use for stable per-user keying (token store, etc.).
133#[must_use]
134pub fn current_sub() -> Option<String> {
135    CURRENT_SUB
136        .try_with(Clone::clone)
137        .ok()
138        .filter(|s| !s.is_empty())
139}
140
141/// Run a future with `CURRENT_TOKEN` set so that [`current_token()`] returns
142/// the given value inside the future. Useful when MCP tool handlers need the
143/// raw bearer token but run in a spawned task where the RBAC middleware's
144/// task-local scope is no longer active.
145pub async fn with_token_scope<F: Future>(token: SecretString, f: F) -> F::Output {
146    CURRENT_TOKEN.scope(token, f).await
147}
148
149/// Run a future with all task-locals (`CURRENT_ROLE`, `CURRENT_IDENTITY`,
150/// `CURRENT_TOKEN`, `CURRENT_SUB`) set.  Use this when re-establishing the
151/// full RBAC context in spawned tasks (e.g. rmcp session tasks) where the
152/// middleware's scope is no longer active.
153pub async fn with_rbac_scope<F: Future>(
154    role: String,
155    identity: String,
156    token: SecretString,
157    sub: String,
158    f: F,
159) -> F::Output {
160    CURRENT_ROLE
161        .scope(
162            role,
163            CURRENT_IDENTITY.scope(
164                identity,
165                CURRENT_TOKEN.scope(token, CURRENT_SUB.scope(sub, f)),
166            ),
167        )
168        .await
169}
170
171/// A single role definition.
172#[derive(Debug, Clone, Deserialize)]
173#[non_exhaustive]
174pub struct RoleConfig {
175    /// Role identifier referenced from identities (API keys, mTLS, JWT claims).
176    pub name: String,
177    /// Human-readable description, surfaced in diagnostics only.
178    #[serde(default)]
179    pub description: Option<String>,
180    /// Allowed operations.  `["*"]` means all operations.
181    #[serde(default)]
182    pub allow: Vec<String>,
183    /// Explicitly denied operations (overrides allow).
184    #[serde(default)]
185    pub deny: Vec<String>,
186    /// Host name glob patterns this role can access. `["*"]` means all hosts.
187    #[serde(default = "default_hosts")]
188    pub hosts: Vec<String>,
189    /// Per-tool argument constraints. When a tool call matches, the
190    /// specified argument's first whitespace-delimited token (or its
191    /// `/`-basename) must appear in the allowlist.
192    #[serde(default)]
193    pub argument_allowlists: Vec<ArgumentAllowlist>,
194}
195
196impl RoleConfig {
197    /// Create a role with the given name, allowed operations, and host patterns.
198    #[must_use]
199    pub fn new(name: impl Into<String>, allow: Vec<String>, hosts: Vec<String>) -> Self {
200        Self {
201            name: name.into(),
202            description: None,
203            allow,
204            deny: vec![],
205            hosts,
206            argument_allowlists: vec![],
207        }
208    }
209
210    /// Attach argument allowlists to this role.
211    #[must_use]
212    pub fn with_argument_allowlists(mut self, allowlists: Vec<ArgumentAllowlist>) -> Self {
213        self.argument_allowlists = allowlists;
214        self
215    }
216}
217
218/// Per-tool argument allowlist entry.
219///
220/// When the middleware sees a `tools/call` for `tool`, it extracts the
221/// string value at `argument` from the call's arguments object and checks
222/// its first token against `allowed`. If the token is not in the list
223/// the call is rejected with 403.
224#[derive(Debug, Clone, Deserialize)]
225#[non_exhaustive]
226pub struct ArgumentAllowlist {
227    /// Tool name to match (exact or glob, e.g. `"run_query"`).
228    pub tool: String,
229    /// Argument key whose value is checked (e.g. `"cmd"`, `"query"`).
230    pub argument: String,
231    /// Permitted first-token values. Empty means unrestricted.
232    #[serde(default)]
233    pub allowed: Vec<String>,
234}
235
236impl ArgumentAllowlist {
237    /// Create an argument allowlist for a tool.
238    #[must_use]
239    pub fn new(tool: impl Into<String>, argument: impl Into<String>, allowed: Vec<String>) -> Self {
240        Self {
241            tool: tool.into(),
242            argument: argument.into(),
243            allowed,
244        }
245    }
246}
247
248fn default_hosts() -> Vec<String> {
249    vec!["*".into()]
250}
251
252/// Top-level RBAC configuration (deserializable from TOML).
253#[derive(Debug, Clone, Default, Deserialize)]
254#[non_exhaustive]
255pub struct RbacConfig {
256    /// Master switch -- when false, the RBAC middleware is not installed.
257    #[serde(default)]
258    pub enabled: bool,
259    /// Role definitions available to identities.
260    #[serde(default)]
261    pub roles: Vec<RoleConfig>,
262    /// Optional stable HMAC key (any length) used to redact argument
263    /// values in deny logs. When set, redacted hashes are stable across
264    /// process restarts (useful for log correlation across deploys).
265    /// When `None`, a random 32-byte key is generated per process at
266    /// first use; redacted hashes change every restart.
267    ///
268    /// The key is wrapped in [`SecretString`] so it never leaks via
269    /// `Debug`/`Display`/serde and is zeroized on drop.
270    #[serde(default)]
271    pub redaction_salt: Option<SecretString>,
272}
273
274impl RbacConfig {
275    /// Create an enabled RBAC config with the given roles.
276    #[must_use]
277    pub fn with_roles(roles: Vec<RoleConfig>) -> Self {
278        Self {
279            enabled: true,
280            roles,
281            redaction_salt: None,
282        }
283    }
284}
285
286/// Result of an RBAC policy check.
287#[derive(Debug, Clone, Copy, PartialEq, Eq)]
288#[non_exhaustive]
289pub enum RbacDecision {
290    /// Caller is permitted to perform the requested operation.
291    Allow,
292    /// Caller is denied access.
293    Deny,
294}
295
296/// Summary of a single role, produced by [`RbacPolicy::summary`].
297#[derive(Debug, Clone, serde::Serialize)]
298#[non_exhaustive]
299pub struct RbacRoleSummary {
300    /// Role name.
301    pub name: String,
302    /// Number of allow entries.
303    pub allow: usize,
304    /// Number of deny entries.
305    pub deny: usize,
306    /// Number of host patterns.
307    pub hosts: usize,
308    /// Number of argument allowlist entries.
309    pub argument_allowlists: usize,
310}
311
312/// Summary of the whole RBAC policy, produced by [`RbacPolicy::summary`].
313#[derive(Debug, Clone, serde::Serialize)]
314#[non_exhaustive]
315pub struct RbacPolicySummary {
316    /// Whether RBAC enforcement is active.
317    pub enabled: bool,
318    /// Per-role summaries.
319    pub roles: Vec<RbacRoleSummary>,
320}
321
322/// Compiled RBAC policy for fast lookup.
323///
324/// Built from [`RbacConfig`] at startup.  All lookups are O(n) over the
325/// role's allow/deny/host lists, which is fine for the expected cardinality
326/// (a handful of roles with tens of entries each).
327#[derive(Debug, Clone)]
328#[non_exhaustive]
329pub struct RbacPolicy {
330    roles: Vec<RoleConfig>,
331    enabled: bool,
332    /// HMAC key used to redact argument values in deny logs.
333    /// Either a configured stable salt or a per-process random salt.
334    redaction_salt: Arc<SecretString>,
335}
336
337impl RbacPolicy {
338    /// Build a policy from config.  When `config.enabled` is false, all
339    /// checks return [`RbacDecision::Allow`].
340    #[must_use]
341    pub fn new(config: &RbacConfig) -> Self {
342        let salt = config
343            .redaction_salt
344            .clone()
345            .unwrap_or_else(|| process_redaction_salt().clone());
346        Self {
347            roles: config.roles.clone(),
348            enabled: config.enabled,
349            redaction_salt: Arc::new(salt),
350        }
351    }
352
353    /// Create a policy that always allows (RBAC disabled).
354    #[must_use]
355    pub fn disabled() -> Self {
356        Self {
357            roles: Vec::new(),
358            enabled: false,
359            redaction_salt: Arc::new(process_redaction_salt().clone()),
360        }
361    }
362
363    /// Whether RBAC enforcement is active.
364    #[must_use]
365    pub fn is_enabled(&self) -> bool {
366        self.enabled
367    }
368
369    /// Summarize the policy for diagnostics (admin endpoint).
370    ///
371    /// Returns `(enabled, role_count, per_role_stats)` where each stat is
372    /// `(name, allow_count, deny_count, host_count, argument_allowlist_count)`.
373    #[must_use]
374    pub fn summary(&self) -> RbacPolicySummary {
375        let roles = self
376            .roles
377            .iter()
378            .map(|r| RbacRoleSummary {
379                name: r.name.clone(),
380                allow: r.allow.len(),
381                deny: r.deny.len(),
382                hosts: r.hosts.len(),
383                argument_allowlists: r.argument_allowlists.len(),
384            })
385            .collect();
386        RbacPolicySummary {
387            enabled: self.enabled,
388            roles,
389        }
390    }
391
392    /// Check whether `role` may perform `operation` (ignoring host).
393    ///
394    /// Use this for tools that don't target a specific host (e.g. `ping`,
395    /// `list_hosts`).
396    #[must_use]
397    pub fn check_operation(&self, role: &str, operation: &str) -> RbacDecision {
398        if !self.enabled {
399            return RbacDecision::Allow;
400        }
401        let Some(role_cfg) = self.find_role(role) else {
402            return RbacDecision::Deny;
403        };
404        if role_cfg.deny.iter().any(|d| d == operation) {
405            return RbacDecision::Deny;
406        }
407        if role_cfg.allow.iter().any(|a| a == "*" || a == operation) {
408            return RbacDecision::Allow;
409        }
410        RbacDecision::Deny
411    }
412
413    /// Check whether `role` may perform `operation` on `host`.
414    ///
415    /// Evaluation order:
416    /// 1. If RBAC is disabled, allow.
417    /// 2. Check operation permission (deny overrides allow).
418    /// 3. Check host visibility via glob matching.
419    #[must_use]
420    pub fn check(&self, role: &str, operation: &str, host: &str) -> RbacDecision {
421        if !self.enabled {
422            return RbacDecision::Allow;
423        }
424        let Some(role_cfg) = self.find_role(role) else {
425            return RbacDecision::Deny;
426        };
427        if role_cfg.deny.iter().any(|d| d == operation) {
428            return RbacDecision::Deny;
429        }
430        if !role_cfg.allow.iter().any(|a| a == "*" || a == operation) {
431            return RbacDecision::Deny;
432        }
433        if !Self::host_matches(&role_cfg.hosts, host) {
434            return RbacDecision::Deny;
435        }
436        RbacDecision::Allow
437    }
438
439    /// Check whether `role` can see `host` at all (for `list_hosts` filtering).
440    #[must_use]
441    pub fn host_visible(&self, role: &str, host: &str) -> bool {
442        if !self.enabled {
443            return true;
444        }
445        let Some(role_cfg) = self.find_role(role) else {
446            return false;
447        };
448        Self::host_matches(&role_cfg.hosts, host)
449    }
450
451    /// Get the list of hosts patterns for a role.
452    #[must_use]
453    pub fn host_patterns(&self, role: &str) -> Option<&[String]> {
454        self.find_role(role).map(|r| r.hosts.as_slice())
455    }
456
457    /// Check whether `value` passes the argument allowlists for `tool` under `role`.
458    ///
459    /// If the role has no matching `argument_allowlists` entry for the tool,
460    /// all values are allowed. When a matching entry exists, `value` is
461    /// tokenized using POSIX-shell-like lexical rules ([`shlex::split`])
462    /// and its first argv element (or the `/`-basename of that element)
463    /// must appear in the `allowed` list.
464    ///
465    /// **Scope of the contract.** This matcher targets consumers that
466    /// interpret string arguments as POSIX-shell-like command lines on
467    /// Unix-like systems (e.g. anything that subsequently feeds the value
468    /// through `shlex` or an equivalent splitter before `execve`). It
469    /// does **not** model real shell *execution* grammar (`FOO=1 cmd`,
470    /// expansion, command substitution, redirection, operators) or
471    /// Windows command-line tokenization (`CommandLineToArgvW`,
472    /// `cmd.exe`, PowerShell). Consumers in those regimes remain subject
473    /// to a parser differential and must validate at their own boundary.
474    ///
475    /// **Fail-closed cases (all return `false` when a matching allowlist
476    /// entry exists):**
477    ///
478    /// - `value` fails to parse as a POSIX-shell-like command line
479    ///   (e.g. unbalanced quotes, dangling escape).
480    /// - `value` parses to zero tokens (empty input).
481    /// - The first parsed token is the empty string (e.g.
482    ///   `value = r#""""#` parses to `Some(vec![""])`). An empty argv
483    ///   element is never a runnable executable, so we reject even when
484    ///   `""` is in the allowlist.
485    #[must_use]
486    pub fn argument_allowed(&self, role: &str, tool: &str, argument: &str, value: &str) -> bool {
487        if !self.enabled {
488            return true;
489        }
490        let Some(role_cfg) = self.find_role(role) else {
491            return false;
492        };
493        for al in &role_cfg.argument_allowlists {
494            if al.tool != tool && !glob_match(&al.tool, tool) {
495                continue;
496            }
497            if al.argument != argument {
498                continue;
499            }
500            if al.allowed.is_empty() {
501                continue;
502            }
503            // Tokenize per POSIX-shell-like rules so quoted paths with
504            // spaces match what an equivalently-tokenizing consumer
505            // would actually run, and malformed shell syntax (unbalanced
506            // quotes, dangling escapes) fails closed.
507            let Some(tokens) = shlex::split(value) else {
508                return false;
509            };
510            let Some(first_token) = tokens.first() else {
511                return false;
512            };
513            // A well-formed but empty first argv element (e.g.
514            // value = r#""""#) is never a runnable executable. Fail
515            // closed even if "" appears in the allowlist.
516            if first_token.is_empty() {
517                return false;
518            }
519            // Also match against the basename if it's a path. POSIX
520            // separator only; Windows-style backslash paths are out of
521            // scope and will not basename-match (see crate-level docs).
522            let basename = first_token
523                .rsplit('/')
524                .next()
525                .unwrap_or(first_token.as_str());
526            if !al.allowed.iter().any(|a| a == first_token || a == basename) {
527                return false;
528            }
529        }
530        true
531    }
532
533    /// Return the role config for a given role name.
534    fn find_role(&self, name: &str) -> Option<&RoleConfig> {
535        self.roles.iter().find(|r| r.name == name)
536    }
537
538    /// Check if a host name matches any of the given glob patterns.
539    fn host_matches(patterns: &[String], host: &str) -> bool {
540        patterns.iter().any(|p| glob_match(p, host))
541    }
542
543    /// HMAC-SHA256 the given argument value with this policy's redaction
544    /// salt and return the first 8 hex characters (4 bytes / 32 bits).
545    ///
546    /// 32 bits is enough entropy for log correlation (1-in-4-billion
547    /// collision per pair) while being far short of any preimage attack
548    /// surface for an attacker reading logs. The HMAC construction
549    /// guarantees that even short or low-entropy values cannot be
550    /// recovered without the key.
551    #[must_use]
552    pub fn redact_arg(&self, value: &str) -> String {
553        redact_with_salt(self.redaction_salt.expose_secret().as_bytes(), value)
554    }
555}
556
557/// Process-wide random redaction salt, lazily generated on first use.
558/// Used when [`RbacConfig::redaction_salt`] is `None`.
559fn process_redaction_salt() -> &'static SecretString {
560    use base64::{Engine as _, engine::general_purpose::STANDARD_NO_PAD};
561    static PROCESS_SALT: std::sync::OnceLock<SecretString> = std::sync::OnceLock::new();
562    PROCESS_SALT.get_or_init(|| {
563        let mut bytes = [0u8; 32];
564        rand::fill(&mut bytes);
565        // base64-encode so the SecretString is valid UTF-8; the HMAC
566        // accepts arbitrary key bytes regardless.
567        SecretString::from(STANDARD_NO_PAD.encode(bytes))
568    })
569}
570
571/// HMAC-SHA256(`salt`, `value`) → first 8 hex chars.
572///
573/// Pulled out as a free function so it can be unit-tested and benchmarked
574/// without constructing a full [`RbacPolicy`].
575fn redact_with_salt(salt: &[u8], value: &str) -> String {
576    use std::fmt::Write as _;
577
578    use sha2::Digest as _;
579
580    type HmacSha256 = Hmac<Sha256>;
581    // HMAC-SHA256 accepts keys of any byte length: the spec pads short
582    // keys with zeros and hashes long keys, so `new_from_slice` is
583    // infallible here. We still defensively re-key with a SHA-256 of
584    // the salt if construction ever fails (e.g. future hmac upstream
585    // tightens the contract); both branches produce a valid keyed MAC.
586    let mut mac = if let Ok(m) = HmacSha256::new_from_slice(salt) {
587        m
588    } else {
589        let digest = Sha256::digest(salt);
590        #[allow(clippy::expect_used)] // 32-byte digest always valid as HMAC key
591        HmacSha256::new_from_slice(&digest).expect("32-byte SHA256 digest is valid HMAC key")
592    };
593    mac.update(value.as_bytes());
594    let bytes = mac.finalize().into_bytes();
595    // 4 bytes → 8 hex chars.
596    let prefix = bytes.get(..4).unwrap_or(&[0; 4]);
597    let mut out = String::with_capacity(8);
598    for b in prefix {
599        let _ = write!(out, "{b:02x}");
600    }
601    out
602}
603
604// -- RBAC middleware --
605
606/// Axum middleware that enforces RBAC and per-IP tool rate limiting on
607/// MCP tool calls.
608///
609/// Inspects POST request bodies for `tools/call` JSON-RPC messages,
610/// extracts the tool name and `host` argument, and checks the
611/// [`RbacPolicy`] against the [`AuthIdentity`] set by the auth middleware.
612///
613/// When a `tool_limiter` is provided, tool invocations are rate-limited
614/// per source IP regardless of whether RBAC is enabled (MCP spec: servers
615/// MUST rate limit tool invocations).
616///
617/// Non-POST requests and non-tool-call messages pass through unchanged.
618/// The caller's role is stored in task-local storage for use by tool
619/// handlers (e.g. `list_hosts` host filtering via [`current_role()`]).
620// TODO(refactor): cognitive complexity reduced from 43/25 by extracting
621// `enforce_tool_policy` and `enforce_rate_limit`. Remaining flow is a
622// linear body-collect + JSON-RPC parse + dispatch, intentionally left
623// inline to keep the request lifecycle visible at a glance.
624#[allow(clippy::too_many_lines)]
625pub(crate) async fn rbac_middleware(
626    policy: Arc<RbacPolicy>,
627    tool_limiter: Option<Arc<ToolRateLimiter>>,
628    req: Request<Body>,
629    next: Next,
630) -> Response {
631    // Only inspect POST requests - tool calls are POSTs.
632    if req.method() != Method::POST {
633        return next.run(req).await;
634    }
635
636    // Extract peer IP for rate limiting.
637    let peer_ip: Option<IpAddr> = req
638        .extensions()
639        .get::<ConnectInfo<std::net::SocketAddr>>()
640        .map(|ci| ci.0.ip())
641        .or_else(|| {
642            req.extensions()
643                .get::<ConnectInfo<TlsConnInfo>>()
644                .map(|ci| ci.0.addr.ip())
645        });
646
647    // Extract caller identity and role (may be absent when auth is off).
648    let identity = req.extensions().get::<AuthIdentity>();
649    let identity_name = identity.map(|id| id.name.clone()).unwrap_or_default();
650    let role = identity.map(|id| id.role.clone()).unwrap_or_default();
651    // Clone the SecretString end-to-end; an absent token becomes an empty
652    // SecretString sentinel (current_token() filters this out as None).
653    let raw_token: SecretString = identity
654        .and_then(|id| id.raw_token.clone())
655        .unwrap_or_else(|| SecretString::from(String::new()));
656    let sub = identity.and_then(|id| id.sub.clone()).unwrap_or_default();
657
658    // RBAC requires an authenticated identity.
659    if policy.is_enabled() && identity.is_none() {
660        return McpxError::Rbac("no authenticated identity".into()).into_response();
661    }
662
663    // Read the body for JSON-RPC inspection.
664    let (parts, body) = req.into_parts();
665    let bytes = match body.collect().await {
666        Ok(collected) => collected.to_bytes(),
667        Err(e) => {
668            tracing::error!(error = %e, "failed to read request body");
669            return (
670                StatusCode::INTERNAL_SERVER_ERROR,
671                "failed to read request body",
672            )
673                .into_response();
674        }
675    };
676
677    // Try to parse as JSON and inspect JSON-RPC tool calls, including batch arrays.
678    if let Ok(json) = serde_json::from_slice::<serde_json::Value>(&bytes) {
679        let tool_calls = extract_tool_calls(&json);
680        if !tool_calls.is_empty() {
681            for params in tool_calls {
682                if let Some(resp) = enforce_rate_limit(tool_limiter.as_deref(), peer_ip) {
683                    return resp;
684                }
685                if policy.is_enabled()
686                    && let Some(resp) = enforce_tool_policy(&policy, &identity_name, &role, params)
687                {
688                    return resp;
689                }
690            }
691        }
692    }
693    // Non-parseable or non-tool-call requests pass through.
694
695    // Reconstruct the request with the consumed body.
696    let req = Request::from_parts(parts, Body::from(bytes));
697
698    // Set the caller's role and identity in task-local storage for the handler.
699    if role.is_empty() {
700        next.run(req).await
701    } else {
702        CURRENT_ROLE
703            .scope(
704                role,
705                CURRENT_IDENTITY.scope(
706                    identity_name,
707                    CURRENT_TOKEN.scope(raw_token, CURRENT_SUB.scope(sub, next.run(req))),
708                ),
709            )
710            .await
711    }
712}
713
714/// Extract the `params` object for every top-level `tools/call` message.
715///
716/// Supports either a single JSON-RPC object or a JSON-RPC batch array. Any
717/// malformed elements are ignored so non-RPC payloads continue to pass through
718/// unchanged.
719fn extract_tool_calls(value: &serde_json::Value) -> Vec<&serde_json::Value> {
720    match value {
721        serde_json::Value::Object(map) => map
722            .get("method")
723            .and_then(serde_json::Value::as_str)
724            .filter(|method| *method == "tools/call")
725            .and_then(|_| map.get("params"))
726            .into_iter()
727            .collect(),
728        serde_json::Value::Array(items) => items
729            .iter()
730            .filter_map(|item| match item {
731                serde_json::Value::Object(map) => map
732                    .get("method")
733                    .and_then(serde_json::Value::as_str)
734                    .filter(|method| *method == "tools/call")
735                    .and_then(|_| map.get("params")),
736                serde_json::Value::Null
737                | serde_json::Value::Bool(_)
738                | serde_json::Value::Number(_)
739                | serde_json::Value::String(_)
740                | serde_json::Value::Array(_) => None,
741            })
742            .collect(),
743        serde_json::Value::Null
744        | serde_json::Value::Bool(_)
745        | serde_json::Value::Number(_)
746        | serde_json::Value::String(_) => Vec::new(),
747    }
748}
749
750/// Per-IP rate limit check for tool invocations. Returns `Some(response)`
751/// if the caller should be rejected.
752fn enforce_rate_limit(
753    tool_limiter: Option<&ToolRateLimiter>,
754    peer_ip: Option<IpAddr>,
755) -> Option<Response> {
756    let limiter = tool_limiter?;
757    let ip = peer_ip?;
758    if limiter.check_key(&ip).is_err() {
759        tracing::warn!(%ip, "tool invocation rate limited");
760        return Some(McpxError::RateLimited("too many tool invocations".into()).into_response());
761    }
762    None
763}
764
765/// Apply RBAC tool/host + argument-allowlist checks. Returns `Some(response)`
766/// when the caller must be rejected. Assumes `policy.is_enabled()`.
767///
768/// `identity_name` is passed explicitly (rather than read from
769/// [`current_identity()`]) because this function runs *before* the
770/// task-local context is installed by the middleware. Reading the
771/// task-local here would always yield `None`, producing deny logs with
772/// an empty `user` field.
773fn enforce_tool_policy(
774    policy: &RbacPolicy,
775    identity_name: &str,
776    role: &str,
777    params: &serde_json::Value,
778) -> Option<Response> {
779    let tool_name = params.get("name").and_then(|v| v.as_str()).unwrap_or("");
780    let host = params
781        .get("arguments")
782        .and_then(|a| a.get("host"))
783        .and_then(|h| h.as_str());
784
785    let decision = if let Some(host) = host {
786        policy.check(role, tool_name, host)
787    } else {
788        policy.check_operation(role, tool_name)
789    };
790    if decision == RbacDecision::Deny {
791        tracing::warn!(
792            user = %identity_name,
793            role = %role,
794            tool = tool_name,
795            host = host.unwrap_or("-"),
796            "RBAC denied"
797        );
798        return Some(
799            McpxError::Rbac(format!("{tool_name} denied for role '{role}'")).into_response(),
800        );
801    }
802
803    let args = params.get("arguments").and_then(|a| a.as_object())?;
804    for (arg_key, arg_val) in args {
805        if let Some(val_str) = arg_val.as_str()
806            && !policy.argument_allowed(role, tool_name, arg_key, val_str)
807        {
808            // Redact the raw value: log an HMAC-SHA256 prefix instead of
809            // the literal string. Operators correlate hashes across log
810            // lines without ever exposing potentially sensitive inputs
811            // (paths, IDs, tokens accidentally passed as args, etc.).
812            tracing::warn!(
813                user = %identity_name,
814                role = %role,
815                tool = tool_name,
816                argument = arg_key,
817                arg_hmac = %policy.redact_arg(val_str),
818                "argument not in allowlist"
819            );
820            return Some(
821                McpxError::Rbac(format!(
822                    "argument '{arg_key}' value not in allowlist for tool '{tool_name}'"
823                ))
824                .into_response(),
825            );
826        }
827    }
828    None
829}
830
831/// Simple glob matching: `*` matches any sequence of characters.
832///
833/// Supports multiple `*` wildcards anywhere in the pattern.
834/// No `?`, `[...]`, or other advanced glob features.
835fn glob_match(pattern: &str, text: &str) -> bool {
836    let parts: Vec<&str> = pattern.split('*').collect();
837    if parts.len() == 1 {
838        // No wildcards - exact match.
839        return pattern == text;
840    }
841
842    let mut pos = 0;
843
844    // First part must match at the start (unless pattern starts with *).
845    if let Some(&first) = parts.first()
846        && !first.is_empty()
847    {
848        if !text.starts_with(first) {
849            return false;
850        }
851        pos = first.len();
852    }
853
854    // Last part must match at the end (unless pattern ends with *).
855    if let Some(&last) = parts.last()
856        && !last.is_empty()
857    {
858        if !text[pos..].ends_with(last) {
859            return false;
860        }
861        // Shrink the search area so middle parts don't overlap with the suffix.
862        let end = text.len() - last.len();
863        if pos > end {
864            return false;
865        }
866        // Check middle parts in the remaining region.
867        let middle = &text[pos..end];
868        let middle_parts = parts.get(1..parts.len() - 1).unwrap_or_default();
869        return match_middle(middle, middle_parts);
870    }
871
872    // Pattern ends with * - just check middle parts.
873    let middle = &text[pos..];
874    let middle_parts = parts.get(1..parts.len() - 1).unwrap_or_default();
875    match_middle(middle, middle_parts)
876}
877
878/// Match middle glob segments sequentially in `text`.
879fn match_middle(mut text: &str, parts: &[&str]) -> bool {
880    for part in parts {
881        if part.is_empty() {
882            continue;
883        }
884        if let Some(idx) = text.find(part) {
885            text = &text[idx + part.len()..];
886        } else {
887            return false;
888        }
889    }
890    true
891}
892
893#[cfg(test)]
894mod tests {
895    use super::*;
896
897    fn test_policy() -> RbacPolicy {
898        RbacPolicy::new(&RbacConfig {
899            enabled: true,
900            roles: vec![
901                RoleConfig {
902                    name: "viewer".into(),
903                    description: Some("Read-only".into()),
904                    allow: vec![
905                        "list_hosts".into(),
906                        "resource_list".into(),
907                        "resource_inspect".into(),
908                        "resource_logs".into(),
909                        "system_info".into(),
910                    ],
911                    deny: vec![],
912                    hosts: vec!["*".into()],
913                    argument_allowlists: vec![],
914                },
915                RoleConfig {
916                    name: "deploy".into(),
917                    description: Some("Lifecycle management".into()),
918                    allow: vec![
919                        "list_hosts".into(),
920                        "resource_list".into(),
921                        "resource_run".into(),
922                        "resource_start".into(),
923                        "resource_stop".into(),
924                        "resource_restart".into(),
925                        "resource_logs".into(),
926                        "image_pull".into(),
927                    ],
928                    deny: vec!["resource_delete".into(), "resource_exec".into()],
929                    hosts: vec!["web-*".into(), "api-*".into()],
930                    argument_allowlists: vec![],
931                },
932                RoleConfig {
933                    name: "ops".into(),
934                    description: Some("Full access".into()),
935                    allow: vec!["*".into()],
936                    deny: vec![],
937                    hosts: vec!["*".into()],
938                    argument_allowlists: vec![],
939                },
940                RoleConfig {
941                    name: "restricted-exec".into(),
942                    description: Some("Exec with argument allowlist".into()),
943                    allow: vec!["resource_exec".into()],
944                    deny: vec![],
945                    hosts: vec!["dev-*".into()],
946                    argument_allowlists: vec![ArgumentAllowlist {
947                        tool: "resource_exec".into(),
948                        argument: "cmd".into(),
949                        allowed: vec![
950                            "sh".into(),
951                            "bash".into(),
952                            "cat".into(),
953                            "ls".into(),
954                            "ps".into(),
955                        ],
956                    }],
957                },
958            ],
959            redaction_salt: None,
960        })
961    }
962
963    // -- glob_match tests --
964
965    #[test]
966    fn glob_exact_match() {
967        assert!(glob_match("web-prod-1", "web-prod-1"));
968        assert!(!glob_match("web-prod-1", "web-prod-2"));
969    }
970
971    #[test]
972    fn glob_star_suffix() {
973        assert!(glob_match("web-*", "web-prod-1"));
974        assert!(glob_match("web-*", "web-staging"));
975        assert!(!glob_match("web-*", "api-prod"));
976    }
977
978    #[test]
979    fn glob_star_prefix() {
980        assert!(glob_match("*-prod", "web-prod"));
981        assert!(glob_match("*-prod", "api-prod"));
982        assert!(!glob_match("*-prod", "web-staging"));
983    }
984
985    #[test]
986    fn glob_star_middle() {
987        assert!(glob_match("web-*-prod", "web-us-prod"));
988        assert!(glob_match("web-*-prod", "web-eu-east-prod"));
989        assert!(!glob_match("web-*-prod", "web-staging"));
990    }
991
992    #[test]
993    fn glob_star_only() {
994        assert!(glob_match("*", "anything"));
995        assert!(glob_match("*", ""));
996    }
997
998    #[test]
999    fn glob_multiple_stars() {
1000        assert!(glob_match("*web*prod*", "my-web-us-prod-1"));
1001        assert!(!glob_match("*web*prod*", "my-api-us-staging"));
1002    }
1003
1004    // -- RbacPolicy::check tests --
1005
1006    #[test]
1007    fn disabled_policy_allows_everything() {
1008        let policy = RbacPolicy::new(&RbacConfig {
1009            enabled: false,
1010            roles: vec![],
1011            redaction_salt: None,
1012        });
1013        assert_eq!(
1014            policy.check("nonexistent", "resource_delete", "any-host"),
1015            RbacDecision::Allow
1016        );
1017    }
1018
1019    #[test]
1020    fn unknown_role_denied() {
1021        let policy = test_policy();
1022        assert_eq!(
1023            policy.check("unknown", "resource_list", "web-prod-1"),
1024            RbacDecision::Deny
1025        );
1026    }
1027
1028    #[test]
1029    fn viewer_allowed_read_ops() {
1030        let policy = test_policy();
1031        assert_eq!(
1032            policy.check("viewer", "resource_list", "web-prod-1"),
1033            RbacDecision::Allow
1034        );
1035        assert_eq!(
1036            policy.check("viewer", "system_info", "db-host"),
1037            RbacDecision::Allow
1038        );
1039    }
1040
1041    #[test]
1042    fn viewer_denied_write_ops() {
1043        let policy = test_policy();
1044        assert_eq!(
1045            policy.check("viewer", "resource_run", "web-prod-1"),
1046            RbacDecision::Deny
1047        );
1048        assert_eq!(
1049            policy.check("viewer", "resource_delete", "web-prod-1"),
1050            RbacDecision::Deny
1051        );
1052    }
1053
1054    #[test]
1055    fn deploy_allowed_on_matching_hosts() {
1056        let policy = test_policy();
1057        assert_eq!(
1058            policy.check("deploy", "resource_run", "web-prod-1"),
1059            RbacDecision::Allow
1060        );
1061        assert_eq!(
1062            policy.check("deploy", "resource_start", "api-staging"),
1063            RbacDecision::Allow
1064        );
1065    }
1066
1067    #[test]
1068    fn deploy_denied_on_non_matching_host() {
1069        let policy = test_policy();
1070        assert_eq!(
1071            policy.check("deploy", "resource_run", "db-prod-1"),
1072            RbacDecision::Deny
1073        );
1074    }
1075
1076    #[test]
1077    fn deny_overrides_allow() {
1078        let policy = test_policy();
1079        assert_eq!(
1080            policy.check("deploy", "resource_delete", "web-prod-1"),
1081            RbacDecision::Deny
1082        );
1083        assert_eq!(
1084            policy.check("deploy", "resource_exec", "web-prod-1"),
1085            RbacDecision::Deny
1086        );
1087    }
1088
1089    #[test]
1090    fn ops_wildcard_allows_everything() {
1091        let policy = test_policy();
1092        assert_eq!(
1093            policy.check("ops", "resource_delete", "any-host"),
1094            RbacDecision::Allow
1095        );
1096        assert_eq!(
1097            policy.check("ops", "secret_create", "db-host"),
1098            RbacDecision::Allow
1099        );
1100    }
1101
1102    // -- host_visible tests --
1103
1104    #[test]
1105    fn host_visible_respects_globs() {
1106        let policy = test_policy();
1107        assert!(policy.host_visible("deploy", "web-prod-1"));
1108        assert!(policy.host_visible("deploy", "api-staging"));
1109        assert!(!policy.host_visible("deploy", "db-prod-1"));
1110        assert!(policy.host_visible("ops", "anything"));
1111        assert!(policy.host_visible("viewer", "anything"));
1112    }
1113
1114    #[test]
1115    fn host_visible_unknown_role() {
1116        let policy = test_policy();
1117        assert!(!policy.host_visible("unknown", "web-prod-1"));
1118    }
1119
1120    // -- argument_allowed tests --
1121
1122    #[test]
1123    fn argument_allowed_no_allowlist() {
1124        let policy = test_policy();
1125        // ops has no argument_allowlists -- all values allowed
1126        assert!(policy.argument_allowed("ops", "resource_exec", "cmd", "rm -rf /"));
1127        assert!(policy.argument_allowed("ops", "resource_exec", "cmd", "bash"));
1128    }
1129
1130    #[test]
1131    fn argument_allowed_with_allowlist() {
1132        let policy = test_policy();
1133        assert!(policy.argument_allowed("restricted-exec", "resource_exec", "cmd", "sh"));
1134        assert!(policy.argument_allowed(
1135            "restricted-exec",
1136            "resource_exec",
1137            "cmd",
1138            "bash -c 'echo hi'"
1139        ));
1140        assert!(policy.argument_allowed(
1141            "restricted-exec",
1142            "resource_exec",
1143            "cmd",
1144            "cat /etc/hosts"
1145        ));
1146        assert!(policy.argument_allowed(
1147            "restricted-exec",
1148            "resource_exec",
1149            "cmd",
1150            "/usr/bin/ls -la"
1151        ));
1152    }
1153
1154    #[test]
1155    fn argument_denied_not_in_allowlist() {
1156        let policy = test_policy();
1157        assert!(!policy.argument_allowed("restricted-exec", "resource_exec", "cmd", "rm -rf /"));
1158        assert!(!policy.argument_allowed(
1159            "restricted-exec",
1160            "resource_exec",
1161            "cmd",
1162            "python3 exploit.py"
1163        ));
1164        assert!(!policy.argument_allowed(
1165            "restricted-exec",
1166            "resource_exec",
1167            "cmd",
1168            "/usr/bin/curl evil.com"
1169        ));
1170    }
1171
1172    #[test]
1173    fn argument_denied_unknown_role() {
1174        let policy = test_policy();
1175        assert!(!policy.argument_allowed("unknown", "resource_exec", "cmd", "sh"));
1176    }
1177
1178    // -- shlex-tokenization regression tests (1.4.1) --
1179    //
1180    // These tests pin the POSIX-shell-like tokenization contract added
1181    // in 1.4.1. See `RbacPolicy::argument_allowed` doc comment for the
1182    // full contract; see CHANGELOG.md `[1.4.1]` for the behavior matrix.
1183
1184    /// Helper: build a minimal enabled policy with a single argument
1185    /// allowlist on tool `run`, argument `cmd`.
1186    fn shlex_policy(allowed: Vec<String>) -> RbacPolicy {
1187        let role = RoleConfig::new("viewer", vec!["run".into()], vec!["*".into()])
1188            .with_argument_allowlists(vec![ArgumentAllowlist::new("run", "cmd", allowed)]);
1189        let mut config = RbacConfig::with_roles(vec![role]);
1190        config.enabled = true;
1191        RbacPolicy::new(&config)
1192    }
1193
1194    #[test]
1195    fn argument_allowed_matches_quoted_path_with_spaces() {
1196        let policy = shlex_policy(vec!["/usr/bin/my tool".into()]);
1197        assert!(policy.argument_allowed("viewer", "run", "cmd", r#""/usr/bin/my tool" --flag"#));
1198    }
1199
1200    #[test]
1201    fn argument_allowed_matches_basename_of_quoted_path() {
1202        let policy = shlex_policy(vec!["my tool".into()]);
1203        assert!(policy.argument_allowed("viewer", "run", "cmd", r#""/usr/bin/my tool" --flag"#));
1204    }
1205
1206    #[test]
1207    fn argument_allowed_fails_closed_on_unbalanced_quote() {
1208        let policy = shlex_policy(vec!["unbalanced".into()]);
1209        assert!(!policy.argument_allowed("viewer", "run", "cmd", r"unbalanced 'quote"));
1210    }
1211
1212    #[test]
1213    fn argument_allowed_fails_closed_on_empty_string() {
1214        let policy = shlex_policy(vec![String::new()]);
1215        assert!(!policy.argument_allowed("viewer", "run", "cmd", ""));
1216    }
1217
1218    #[test]
1219    fn argument_allowed_handles_single_quoted_executable() {
1220        let policy = shlex_policy(vec!["/bin/sh".into()]);
1221        assert!(policy.argument_allowed("viewer", "run", "cmd", r"'/bin/sh' -c 'echo hi'"));
1222    }
1223
1224    #[test]
1225    fn argument_allowed_handles_tab_separator() {
1226        let policy = shlex_policy(vec!["ls".into()]);
1227        assert!(policy.argument_allowed("viewer", "run", "cmd", "ls\t/etc/passwd"));
1228    }
1229
1230    #[test]
1231    fn argument_allowed_plain_token_unchanged() {
1232        let policy = shlex_policy(vec!["ls".into()]);
1233        assert!(policy.argument_allowed("viewer", "run", "cmd", "ls"));
1234    }
1235
1236    // Per Oracle review: the next four tests pin the cases the original
1237    // handoff missed. Each confirms the *new* (1.4.1) deny behavior so a
1238    // future regression to the old `split_whitespace` semantics would
1239    // surface as a test failure.
1240
1241    #[test]
1242    fn argument_allowed_fails_closed_on_quoted_empty_first_token() {
1243        // value r#""""# parses to Some(vec![""]). An empty argv element
1244        // is never a runnable executable; deny even when "" is
1245        // explicitly allowlisted.
1246        let policy = shlex_policy(vec![String::new()]);
1247        assert!(!policy.argument_allowed("viewer", "run", "cmd", r#""""#));
1248    }
1249
1250    #[test]
1251    fn argument_allowed_quoted_literal_token_no_longer_matches() {
1252        // 1.4.0 behavior: split_whitespace first token = "'bash'" --
1253        //                 matched literal allowlist entry "'bash'".
1254        // 1.4.1 behavior: shlex strips the surrounding quotes -> first
1255        //                 token = "bash" -- no match against allowlist
1256        //                 entry "'bash'". Deny.
1257        let policy = shlex_policy(vec!["'bash'".into()]);
1258        assert!(!policy.argument_allowed("viewer", "run", "cmd", "'bash' -c true"));
1259    }
1260
1261    #[test]
1262    fn argument_allowed_backslash_literal_token_no_longer_matches() {
1263        // 1.4.0 behavior: literal first token "foo\\bar" matched.
1264        // 1.4.1 behavior: POSIX shlex treats backslash as escape ->
1265        //                 first token = "foobar". Allowlist entry with
1266        //                 a literal backslash no longer matches. Deny.
1267        let policy = shlex_policy(vec![r"foo\bar".into()]);
1268        assert!(!policy.argument_allowed("viewer", "run", "cmd", r"foo\bar --x"));
1269    }
1270
1271    #[test]
1272    fn argument_allowed_windows_path_no_longer_matches() {
1273        // 1.4.0 behavior: literal Windows path matched.
1274        // 1.4.1 behavior: POSIX shlex eats backslashes -> path identity
1275        //                 changes; allowlist entry no longer matches.
1276        //                 Deny. Documented in CHANGELOG operator notes.
1277        let policy = shlex_policy(vec![r"C:\Windows\System32\cmd.exe".into()]);
1278        assert!(!policy.argument_allowed(
1279            "viewer",
1280            "run",
1281            "cmd",
1282            r"C:\Windows\System32\cmd.exe /c dir"
1283        ));
1284    }
1285
1286    // -- host_patterns tests --
1287
1288    #[test]
1289    fn host_patterns_returns_globs() {
1290        let policy = test_policy();
1291        assert_eq!(
1292            policy.host_patterns("deploy"),
1293            Some(vec!["web-*".to_owned(), "api-*".to_owned()].as_slice())
1294        );
1295        assert_eq!(
1296            policy.host_patterns("ops"),
1297            Some(vec!["*".to_owned()].as_slice())
1298        );
1299        assert!(policy.host_patterns("nonexistent").is_none());
1300    }
1301
1302    // -- check_operation tests (no host check) --
1303
1304    #[test]
1305    fn check_operation_allows_without_host() {
1306        let policy = test_policy();
1307        assert_eq!(
1308            policy.check_operation("deploy", "resource_run"),
1309            RbacDecision::Allow
1310        );
1311        // but check() with a non-matching host denies
1312        assert_eq!(
1313            policy.check("deploy", "resource_run", "db-prod-1"),
1314            RbacDecision::Deny
1315        );
1316    }
1317
1318    #[test]
1319    fn check_operation_deny_overrides() {
1320        let policy = test_policy();
1321        assert_eq!(
1322            policy.check_operation("deploy", "resource_delete"),
1323            RbacDecision::Deny
1324        );
1325    }
1326
1327    #[test]
1328    fn check_operation_unknown_role() {
1329        let policy = test_policy();
1330        assert_eq!(
1331            policy.check_operation("unknown", "resource_list"),
1332            RbacDecision::Deny
1333        );
1334    }
1335
1336    #[test]
1337    fn check_operation_disabled() {
1338        let policy = RbacPolicy::new(&RbacConfig {
1339            enabled: false,
1340            roles: vec![],
1341            redaction_salt: None,
1342        });
1343        assert_eq!(
1344            policy.check_operation("nonexistent", "anything"),
1345            RbacDecision::Allow
1346        );
1347    }
1348
1349    // -- current_role / current_identity tests --
1350
1351    #[test]
1352    fn current_role_returns_none_outside_scope() {
1353        assert!(current_role().is_none());
1354    }
1355
1356    #[test]
1357    fn current_identity_returns_none_outside_scope() {
1358        assert!(current_identity().is_none());
1359    }
1360
1361    // -- rbac_middleware integration tests --
1362
1363    use axum::{
1364        body::Body,
1365        http::{Method, Request, StatusCode},
1366    };
1367    use tower::ServiceExt as _;
1368
1369    fn tool_call_body(tool: &str, args: &serde_json::Value) -> String {
1370        serde_json::json!({
1371            "jsonrpc": "2.0",
1372            "id": 1,
1373            "method": "tools/call",
1374            "params": {
1375                "name": tool,
1376                "arguments": args
1377            }
1378        })
1379        .to_string()
1380    }
1381
1382    fn rbac_router(policy: Arc<RbacPolicy>) -> axum::Router {
1383        axum::Router::new()
1384            .route("/mcp", axum::routing::post(|| async { "ok" }))
1385            .layer(axum::middleware::from_fn(move |req, next| {
1386                let p = Arc::clone(&policy);
1387                rbac_middleware(p, None, req, next)
1388            }))
1389    }
1390
1391    fn rbac_router_with_identity(policy: Arc<RbacPolicy>, identity: AuthIdentity) -> axum::Router {
1392        axum::Router::new()
1393            .route("/mcp", axum::routing::post(|| async { "ok" }))
1394            .layer(axum::middleware::from_fn(
1395                move |mut req: Request<Body>, next: Next| {
1396                    let p = Arc::clone(&policy);
1397                    let id = identity.clone();
1398                    async move {
1399                        req.extensions_mut().insert(id);
1400                        rbac_middleware(p, None, req, next).await
1401                    }
1402                },
1403            ))
1404    }
1405
1406    #[tokio::test]
1407    async fn middleware_passes_non_post() {
1408        let policy = Arc::new(test_policy());
1409        let app = rbac_router(policy);
1410        // GET passes through even without identity.
1411        let req = Request::builder()
1412            .method(Method::GET)
1413            .uri("/mcp")
1414            .body(Body::empty())
1415            .unwrap();
1416        // GET on a POST-only route returns 405, but the middleware itself
1417        // doesn't block it -- it returns next.run(req).
1418        let resp = app.oneshot(req).await.unwrap();
1419        assert_eq!(resp.status(), StatusCode::METHOD_NOT_ALLOWED);
1420    }
1421
1422    #[tokio::test]
1423    async fn middleware_denies_without_identity() {
1424        let policy = Arc::new(test_policy());
1425        let app = rbac_router(policy);
1426        let body = tool_call_body("resource_list", &serde_json::json!({}));
1427        let req = Request::builder()
1428            .method(Method::POST)
1429            .uri("/mcp")
1430            .header("content-type", "application/json")
1431            .body(Body::from(body))
1432            .unwrap();
1433        let resp = app.oneshot(req).await.unwrap();
1434        assert_eq!(resp.status(), StatusCode::FORBIDDEN);
1435    }
1436
1437    #[tokio::test]
1438    async fn middleware_allows_permitted_tool() {
1439        let policy = Arc::new(test_policy());
1440        let id = AuthIdentity {
1441            method: crate::auth::AuthMethod::BearerToken,
1442            name: "alice".into(),
1443            role: "viewer".into(),
1444            raw_token: None,
1445            sub: None,
1446        };
1447        let app = rbac_router_with_identity(policy, id);
1448        let body = tool_call_body("resource_list", &serde_json::json!({}));
1449        let req = Request::builder()
1450            .method(Method::POST)
1451            .uri("/mcp")
1452            .header("content-type", "application/json")
1453            .body(Body::from(body))
1454            .unwrap();
1455        let resp = app.oneshot(req).await.unwrap();
1456        assert_eq!(resp.status(), StatusCode::OK);
1457    }
1458
1459    #[tokio::test]
1460    async fn middleware_denies_unpermitted_tool() {
1461        let policy = Arc::new(test_policy());
1462        let id = AuthIdentity {
1463            method: crate::auth::AuthMethod::BearerToken,
1464            name: "alice".into(),
1465            role: "viewer".into(),
1466            raw_token: None,
1467            sub: None,
1468        };
1469        let app = rbac_router_with_identity(policy, id);
1470        let body = tool_call_body("resource_delete", &serde_json::json!({}));
1471        let req = Request::builder()
1472            .method(Method::POST)
1473            .uri("/mcp")
1474            .header("content-type", "application/json")
1475            .body(Body::from(body))
1476            .unwrap();
1477        let resp = app.oneshot(req).await.unwrap();
1478        assert_eq!(resp.status(), StatusCode::FORBIDDEN);
1479    }
1480
1481    #[tokio::test]
1482    async fn middleware_passes_non_tool_call_post() {
1483        let policy = Arc::new(test_policy());
1484        let id = AuthIdentity {
1485            method: crate::auth::AuthMethod::BearerToken,
1486            name: "alice".into(),
1487            role: "viewer".into(),
1488            raw_token: None,
1489            sub: None,
1490        };
1491        let app = rbac_router_with_identity(policy, id);
1492        // A non-tools/call JSON-RPC (e.g. resources/list) passes through.
1493        let body = serde_json::json!({
1494            "jsonrpc": "2.0",
1495            "id": 1,
1496            "method": "resources/list"
1497        })
1498        .to_string();
1499        let req = Request::builder()
1500            .method(Method::POST)
1501            .uri("/mcp")
1502            .header("content-type", "application/json")
1503            .body(Body::from(body))
1504            .unwrap();
1505        let resp = app.oneshot(req).await.unwrap();
1506        assert_eq!(resp.status(), StatusCode::OK);
1507    }
1508
1509    #[tokio::test]
1510    async fn middleware_enforces_argument_allowlist() {
1511        let policy = Arc::new(test_policy());
1512        let id = AuthIdentity {
1513            method: crate::auth::AuthMethod::BearerToken,
1514            name: "dev".into(),
1515            role: "restricted-exec".into(),
1516            raw_token: None,
1517            sub: None,
1518        };
1519        // Allowed command
1520        let app = rbac_router_with_identity(Arc::clone(&policy), id.clone());
1521        let body = tool_call_body(
1522            "resource_exec",
1523            &serde_json::json!({"cmd": "ls -la", "host": "dev-1"}),
1524        );
1525        let req = Request::builder()
1526            .method(Method::POST)
1527            .uri("/mcp")
1528            .body(Body::from(body))
1529            .unwrap();
1530        let resp = app.oneshot(req).await.unwrap();
1531        assert_eq!(resp.status(), StatusCode::OK);
1532
1533        // Denied command
1534        let app = rbac_router_with_identity(policy, id);
1535        let body = tool_call_body(
1536            "resource_exec",
1537            &serde_json::json!({"cmd": "rm -rf /", "host": "dev-1"}),
1538        );
1539        let req = Request::builder()
1540            .method(Method::POST)
1541            .uri("/mcp")
1542            .body(Body::from(body))
1543            .unwrap();
1544        let resp = app.oneshot(req).await.unwrap();
1545        assert_eq!(resp.status(), StatusCode::FORBIDDEN);
1546    }
1547
1548    #[tokio::test]
1549    async fn middleware_disabled_policy_passes_everything() {
1550        let policy = Arc::new(RbacPolicy::disabled());
1551        let app = rbac_router(policy);
1552        // No identity, disabled policy -- should pass.
1553        let body = tool_call_body("anything", &serde_json::json!({}));
1554        let req = Request::builder()
1555            .method(Method::POST)
1556            .uri("/mcp")
1557            .body(Body::from(body))
1558            .unwrap();
1559        let resp = app.oneshot(req).await.unwrap();
1560        assert_eq!(resp.status(), StatusCode::OK);
1561    }
1562
1563    #[tokio::test]
1564    async fn middleware_batch_all_allowed_passes() {
1565        let policy = Arc::new(test_policy());
1566        let id = AuthIdentity {
1567            method: crate::auth::AuthMethod::BearerToken,
1568            name: "alice".into(),
1569            role: "viewer".into(),
1570            raw_token: None,
1571            sub: None,
1572        };
1573        let app = rbac_router_with_identity(policy, id);
1574        let body = serde_json::json!([
1575            {
1576                "jsonrpc": "2.0",
1577                "id": 1,
1578                "method": "tools/call",
1579                "params": { "name": "resource_list", "arguments": {} }
1580            },
1581            {
1582                "jsonrpc": "2.0",
1583                "id": 2,
1584                "method": "tools/call",
1585                "params": { "name": "system_info", "arguments": {} }
1586            }
1587        ])
1588        .to_string();
1589        let req = Request::builder()
1590            .method(Method::POST)
1591            .uri("/mcp")
1592            .header("content-type", "application/json")
1593            .body(Body::from(body))
1594            .unwrap();
1595        let resp = app.oneshot(req).await.unwrap();
1596        assert_eq!(resp.status(), StatusCode::OK);
1597    }
1598
1599    #[tokio::test]
1600    async fn middleware_batch_with_denied_call_rejects_entire_batch() {
1601        let policy = Arc::new(test_policy());
1602        let id = AuthIdentity {
1603            method: crate::auth::AuthMethod::BearerToken,
1604            name: "alice".into(),
1605            role: "viewer".into(),
1606            raw_token: None,
1607            sub: None,
1608        };
1609        let app = rbac_router_with_identity(policy, id);
1610        let body = serde_json::json!([
1611            {
1612                "jsonrpc": "2.0",
1613                "id": 1,
1614                "method": "tools/call",
1615                "params": { "name": "resource_list", "arguments": {} }
1616            },
1617            {
1618                "jsonrpc": "2.0",
1619                "id": 2,
1620                "method": "tools/call",
1621                "params": { "name": "resource_delete", "arguments": {} }
1622            }
1623        ])
1624        .to_string();
1625        let req = Request::builder()
1626            .method(Method::POST)
1627            .uri("/mcp")
1628            .header("content-type", "application/json")
1629            .body(Body::from(body))
1630            .unwrap();
1631        let resp = app.oneshot(req).await.unwrap();
1632        assert_eq!(resp.status(), StatusCode::FORBIDDEN);
1633    }
1634
1635    #[tokio::test]
1636    async fn middleware_batch_mixed_allowed_and_denied_rejects() {
1637        let policy = Arc::new(test_policy());
1638        let id = AuthIdentity {
1639            method: crate::auth::AuthMethod::BearerToken,
1640            name: "dev".into(),
1641            role: "restricted-exec".into(),
1642            raw_token: None,
1643            sub: None,
1644        };
1645        let app = rbac_router_with_identity(policy, id);
1646        let body = serde_json::json!([
1647            {
1648                "jsonrpc": "2.0",
1649                "id": 1,
1650                "method": "tools/call",
1651                "params": {
1652                    "name": "resource_exec",
1653                    "arguments": { "cmd": "ls -la", "host": "dev-1" }
1654                }
1655            },
1656            {
1657                "jsonrpc": "2.0",
1658                "id": 2,
1659                "method": "tools/call",
1660                "params": {
1661                    "name": "resource_exec",
1662                    "arguments": { "cmd": "rm -rf /", "host": "dev-1" }
1663                }
1664            }
1665        ])
1666        .to_string();
1667        let req = Request::builder()
1668            .method(Method::POST)
1669            .uri("/mcp")
1670            .header("content-type", "application/json")
1671            .body(Body::from(body))
1672            .unwrap();
1673        let resp = app.oneshot(req).await.unwrap();
1674        assert_eq!(resp.status(), StatusCode::FORBIDDEN);
1675    }
1676
1677    // -- redact_arg / redaction_salt tests --
1678
1679    #[test]
1680    fn redact_with_salt_is_deterministic_per_salt() {
1681        let salt = b"unit-test-salt";
1682        let a = redact_with_salt(salt, "rm -rf /");
1683        let b = redact_with_salt(salt, "rm -rf /");
1684        assert_eq!(a, b, "same input + salt must yield identical hash");
1685        assert_eq!(a.len(), 8, "redacted hash is 8 hex chars (4 bytes)");
1686        assert!(
1687            a.chars().all(|c| c.is_ascii_hexdigit()),
1688            "redacted hash must be lowercase hex: {a}"
1689        );
1690    }
1691
1692    #[test]
1693    fn redact_with_salt_differs_across_salts() {
1694        let v = "the-same-value";
1695        let h1 = redact_with_salt(b"salt-one", v);
1696        let h2 = redact_with_salt(b"salt-two", v);
1697        assert_ne!(
1698            h1, h2,
1699            "different salts must produce different hashes for the same value"
1700        );
1701    }
1702
1703    #[test]
1704    fn redact_with_salt_distinguishes_values() {
1705        let salt = b"k";
1706        let h1 = redact_with_salt(salt, "alpha");
1707        let h2 = redact_with_salt(salt, "beta");
1708        // Hash collisions on 32 bits are 1-in-4-billion; safe to assert.
1709        assert_ne!(h1, h2, "different values must produce different hashes");
1710    }
1711
1712    #[test]
1713    fn policy_with_configured_salt_redacts_consistently() {
1714        let cfg = RbacConfig {
1715            enabled: true,
1716            roles: vec![],
1717            redaction_salt: Some(SecretString::from("my-stable-salt")),
1718        };
1719        let p1 = RbacPolicy::new(&cfg);
1720        let p2 = RbacPolicy::new(&cfg);
1721        assert_eq!(
1722            p1.redact_arg("payload"),
1723            p2.redact_arg("payload"),
1724            "policies built from the same configured salt must agree"
1725        );
1726    }
1727
1728    #[test]
1729    fn policy_without_configured_salt_uses_process_salt() {
1730        let cfg = RbacConfig {
1731            enabled: true,
1732            roles: vec![],
1733            redaction_salt: None,
1734        };
1735        let p1 = RbacPolicy::new(&cfg);
1736        let p2 = RbacPolicy::new(&cfg);
1737        // Within one process, the lazy OnceLock salt is shared.
1738        assert_eq!(
1739            p1.redact_arg("payload"),
1740            p2.redact_arg("payload"),
1741            "process-wide salt must be consistent within one process"
1742        );
1743    }
1744
1745    #[test]
1746    fn redact_arg_is_fast_enough() {
1747        // Sanity floor: a single redaction should take well under 100 µs
1748        // even in unoptimized debug builds. Production criterion bench
1749        // (see H-T4 plan) will assert a stricter <10 µs threshold.
1750        let salt = b"perf-sanity-salt-32-bytes-padded";
1751        let value = "x".repeat(256);
1752        let start = std::time::Instant::now();
1753        let _ = redact_with_salt(salt, &value);
1754        let elapsed = start.elapsed();
1755        assert!(
1756            elapsed < Duration::from_millis(5),
1757            "single redact_with_salt took {elapsed:?}, expected <5 ms even in debug"
1758        );
1759    }
1760
1761    // -- enforce_tool_policy identity propagation regression test (BUG H-S3) --
1762
1763    /// Regression: when `enforce_tool_policy` denied a request, the deny
1764    /// log used to read `current_identity()`, which was always `None` at
1765    /// that point because the task-local context is installed *after*
1766    /// policy enforcement. The fix passes `identity_name` explicitly.
1767    ///
1768    /// We assert the deny path returns 403 (the visible behaviour).
1769    /// The log-content assertion lives behind tracing-test which we have
1770    /// not yet added as a dev-dep; the explicit-parameter signature alone
1771    /// makes the previous bug structurally impossible.
1772    #[tokio::test]
1773    async fn deny_path_uses_explicit_identity_not_task_local() {
1774        let policy = Arc::new(test_policy());
1775        let id = AuthIdentity {
1776            method: crate::auth::AuthMethod::BearerToken,
1777            name: "alice-the-auditor".into(),
1778            role: "viewer".into(),
1779            raw_token: None,
1780            sub: None,
1781        };
1782        let app = rbac_router_with_identity(policy, id);
1783        // viewer is not allowed to call resource_delete -> 403.
1784        let body = tool_call_body("resource_delete", &serde_json::json!({}));
1785        let req = Request::builder()
1786            .method(Method::POST)
1787            .uri("/mcp")
1788            .header("content-type", "application/json")
1789            .body(Body::from(body))
1790            .unwrap();
1791        let resp = app.oneshot(req).await.unwrap();
1792        assert_eq!(resp.status(), StatusCode::FORBIDDEN);
1793    }
1794}