Skip to main content

rmcp_server_kit/
rbac.rs

1//! Role-Based Access Control (RBAC) policy engine.
2//!
3//! Evaluates `(role, operation, host)` tuples against a set of role
4//! definitions loaded from config.  Deny-overrides-allow semantics:
5//! an explicit deny entry always wins over a wildcard allow.
6//!
7//! Includes an axum middleware that inspects MCP JSON-RPC tool calls
8//! and enforces RBAC and per-IP tool rate limiting before the request
9//! reaches the handler.
10
11use std::{net::IpAddr, num::NonZeroU32, sync::Arc, time::Duration};
12
13use axum::{
14    body::Body,
15    extract::ConnectInfo,
16    http::{Method, Request, StatusCode},
17    middleware::Next,
18    response::{IntoResponse, Response},
19};
20use hmac::{Hmac, Mac};
21use http_body_util::BodyExt;
22use secrecy::{ExposeSecret, SecretString};
23use serde::Deserialize;
24use sha2::Sha256;
25
26use crate::{
27    auth::{AuthIdentity, TlsConnInfo},
28    bounded_limiter::BoundedKeyedLimiter,
29    error::McpxError,
30};
31
32/// Per-source-IP rate limiter for tool invocations. Memory-bounded against
33/// IP-spray `DoS` via [`BoundedKeyedLimiter`].
34pub(crate) type ToolRateLimiter = BoundedKeyedLimiter<IpAddr>;
35
36/// Default tool rate limit: 120 invocations per minute per source IP.
37// SAFETY: unwrap() is safe - literal 120 is provably non-zero (const-evaluated).
38const DEFAULT_TOOL_RATE: NonZeroU32 = NonZeroU32::new(120).unwrap();
39
40/// Default cap on the number of distinct source IPs tracked by the tool
41/// rate limiter. Bounded to defend against IP-spray `DoS` exhausting memory.
42const DEFAULT_TOOL_MAX_TRACKED_KEYS: usize = 10_000;
43
44/// Default idle-eviction window for the tool rate limiter (15 minutes).
45const DEFAULT_TOOL_IDLE_EVICTION: Duration = Duration::from_mins(15);
46
47/// Build a per-IP tool rate limiter from a max-calls-per-minute value.
48///
49/// Memory-bounded with `DEFAULT_TOOL_MAX_TRACKED_KEYS` tracked keys and
50/// `DEFAULT_TOOL_IDLE_EVICTION` idle eviction. Use
51/// [`build_tool_rate_limiter_with_bounds`] to override.
52#[must_use]
53pub(crate) fn build_tool_rate_limiter(max_per_minute: u32) -> Arc<ToolRateLimiter> {
54    build_tool_rate_limiter_with_bounds(
55        max_per_minute,
56        DEFAULT_TOOL_MAX_TRACKED_KEYS,
57        DEFAULT_TOOL_IDLE_EVICTION,
58    )
59}
60
61/// Build a per-IP tool rate limiter with explicit memory-bound parameters.
62#[must_use]
63pub(crate) fn build_tool_rate_limiter_with_bounds(
64    max_per_minute: u32,
65    max_tracked_keys: usize,
66    idle_eviction: Duration,
67) -> Arc<ToolRateLimiter> {
68    let quota =
69        governor::Quota::per_minute(NonZeroU32::new(max_per_minute).unwrap_or(DEFAULT_TOOL_RATE));
70    Arc::new(BoundedKeyedLimiter::new(
71        quota,
72        max_tracked_keys,
73        idle_eviction,
74    ))
75}
76
77// Task-local storage for the current caller's RBAC role and identity name.
78// Set by the RBAC middleware, read by tool handlers (e.g. list_hosts filtering, audit logging).
79//
80// `CURRENT_TOKEN` holds a [`SecretString`] so the raw bearer token is never
81// printed via `Debug` (it formats as `"[REDACTED alloc::string::String]"`)
82// and is zeroized on drop by the `secrecy` crate.
83tokio::task_local! {
84    static CURRENT_ROLE: String;
85    static CURRENT_IDENTITY: String;
86    static CURRENT_TOKEN: SecretString;
87    static CURRENT_SUB: String;
88}
89
90/// Get the current caller's RBAC role (set by RBAC middleware).
91/// Returns `None` outside an RBAC-scoped request context.
92#[must_use]
93pub fn current_role() -> Option<String> {
94    CURRENT_ROLE.try_with(Clone::clone).ok()
95}
96
97/// Get the current caller's identity name (set by RBAC middleware).
98/// Returns `None` outside an RBAC-scoped request context.
99#[must_use]
100pub fn current_identity() -> Option<String> {
101    CURRENT_IDENTITY.try_with(Clone::clone).ok()
102}
103
104/// Get the raw bearer token for the current request as a [`SecretString`].
105/// Returns `None` outside a request context or when auth used mTLS/API-key.
106/// Tool handlers use this for downstream token passthrough.
107///
108/// The returned value is wrapped in [`SecretString`] so it does not leak
109/// via `Debug`/`Display`/serde. Call `.expose_secret()` only when the
110/// raw value is actually needed (e.g. as the `Authorization` header on
111/// an outbound HTTP request).
112///
113/// An empty token is treated as absent (returns `None`); this preserves
114/// backward compatibility with the prior `Option<String>` API where the
115/// empty default sentinel meant "no token".
116#[must_use]
117pub fn current_token() -> Option<SecretString> {
118    CURRENT_TOKEN
119        .try_with(|t| {
120            if t.expose_secret().is_empty() {
121                None
122            } else {
123                Some(t.clone())
124            }
125        })
126        .ok()
127        .flatten()
128}
129
130/// Get the JWT `sub` claim (stable user ID, e.g. Keycloak UUID).
131/// Returns `None` outside a request context or for non-JWT auth.
132/// Use for stable per-user keying (token store, etc.).
133#[must_use]
134pub fn current_sub() -> Option<String> {
135    CURRENT_SUB
136        .try_with(Clone::clone)
137        .ok()
138        .filter(|s| !s.is_empty())
139}
140
141/// Run a future with `CURRENT_TOKEN` set so that [`current_token()`] returns
142/// the given value inside the future. Useful when MCP tool handlers need the
143/// raw bearer token but run in a spawned task where the RBAC middleware's
144/// task-local scope is no longer active.
145pub async fn with_token_scope<F: Future>(token: SecretString, f: F) -> F::Output {
146    CURRENT_TOKEN.scope(token, f).await
147}
148
149/// Run a future with all task-locals (`CURRENT_ROLE`, `CURRENT_IDENTITY`,
150/// `CURRENT_TOKEN`, `CURRENT_SUB`) set.  Use this when re-establishing the
151/// full RBAC context in spawned tasks (e.g. rmcp session tasks) where the
152/// middleware's scope is no longer active.
153pub async fn with_rbac_scope<F: Future>(
154    role: String,
155    identity: String,
156    token: SecretString,
157    sub: String,
158    f: F,
159) -> F::Output {
160    CURRENT_ROLE
161        .scope(
162            role,
163            CURRENT_IDENTITY.scope(
164                identity,
165                CURRENT_TOKEN.scope(token, CURRENT_SUB.scope(sub, f)),
166            ),
167        )
168        .await
169}
170
171/// A single role definition.
172#[derive(Debug, Clone, Deserialize)]
173#[non_exhaustive]
174pub struct RoleConfig {
175    /// Role identifier referenced from identities (API keys, mTLS, JWT claims).
176    pub name: String,
177    /// Human-readable description, surfaced in diagnostics only.
178    #[serde(default)]
179    pub description: Option<String>,
180    /// Allowed operations.  `["*"]` means all operations.
181    #[serde(default)]
182    pub allow: Vec<String>,
183    /// Explicitly denied operations (overrides allow).
184    #[serde(default)]
185    pub deny: Vec<String>,
186    /// Host name glob patterns this role can access. `["*"]` means all hosts.
187    #[serde(default = "default_hosts")]
188    pub hosts: Vec<String>,
189    /// Per-tool argument constraints. When a tool call matches, the
190    /// specified argument's first whitespace-delimited token (or its
191    /// `/`-basename) must appear in the allowlist.
192    #[serde(default)]
193    pub argument_allowlists: Vec<ArgumentAllowlist>,
194}
195
196impl RoleConfig {
197    /// Create a role with the given name, allowed operations, and host patterns.
198    #[must_use]
199    pub fn new(name: impl Into<String>, allow: Vec<String>, hosts: Vec<String>) -> Self {
200        Self {
201            name: name.into(),
202            description: None,
203            allow,
204            deny: vec![],
205            hosts,
206            argument_allowlists: vec![],
207        }
208    }
209
210    /// Attach argument allowlists to this role.
211    #[must_use]
212    pub fn with_argument_allowlists(mut self, allowlists: Vec<ArgumentAllowlist>) -> Self {
213        self.argument_allowlists = allowlists;
214        self
215    }
216}
217
218/// Per-tool argument allowlist entry.
219///
220/// When the middleware sees a `tools/call` for `tool`, it extracts the
221/// string value at `argument` from the call's arguments object and checks
222/// its first token against `allowed`. If the token is not in the list
223/// the call is rejected with 403.
224#[derive(Debug, Clone, Deserialize)]
225#[non_exhaustive]
226pub struct ArgumentAllowlist {
227    /// Tool name to match (exact or glob, e.g. `"run_query"`).
228    pub tool: String,
229    /// Argument key whose value is checked (e.g. `"cmd"`, `"query"`).
230    pub argument: String,
231    /// Permitted first-token values. Empty means unrestricted.
232    #[serde(default)]
233    pub allowed: Vec<String>,
234}
235
236impl ArgumentAllowlist {
237    /// Create an argument allowlist for a tool.
238    #[must_use]
239    pub fn new(tool: impl Into<String>, argument: impl Into<String>, allowed: Vec<String>) -> Self {
240        Self {
241            tool: tool.into(),
242            argument: argument.into(),
243            allowed,
244        }
245    }
246}
247
248fn default_hosts() -> Vec<String> {
249    vec!["*".into()]
250}
251
252/// Top-level RBAC configuration (deserializable from TOML).
253#[derive(Debug, Clone, Default, Deserialize)]
254#[non_exhaustive]
255pub struct RbacConfig {
256    /// Master switch -- when false, the RBAC middleware is not installed.
257    #[serde(default)]
258    pub enabled: bool,
259    /// Role definitions available to identities.
260    #[serde(default)]
261    pub roles: Vec<RoleConfig>,
262    /// Optional stable HMAC key (any length) used to redact argument
263    /// values in deny logs. When set, redacted hashes are stable across
264    /// process restarts (useful for log correlation across deploys).
265    /// When `None`, a random 32-byte key is generated per process at
266    /// first use; redacted hashes change every restart.
267    ///
268    /// The key is wrapped in [`SecretString`] so it never leaks via
269    /// `Debug`/`Display`/serde and is zeroized on drop.
270    #[serde(default)]
271    pub redaction_salt: Option<SecretString>,
272}
273
274impl RbacConfig {
275    /// Create an enabled RBAC config with the given roles.
276    #[must_use]
277    pub fn with_roles(roles: Vec<RoleConfig>) -> Self {
278        Self {
279            enabled: true,
280            roles,
281            redaction_salt: None,
282        }
283    }
284}
285
286/// Result of an RBAC policy check.
287#[derive(Debug, Clone, Copy, PartialEq, Eq)]
288#[non_exhaustive]
289pub enum RbacDecision {
290    /// Caller is permitted to perform the requested operation.
291    Allow,
292    /// Caller is denied access.
293    Deny,
294}
295
296/// Summary of a single role, produced by [`RbacPolicy::summary`].
297#[derive(Debug, Clone, serde::Serialize)]
298#[non_exhaustive]
299pub struct RbacRoleSummary {
300    /// Role name.
301    pub name: String,
302    /// Number of allow entries.
303    pub allow: usize,
304    /// Number of deny entries.
305    pub deny: usize,
306    /// Number of host patterns.
307    pub hosts: usize,
308    /// Number of argument allowlist entries.
309    pub argument_allowlists: usize,
310}
311
312/// Summary of the whole RBAC policy, produced by [`RbacPolicy::summary`].
313#[derive(Debug, Clone, serde::Serialize)]
314#[non_exhaustive]
315pub struct RbacPolicySummary {
316    /// Whether RBAC enforcement is active.
317    pub enabled: bool,
318    /// Per-role summaries.
319    pub roles: Vec<RbacRoleSummary>,
320}
321
322/// Compiled RBAC policy for fast lookup.
323///
324/// Built from [`RbacConfig`] at startup.  All lookups are O(n) over the
325/// role's allow/deny/host lists, which is fine for the expected cardinality
326/// (a handful of roles with tens of entries each).
327#[derive(Debug, Clone)]
328#[non_exhaustive]
329pub struct RbacPolicy {
330    roles: Vec<RoleConfig>,
331    enabled: bool,
332    /// HMAC key used to redact argument values in deny logs.
333    /// Either a configured stable salt or a per-process random salt.
334    redaction_salt: Arc<SecretString>,
335}
336
337impl RbacPolicy {
338    /// Build a policy from config.  When `config.enabled` is false, all
339    /// checks return [`RbacDecision::Allow`].
340    #[must_use]
341    pub fn new(config: &RbacConfig) -> Self {
342        let salt = config
343            .redaction_salt
344            .clone()
345            .unwrap_or_else(|| process_redaction_salt().clone());
346        Self {
347            roles: config.roles.clone(),
348            enabled: config.enabled,
349            redaction_salt: Arc::new(salt),
350        }
351    }
352
353    /// Create a policy that always allows (RBAC disabled).
354    #[must_use]
355    pub fn disabled() -> Self {
356        Self {
357            roles: Vec::new(),
358            enabled: false,
359            redaction_salt: Arc::new(process_redaction_salt().clone()),
360        }
361    }
362
363    /// Whether RBAC enforcement is active.
364    #[must_use]
365    pub fn is_enabled(&self) -> bool {
366        self.enabled
367    }
368
369    /// Summarize the policy for diagnostics (admin endpoint).
370    ///
371    /// Returns `(enabled, role_count, per_role_stats)` where each stat is
372    /// `(name, allow_count, deny_count, host_count, argument_allowlist_count)`.
373    #[must_use]
374    pub fn summary(&self) -> RbacPolicySummary {
375        let roles = self
376            .roles
377            .iter()
378            .map(|r| RbacRoleSummary {
379                name: r.name.clone(),
380                allow: r.allow.len(),
381                deny: r.deny.len(),
382                hosts: r.hosts.len(),
383                argument_allowlists: r.argument_allowlists.len(),
384            })
385            .collect();
386        RbacPolicySummary {
387            enabled: self.enabled,
388            roles,
389        }
390    }
391
392    /// Check whether `role` may perform `operation` (ignoring host).
393    ///
394    /// Use this for tools that don't target a specific host (e.g. `ping`,
395    /// `list_hosts`).
396    #[must_use]
397    pub fn check_operation(&self, role: &str, operation: &str) -> RbacDecision {
398        if !self.enabled {
399            return RbacDecision::Allow;
400        }
401        let Some(role_cfg) = self.find_role(role) else {
402            return RbacDecision::Deny;
403        };
404        if role_cfg.deny.iter().any(|d| d == operation) {
405            return RbacDecision::Deny;
406        }
407        if role_cfg.allow.iter().any(|a| a == "*" || a == operation) {
408            return RbacDecision::Allow;
409        }
410        RbacDecision::Deny
411    }
412
413    /// Check whether `role` may perform `operation` on `host`.
414    ///
415    /// Evaluation order:
416    /// 1. If RBAC is disabled, allow.
417    /// 2. Check operation permission (deny overrides allow).
418    /// 3. Check host visibility via glob matching.
419    #[must_use]
420    pub fn check(&self, role: &str, operation: &str, host: &str) -> RbacDecision {
421        if !self.enabled {
422            return RbacDecision::Allow;
423        }
424        let Some(role_cfg) = self.find_role(role) else {
425            return RbacDecision::Deny;
426        };
427        if role_cfg.deny.iter().any(|d| d == operation) {
428            return RbacDecision::Deny;
429        }
430        if !role_cfg.allow.iter().any(|a| a == "*" || a == operation) {
431            return RbacDecision::Deny;
432        }
433        if !Self::host_matches(&role_cfg.hosts, host) {
434            return RbacDecision::Deny;
435        }
436        RbacDecision::Allow
437    }
438
439    /// Check whether `role` can see `host` at all (for `list_hosts` filtering).
440    #[must_use]
441    pub fn host_visible(&self, role: &str, host: &str) -> bool {
442        if !self.enabled {
443            return true;
444        }
445        let Some(role_cfg) = self.find_role(role) else {
446            return false;
447        };
448        Self::host_matches(&role_cfg.hosts, host)
449    }
450
451    /// Get the list of hosts patterns for a role.
452    #[must_use]
453    pub fn host_patterns(&self, role: &str) -> Option<&[String]> {
454        self.find_role(role).map(|r| r.hosts.as_slice())
455    }
456
457    /// Check whether `value` passes the argument allowlists for `tool` under `role`.
458    ///
459    /// If the role has no matching `argument_allowlists` entry for the tool,
460    /// all values are allowed. When a matching entry exists, the first
461    /// whitespace-delimited token of `value` (or its `/`-basename) must
462    /// appear in the `allowed` list.
463    #[must_use]
464    pub fn argument_allowed(&self, role: &str, tool: &str, argument: &str, value: &str) -> bool {
465        if !self.enabled {
466            return true;
467        }
468        let Some(role_cfg) = self.find_role(role) else {
469            return false;
470        };
471        for al in &role_cfg.argument_allowlists {
472            if al.tool != tool && !glob_match(&al.tool, tool) {
473                continue;
474            }
475            if al.argument != argument {
476                continue;
477            }
478            if al.allowed.is_empty() {
479                continue;
480            }
481            // Match the first token (the executable / keyword).
482            let first_token = value.split_whitespace().next().unwrap_or(value);
483            // Also match against the basename if it's a path.
484            let basename = first_token.rsplit('/').next().unwrap_or(first_token);
485            if !al.allowed.iter().any(|a| a == first_token || a == basename) {
486                return false;
487            }
488        }
489        true
490    }
491
492    /// Return the role config for a given role name.
493    fn find_role(&self, name: &str) -> Option<&RoleConfig> {
494        self.roles.iter().find(|r| r.name == name)
495    }
496
497    /// Check if a host name matches any of the given glob patterns.
498    fn host_matches(patterns: &[String], host: &str) -> bool {
499        patterns.iter().any(|p| glob_match(p, host))
500    }
501
502    /// HMAC-SHA256 the given argument value with this policy's redaction
503    /// salt and return the first 8 hex characters (4 bytes / 32 bits).
504    ///
505    /// 32 bits is enough entropy for log correlation (1-in-4-billion
506    /// collision per pair) while being far short of any preimage attack
507    /// surface for an attacker reading logs. The HMAC construction
508    /// guarantees that even short or low-entropy values cannot be
509    /// recovered without the key.
510    #[must_use]
511    pub fn redact_arg(&self, value: &str) -> String {
512        redact_with_salt(self.redaction_salt.expose_secret().as_bytes(), value)
513    }
514}
515
516/// Process-wide random redaction salt, lazily generated on first use.
517/// Used when [`RbacConfig::redaction_salt`] is `None`.
518fn process_redaction_salt() -> &'static SecretString {
519    use base64::{Engine as _, engine::general_purpose::STANDARD_NO_PAD};
520    static PROCESS_SALT: std::sync::OnceLock<SecretString> = std::sync::OnceLock::new();
521    PROCESS_SALT.get_or_init(|| {
522        let mut bytes = [0u8; 32];
523        rand::fill(&mut bytes);
524        // base64-encode so the SecretString is valid UTF-8; the HMAC
525        // accepts arbitrary key bytes regardless.
526        SecretString::from(STANDARD_NO_PAD.encode(bytes))
527    })
528}
529
530/// HMAC-SHA256(`salt`, `value`) → first 8 hex chars.
531///
532/// Pulled out as a free function so it can be unit-tested and benchmarked
533/// without constructing a full [`RbacPolicy`].
534fn redact_with_salt(salt: &[u8], value: &str) -> String {
535    use std::fmt::Write as _;
536
537    use sha2::Digest as _;
538
539    type HmacSha256 = Hmac<Sha256>;
540    // HMAC-SHA256 accepts keys of any byte length: the spec pads short
541    // keys with zeros and hashes long keys, so `new_from_slice` is
542    // infallible here. We still defensively re-key with a SHA-256 of
543    // the salt if construction ever fails (e.g. future hmac upstream
544    // tightens the contract); both branches produce a valid keyed MAC.
545    let mut mac = if let Ok(m) = HmacSha256::new_from_slice(salt) {
546        m
547    } else {
548        let digest = Sha256::digest(salt);
549        #[allow(clippy::expect_used)] // 32-byte digest always valid as HMAC key
550        HmacSha256::new_from_slice(&digest).expect("32-byte SHA256 digest is valid HMAC key")
551    };
552    mac.update(value.as_bytes());
553    let bytes = mac.finalize().into_bytes();
554    // 4 bytes → 8 hex chars.
555    let prefix = bytes.get(..4).unwrap_or(&[0; 4]);
556    let mut out = String::with_capacity(8);
557    for b in prefix {
558        let _ = write!(out, "{b:02x}");
559    }
560    out
561}
562
563// -- RBAC middleware --
564
565/// Axum middleware that enforces RBAC and per-IP tool rate limiting on
566/// MCP tool calls.
567///
568/// Inspects POST request bodies for `tools/call` JSON-RPC messages,
569/// extracts the tool name and `host` argument, and checks the
570/// [`RbacPolicy`] against the [`AuthIdentity`] set by the auth middleware.
571///
572/// When a `tool_limiter` is provided, tool invocations are rate-limited
573/// per source IP regardless of whether RBAC is enabled (MCP spec: servers
574/// MUST rate limit tool invocations).
575///
576/// Non-POST requests and non-tool-call messages pass through unchanged.
577/// The caller's role is stored in task-local storage for use by tool
578/// handlers (e.g. `list_hosts` host filtering via [`current_role()`]).
579// TODO(refactor): cognitive complexity reduced from 43/25 by extracting
580// `enforce_tool_policy` and `enforce_rate_limit`. Remaining flow is a
581// linear body-collect + JSON-RPC parse + dispatch, intentionally left
582// inline to keep the request lifecycle visible at a glance.
583#[allow(clippy::too_many_lines)]
584pub(crate) async fn rbac_middleware(
585    policy: Arc<RbacPolicy>,
586    tool_limiter: Option<Arc<ToolRateLimiter>>,
587    req: Request<Body>,
588    next: Next,
589) -> Response {
590    // Only inspect POST requests - tool calls are POSTs.
591    if req.method() != Method::POST {
592        return next.run(req).await;
593    }
594
595    // Extract peer IP for rate limiting.
596    let peer_ip: Option<IpAddr> = req
597        .extensions()
598        .get::<ConnectInfo<std::net::SocketAddr>>()
599        .map(|ci| ci.0.ip())
600        .or_else(|| {
601            req.extensions()
602                .get::<ConnectInfo<TlsConnInfo>>()
603                .map(|ci| ci.0.addr.ip())
604        });
605
606    // Extract caller identity and role (may be absent when auth is off).
607    let identity = req.extensions().get::<AuthIdentity>();
608    let identity_name = identity.map(|id| id.name.clone()).unwrap_or_default();
609    let role = identity.map(|id| id.role.clone()).unwrap_or_default();
610    // Clone the SecretString end-to-end; an absent token becomes an empty
611    // SecretString sentinel (current_token() filters this out as None).
612    let raw_token: SecretString = identity
613        .and_then(|id| id.raw_token.clone())
614        .unwrap_or_else(|| SecretString::from(String::new()));
615    let sub = identity.and_then(|id| id.sub.clone()).unwrap_or_default();
616
617    // RBAC requires an authenticated identity.
618    if policy.is_enabled() && identity.is_none() {
619        return McpxError::Rbac("no authenticated identity".into()).into_response();
620    }
621
622    // Read the body for JSON-RPC inspection.
623    let (parts, body) = req.into_parts();
624    let bytes = match body.collect().await {
625        Ok(collected) => collected.to_bytes(),
626        Err(e) => {
627            tracing::error!(error = %e, "failed to read request body");
628            return (
629                StatusCode::INTERNAL_SERVER_ERROR,
630                "failed to read request body",
631            )
632                .into_response();
633        }
634    };
635
636    // Try to parse as JSON and inspect JSON-RPC tool calls, including batch arrays.
637    if let Ok(json) = serde_json::from_slice::<serde_json::Value>(&bytes) {
638        let tool_calls = extract_tool_calls(&json);
639        if !tool_calls.is_empty() {
640            for params in tool_calls {
641                if let Some(resp) = enforce_rate_limit(tool_limiter.as_deref(), peer_ip) {
642                    return resp;
643                }
644                if policy.is_enabled()
645                    && let Some(resp) = enforce_tool_policy(&policy, &identity_name, &role, params)
646                {
647                    return resp;
648                }
649            }
650        }
651    }
652    // Non-parseable or non-tool-call requests pass through.
653
654    // Reconstruct the request with the consumed body.
655    let req = Request::from_parts(parts, Body::from(bytes));
656
657    // Set the caller's role and identity in task-local storage for the handler.
658    if role.is_empty() {
659        next.run(req).await
660    } else {
661        CURRENT_ROLE
662            .scope(
663                role,
664                CURRENT_IDENTITY.scope(
665                    identity_name,
666                    CURRENT_TOKEN.scope(raw_token, CURRENT_SUB.scope(sub, next.run(req))),
667                ),
668            )
669            .await
670    }
671}
672
673/// Extract the `params` object for every top-level `tools/call` message.
674///
675/// Supports either a single JSON-RPC object or a JSON-RPC batch array. Any
676/// malformed elements are ignored so non-RPC payloads continue to pass through
677/// unchanged.
678fn extract_tool_calls(value: &serde_json::Value) -> Vec<&serde_json::Value> {
679    match value {
680        serde_json::Value::Object(map) => map
681            .get("method")
682            .and_then(serde_json::Value::as_str)
683            .filter(|method| *method == "tools/call")
684            .and_then(|_| map.get("params"))
685            .into_iter()
686            .collect(),
687        serde_json::Value::Array(items) => items
688            .iter()
689            .filter_map(|item| match item {
690                serde_json::Value::Object(map) => map
691                    .get("method")
692                    .and_then(serde_json::Value::as_str)
693                    .filter(|method| *method == "tools/call")
694                    .and_then(|_| map.get("params")),
695                serde_json::Value::Null
696                | serde_json::Value::Bool(_)
697                | serde_json::Value::Number(_)
698                | serde_json::Value::String(_)
699                | serde_json::Value::Array(_) => None,
700            })
701            .collect(),
702        serde_json::Value::Null
703        | serde_json::Value::Bool(_)
704        | serde_json::Value::Number(_)
705        | serde_json::Value::String(_) => Vec::new(),
706    }
707}
708
709/// Per-IP rate limit check for tool invocations. Returns `Some(response)`
710/// if the caller should be rejected.
711fn enforce_rate_limit(
712    tool_limiter: Option<&ToolRateLimiter>,
713    peer_ip: Option<IpAddr>,
714) -> Option<Response> {
715    let limiter = tool_limiter?;
716    let ip = peer_ip?;
717    if limiter.check_key(&ip).is_err() {
718        tracing::warn!(%ip, "tool invocation rate limited");
719        return Some(McpxError::RateLimited("too many tool invocations".into()).into_response());
720    }
721    None
722}
723
724/// Apply RBAC tool/host + argument-allowlist checks. Returns `Some(response)`
725/// when the caller must be rejected. Assumes `policy.is_enabled()`.
726///
727/// `identity_name` is passed explicitly (rather than read from
728/// [`current_identity()`]) because this function runs *before* the
729/// task-local context is installed by the middleware. Reading the
730/// task-local here would always yield `None`, producing deny logs with
731/// an empty `user` field.
732fn enforce_tool_policy(
733    policy: &RbacPolicy,
734    identity_name: &str,
735    role: &str,
736    params: &serde_json::Value,
737) -> Option<Response> {
738    let tool_name = params.get("name").and_then(|v| v.as_str()).unwrap_or("");
739    let host = params
740        .get("arguments")
741        .and_then(|a| a.get("host"))
742        .and_then(|h| h.as_str());
743
744    let decision = if let Some(host) = host {
745        policy.check(role, tool_name, host)
746    } else {
747        policy.check_operation(role, tool_name)
748    };
749    if decision == RbacDecision::Deny {
750        tracing::warn!(
751            user = %identity_name,
752            role = %role,
753            tool = tool_name,
754            host = host.unwrap_or("-"),
755            "RBAC denied"
756        );
757        return Some(
758            McpxError::Rbac(format!("{tool_name} denied for role '{role}'")).into_response(),
759        );
760    }
761
762    let args = params.get("arguments").and_then(|a| a.as_object())?;
763    for (arg_key, arg_val) in args {
764        if let Some(val_str) = arg_val.as_str()
765            && !policy.argument_allowed(role, tool_name, arg_key, val_str)
766        {
767            // Redact the raw value: log an HMAC-SHA256 prefix instead of
768            // the literal string. Operators correlate hashes across log
769            // lines without ever exposing potentially sensitive inputs
770            // (paths, IDs, tokens accidentally passed as args, etc.).
771            tracing::warn!(
772                user = %identity_name,
773                role = %role,
774                tool = tool_name,
775                argument = arg_key,
776                arg_hmac = %policy.redact_arg(val_str),
777                "argument not in allowlist"
778            );
779            return Some(
780                McpxError::Rbac(format!(
781                    "argument '{arg_key}' value not in allowlist for tool '{tool_name}'"
782                ))
783                .into_response(),
784            );
785        }
786    }
787    None
788}
789
790/// Simple glob matching: `*` matches any sequence of characters.
791///
792/// Supports multiple `*` wildcards anywhere in the pattern.
793/// No `?`, `[...]`, or other advanced glob features.
794fn glob_match(pattern: &str, text: &str) -> bool {
795    let parts: Vec<&str> = pattern.split('*').collect();
796    if parts.len() == 1 {
797        // No wildcards - exact match.
798        return pattern == text;
799    }
800
801    let mut pos = 0;
802
803    // First part must match at the start (unless pattern starts with *).
804    if let Some(&first) = parts.first()
805        && !first.is_empty()
806    {
807        if !text.starts_with(first) {
808            return false;
809        }
810        pos = first.len();
811    }
812
813    // Last part must match at the end (unless pattern ends with *).
814    if let Some(&last) = parts.last()
815        && !last.is_empty()
816    {
817        if !text[pos..].ends_with(last) {
818            return false;
819        }
820        // Shrink the search area so middle parts don't overlap with the suffix.
821        let end = text.len() - last.len();
822        if pos > end {
823            return false;
824        }
825        // Check middle parts in the remaining region.
826        let middle = &text[pos..end];
827        let middle_parts = parts.get(1..parts.len() - 1).unwrap_or_default();
828        return match_middle(middle, middle_parts);
829    }
830
831    // Pattern ends with * - just check middle parts.
832    let middle = &text[pos..];
833    let middle_parts = parts.get(1..parts.len() - 1).unwrap_or_default();
834    match_middle(middle, middle_parts)
835}
836
837/// Match middle glob segments sequentially in `text`.
838fn match_middle(mut text: &str, parts: &[&str]) -> bool {
839    for part in parts {
840        if part.is_empty() {
841            continue;
842        }
843        if let Some(idx) = text.find(part) {
844            text = &text[idx + part.len()..];
845        } else {
846            return false;
847        }
848    }
849    true
850}
851
852#[cfg(test)]
853mod tests {
854    use super::*;
855
856    fn test_policy() -> RbacPolicy {
857        RbacPolicy::new(&RbacConfig {
858            enabled: true,
859            roles: vec![
860                RoleConfig {
861                    name: "viewer".into(),
862                    description: Some("Read-only".into()),
863                    allow: vec![
864                        "list_hosts".into(),
865                        "resource_list".into(),
866                        "resource_inspect".into(),
867                        "resource_logs".into(),
868                        "system_info".into(),
869                    ],
870                    deny: vec![],
871                    hosts: vec!["*".into()],
872                    argument_allowlists: vec![],
873                },
874                RoleConfig {
875                    name: "deploy".into(),
876                    description: Some("Lifecycle management".into()),
877                    allow: vec![
878                        "list_hosts".into(),
879                        "resource_list".into(),
880                        "resource_run".into(),
881                        "resource_start".into(),
882                        "resource_stop".into(),
883                        "resource_restart".into(),
884                        "resource_logs".into(),
885                        "image_pull".into(),
886                    ],
887                    deny: vec!["resource_delete".into(), "resource_exec".into()],
888                    hosts: vec!["web-*".into(), "api-*".into()],
889                    argument_allowlists: vec![],
890                },
891                RoleConfig {
892                    name: "ops".into(),
893                    description: Some("Full access".into()),
894                    allow: vec!["*".into()],
895                    deny: vec![],
896                    hosts: vec!["*".into()],
897                    argument_allowlists: vec![],
898                },
899                RoleConfig {
900                    name: "restricted-exec".into(),
901                    description: Some("Exec with argument allowlist".into()),
902                    allow: vec!["resource_exec".into()],
903                    deny: vec![],
904                    hosts: vec!["dev-*".into()],
905                    argument_allowlists: vec![ArgumentAllowlist {
906                        tool: "resource_exec".into(),
907                        argument: "cmd".into(),
908                        allowed: vec![
909                            "sh".into(),
910                            "bash".into(),
911                            "cat".into(),
912                            "ls".into(),
913                            "ps".into(),
914                        ],
915                    }],
916                },
917            ],
918            redaction_salt: None,
919        })
920    }
921
922    // -- glob_match tests --
923
924    #[test]
925    fn glob_exact_match() {
926        assert!(glob_match("web-prod-1", "web-prod-1"));
927        assert!(!glob_match("web-prod-1", "web-prod-2"));
928    }
929
930    #[test]
931    fn glob_star_suffix() {
932        assert!(glob_match("web-*", "web-prod-1"));
933        assert!(glob_match("web-*", "web-staging"));
934        assert!(!glob_match("web-*", "api-prod"));
935    }
936
937    #[test]
938    fn glob_star_prefix() {
939        assert!(glob_match("*-prod", "web-prod"));
940        assert!(glob_match("*-prod", "api-prod"));
941        assert!(!glob_match("*-prod", "web-staging"));
942    }
943
944    #[test]
945    fn glob_star_middle() {
946        assert!(glob_match("web-*-prod", "web-us-prod"));
947        assert!(glob_match("web-*-prod", "web-eu-east-prod"));
948        assert!(!glob_match("web-*-prod", "web-staging"));
949    }
950
951    #[test]
952    fn glob_star_only() {
953        assert!(glob_match("*", "anything"));
954        assert!(glob_match("*", ""));
955    }
956
957    #[test]
958    fn glob_multiple_stars() {
959        assert!(glob_match("*web*prod*", "my-web-us-prod-1"));
960        assert!(!glob_match("*web*prod*", "my-api-us-staging"));
961    }
962
963    // -- RbacPolicy::check tests --
964
965    #[test]
966    fn disabled_policy_allows_everything() {
967        let policy = RbacPolicy::new(&RbacConfig {
968            enabled: false,
969            roles: vec![],
970            redaction_salt: None,
971        });
972        assert_eq!(
973            policy.check("nonexistent", "resource_delete", "any-host"),
974            RbacDecision::Allow
975        );
976    }
977
978    #[test]
979    fn unknown_role_denied() {
980        let policy = test_policy();
981        assert_eq!(
982            policy.check("unknown", "resource_list", "web-prod-1"),
983            RbacDecision::Deny
984        );
985    }
986
987    #[test]
988    fn viewer_allowed_read_ops() {
989        let policy = test_policy();
990        assert_eq!(
991            policy.check("viewer", "resource_list", "web-prod-1"),
992            RbacDecision::Allow
993        );
994        assert_eq!(
995            policy.check("viewer", "system_info", "db-host"),
996            RbacDecision::Allow
997        );
998    }
999
1000    #[test]
1001    fn viewer_denied_write_ops() {
1002        let policy = test_policy();
1003        assert_eq!(
1004            policy.check("viewer", "resource_run", "web-prod-1"),
1005            RbacDecision::Deny
1006        );
1007        assert_eq!(
1008            policy.check("viewer", "resource_delete", "web-prod-1"),
1009            RbacDecision::Deny
1010        );
1011    }
1012
1013    #[test]
1014    fn deploy_allowed_on_matching_hosts() {
1015        let policy = test_policy();
1016        assert_eq!(
1017            policy.check("deploy", "resource_run", "web-prod-1"),
1018            RbacDecision::Allow
1019        );
1020        assert_eq!(
1021            policy.check("deploy", "resource_start", "api-staging"),
1022            RbacDecision::Allow
1023        );
1024    }
1025
1026    #[test]
1027    fn deploy_denied_on_non_matching_host() {
1028        let policy = test_policy();
1029        assert_eq!(
1030            policy.check("deploy", "resource_run", "db-prod-1"),
1031            RbacDecision::Deny
1032        );
1033    }
1034
1035    #[test]
1036    fn deny_overrides_allow() {
1037        let policy = test_policy();
1038        assert_eq!(
1039            policy.check("deploy", "resource_delete", "web-prod-1"),
1040            RbacDecision::Deny
1041        );
1042        assert_eq!(
1043            policy.check("deploy", "resource_exec", "web-prod-1"),
1044            RbacDecision::Deny
1045        );
1046    }
1047
1048    #[test]
1049    fn ops_wildcard_allows_everything() {
1050        let policy = test_policy();
1051        assert_eq!(
1052            policy.check("ops", "resource_delete", "any-host"),
1053            RbacDecision::Allow
1054        );
1055        assert_eq!(
1056            policy.check("ops", "secret_create", "db-host"),
1057            RbacDecision::Allow
1058        );
1059    }
1060
1061    // -- host_visible tests --
1062
1063    #[test]
1064    fn host_visible_respects_globs() {
1065        let policy = test_policy();
1066        assert!(policy.host_visible("deploy", "web-prod-1"));
1067        assert!(policy.host_visible("deploy", "api-staging"));
1068        assert!(!policy.host_visible("deploy", "db-prod-1"));
1069        assert!(policy.host_visible("ops", "anything"));
1070        assert!(policy.host_visible("viewer", "anything"));
1071    }
1072
1073    #[test]
1074    fn host_visible_unknown_role() {
1075        let policy = test_policy();
1076        assert!(!policy.host_visible("unknown", "web-prod-1"));
1077    }
1078
1079    // -- argument_allowed tests --
1080
1081    #[test]
1082    fn argument_allowed_no_allowlist() {
1083        let policy = test_policy();
1084        // ops has no argument_allowlists -- all values allowed
1085        assert!(policy.argument_allowed("ops", "resource_exec", "cmd", "rm -rf /"));
1086        assert!(policy.argument_allowed("ops", "resource_exec", "cmd", "bash"));
1087    }
1088
1089    #[test]
1090    fn argument_allowed_with_allowlist() {
1091        let policy = test_policy();
1092        assert!(policy.argument_allowed("restricted-exec", "resource_exec", "cmd", "sh"));
1093        assert!(policy.argument_allowed(
1094            "restricted-exec",
1095            "resource_exec",
1096            "cmd",
1097            "bash -c 'echo hi'"
1098        ));
1099        assert!(policy.argument_allowed(
1100            "restricted-exec",
1101            "resource_exec",
1102            "cmd",
1103            "cat /etc/hosts"
1104        ));
1105        assert!(policy.argument_allowed(
1106            "restricted-exec",
1107            "resource_exec",
1108            "cmd",
1109            "/usr/bin/ls -la"
1110        ));
1111    }
1112
1113    #[test]
1114    fn argument_denied_not_in_allowlist() {
1115        let policy = test_policy();
1116        assert!(!policy.argument_allowed("restricted-exec", "resource_exec", "cmd", "rm -rf /"));
1117        assert!(!policy.argument_allowed(
1118            "restricted-exec",
1119            "resource_exec",
1120            "cmd",
1121            "python3 exploit.py"
1122        ));
1123        assert!(!policy.argument_allowed(
1124            "restricted-exec",
1125            "resource_exec",
1126            "cmd",
1127            "/usr/bin/curl evil.com"
1128        ));
1129    }
1130
1131    #[test]
1132    fn argument_denied_unknown_role() {
1133        let policy = test_policy();
1134        assert!(!policy.argument_allowed("unknown", "resource_exec", "cmd", "sh"));
1135    }
1136
1137    // -- host_patterns tests --
1138
1139    #[test]
1140    fn host_patterns_returns_globs() {
1141        let policy = test_policy();
1142        assert_eq!(
1143            policy.host_patterns("deploy"),
1144            Some(vec!["web-*".to_owned(), "api-*".to_owned()].as_slice())
1145        );
1146        assert_eq!(
1147            policy.host_patterns("ops"),
1148            Some(vec!["*".to_owned()].as_slice())
1149        );
1150        assert!(policy.host_patterns("nonexistent").is_none());
1151    }
1152
1153    // -- check_operation tests (no host check) --
1154
1155    #[test]
1156    fn check_operation_allows_without_host() {
1157        let policy = test_policy();
1158        assert_eq!(
1159            policy.check_operation("deploy", "resource_run"),
1160            RbacDecision::Allow
1161        );
1162        // but check() with a non-matching host denies
1163        assert_eq!(
1164            policy.check("deploy", "resource_run", "db-prod-1"),
1165            RbacDecision::Deny
1166        );
1167    }
1168
1169    #[test]
1170    fn check_operation_deny_overrides() {
1171        let policy = test_policy();
1172        assert_eq!(
1173            policy.check_operation("deploy", "resource_delete"),
1174            RbacDecision::Deny
1175        );
1176    }
1177
1178    #[test]
1179    fn check_operation_unknown_role() {
1180        let policy = test_policy();
1181        assert_eq!(
1182            policy.check_operation("unknown", "resource_list"),
1183            RbacDecision::Deny
1184        );
1185    }
1186
1187    #[test]
1188    fn check_operation_disabled() {
1189        let policy = RbacPolicy::new(&RbacConfig {
1190            enabled: false,
1191            roles: vec![],
1192            redaction_salt: None,
1193        });
1194        assert_eq!(
1195            policy.check_operation("nonexistent", "anything"),
1196            RbacDecision::Allow
1197        );
1198    }
1199
1200    // -- current_role / current_identity tests --
1201
1202    #[test]
1203    fn current_role_returns_none_outside_scope() {
1204        assert!(current_role().is_none());
1205    }
1206
1207    #[test]
1208    fn current_identity_returns_none_outside_scope() {
1209        assert!(current_identity().is_none());
1210    }
1211
1212    // -- rbac_middleware integration tests --
1213
1214    use axum::{
1215        body::Body,
1216        http::{Method, Request, StatusCode},
1217    };
1218    use tower::ServiceExt as _;
1219
1220    fn tool_call_body(tool: &str, args: &serde_json::Value) -> String {
1221        serde_json::json!({
1222            "jsonrpc": "2.0",
1223            "id": 1,
1224            "method": "tools/call",
1225            "params": {
1226                "name": tool,
1227                "arguments": args
1228            }
1229        })
1230        .to_string()
1231    }
1232
1233    fn rbac_router(policy: Arc<RbacPolicy>) -> axum::Router {
1234        axum::Router::new()
1235            .route("/mcp", axum::routing::post(|| async { "ok" }))
1236            .layer(axum::middleware::from_fn(move |req, next| {
1237                let p = Arc::clone(&policy);
1238                rbac_middleware(p, None, req, next)
1239            }))
1240    }
1241
1242    fn rbac_router_with_identity(policy: Arc<RbacPolicy>, identity: AuthIdentity) -> axum::Router {
1243        axum::Router::new()
1244            .route("/mcp", axum::routing::post(|| async { "ok" }))
1245            .layer(axum::middleware::from_fn(
1246                move |mut req: Request<Body>, next: Next| {
1247                    let p = Arc::clone(&policy);
1248                    let id = identity.clone();
1249                    async move {
1250                        req.extensions_mut().insert(id);
1251                        rbac_middleware(p, None, req, next).await
1252                    }
1253                },
1254            ))
1255    }
1256
1257    #[tokio::test]
1258    async fn middleware_passes_non_post() {
1259        let policy = Arc::new(test_policy());
1260        let app = rbac_router(policy);
1261        // GET passes through even without identity.
1262        let req = Request::builder()
1263            .method(Method::GET)
1264            .uri("/mcp")
1265            .body(Body::empty())
1266            .unwrap();
1267        // GET on a POST-only route returns 405, but the middleware itself
1268        // doesn't block it -- it returns next.run(req).
1269        let resp = app.oneshot(req).await.unwrap();
1270        assert_eq!(resp.status(), StatusCode::METHOD_NOT_ALLOWED);
1271    }
1272
1273    #[tokio::test]
1274    async fn middleware_denies_without_identity() {
1275        let policy = Arc::new(test_policy());
1276        let app = rbac_router(policy);
1277        let body = tool_call_body("resource_list", &serde_json::json!({}));
1278        let req = Request::builder()
1279            .method(Method::POST)
1280            .uri("/mcp")
1281            .header("content-type", "application/json")
1282            .body(Body::from(body))
1283            .unwrap();
1284        let resp = app.oneshot(req).await.unwrap();
1285        assert_eq!(resp.status(), StatusCode::FORBIDDEN);
1286    }
1287
1288    #[tokio::test]
1289    async fn middleware_allows_permitted_tool() {
1290        let policy = Arc::new(test_policy());
1291        let id = AuthIdentity {
1292            method: crate::auth::AuthMethod::BearerToken,
1293            name: "alice".into(),
1294            role: "viewer".into(),
1295            raw_token: None,
1296            sub: None,
1297        };
1298        let app = rbac_router_with_identity(policy, id);
1299        let body = tool_call_body("resource_list", &serde_json::json!({}));
1300        let req = Request::builder()
1301            .method(Method::POST)
1302            .uri("/mcp")
1303            .header("content-type", "application/json")
1304            .body(Body::from(body))
1305            .unwrap();
1306        let resp = app.oneshot(req).await.unwrap();
1307        assert_eq!(resp.status(), StatusCode::OK);
1308    }
1309
1310    #[tokio::test]
1311    async fn middleware_denies_unpermitted_tool() {
1312        let policy = Arc::new(test_policy());
1313        let id = AuthIdentity {
1314            method: crate::auth::AuthMethod::BearerToken,
1315            name: "alice".into(),
1316            role: "viewer".into(),
1317            raw_token: None,
1318            sub: None,
1319        };
1320        let app = rbac_router_with_identity(policy, id);
1321        let body = tool_call_body("resource_delete", &serde_json::json!({}));
1322        let req = Request::builder()
1323            .method(Method::POST)
1324            .uri("/mcp")
1325            .header("content-type", "application/json")
1326            .body(Body::from(body))
1327            .unwrap();
1328        let resp = app.oneshot(req).await.unwrap();
1329        assert_eq!(resp.status(), StatusCode::FORBIDDEN);
1330    }
1331
1332    #[tokio::test]
1333    async fn middleware_passes_non_tool_call_post() {
1334        let policy = Arc::new(test_policy());
1335        let id = AuthIdentity {
1336            method: crate::auth::AuthMethod::BearerToken,
1337            name: "alice".into(),
1338            role: "viewer".into(),
1339            raw_token: None,
1340            sub: None,
1341        };
1342        let app = rbac_router_with_identity(policy, id);
1343        // A non-tools/call JSON-RPC (e.g. resources/list) passes through.
1344        let body = serde_json::json!({
1345            "jsonrpc": "2.0",
1346            "id": 1,
1347            "method": "resources/list"
1348        })
1349        .to_string();
1350        let req = Request::builder()
1351            .method(Method::POST)
1352            .uri("/mcp")
1353            .header("content-type", "application/json")
1354            .body(Body::from(body))
1355            .unwrap();
1356        let resp = app.oneshot(req).await.unwrap();
1357        assert_eq!(resp.status(), StatusCode::OK);
1358    }
1359
1360    #[tokio::test]
1361    async fn middleware_enforces_argument_allowlist() {
1362        let policy = Arc::new(test_policy());
1363        let id = AuthIdentity {
1364            method: crate::auth::AuthMethod::BearerToken,
1365            name: "dev".into(),
1366            role: "restricted-exec".into(),
1367            raw_token: None,
1368            sub: None,
1369        };
1370        // Allowed command
1371        let app = rbac_router_with_identity(Arc::clone(&policy), id.clone());
1372        let body = tool_call_body(
1373            "resource_exec",
1374            &serde_json::json!({"cmd": "ls -la", "host": "dev-1"}),
1375        );
1376        let req = Request::builder()
1377            .method(Method::POST)
1378            .uri("/mcp")
1379            .body(Body::from(body))
1380            .unwrap();
1381        let resp = app.oneshot(req).await.unwrap();
1382        assert_eq!(resp.status(), StatusCode::OK);
1383
1384        // Denied command
1385        let app = rbac_router_with_identity(policy, id);
1386        let body = tool_call_body(
1387            "resource_exec",
1388            &serde_json::json!({"cmd": "rm -rf /", "host": "dev-1"}),
1389        );
1390        let req = Request::builder()
1391            .method(Method::POST)
1392            .uri("/mcp")
1393            .body(Body::from(body))
1394            .unwrap();
1395        let resp = app.oneshot(req).await.unwrap();
1396        assert_eq!(resp.status(), StatusCode::FORBIDDEN);
1397    }
1398
1399    #[tokio::test]
1400    async fn middleware_disabled_policy_passes_everything() {
1401        let policy = Arc::new(RbacPolicy::disabled());
1402        let app = rbac_router(policy);
1403        // No identity, disabled policy -- should pass.
1404        let body = tool_call_body("anything", &serde_json::json!({}));
1405        let req = Request::builder()
1406            .method(Method::POST)
1407            .uri("/mcp")
1408            .body(Body::from(body))
1409            .unwrap();
1410        let resp = app.oneshot(req).await.unwrap();
1411        assert_eq!(resp.status(), StatusCode::OK);
1412    }
1413
1414    #[tokio::test]
1415    async fn middleware_batch_all_allowed_passes() {
1416        let policy = Arc::new(test_policy());
1417        let id = AuthIdentity {
1418            method: crate::auth::AuthMethod::BearerToken,
1419            name: "alice".into(),
1420            role: "viewer".into(),
1421            raw_token: None,
1422            sub: None,
1423        };
1424        let app = rbac_router_with_identity(policy, id);
1425        let body = serde_json::json!([
1426            {
1427                "jsonrpc": "2.0",
1428                "id": 1,
1429                "method": "tools/call",
1430                "params": { "name": "resource_list", "arguments": {} }
1431            },
1432            {
1433                "jsonrpc": "2.0",
1434                "id": 2,
1435                "method": "tools/call",
1436                "params": { "name": "system_info", "arguments": {} }
1437            }
1438        ])
1439        .to_string();
1440        let req = Request::builder()
1441            .method(Method::POST)
1442            .uri("/mcp")
1443            .header("content-type", "application/json")
1444            .body(Body::from(body))
1445            .unwrap();
1446        let resp = app.oneshot(req).await.unwrap();
1447        assert_eq!(resp.status(), StatusCode::OK);
1448    }
1449
1450    #[tokio::test]
1451    async fn middleware_batch_with_denied_call_rejects_entire_batch() {
1452        let policy = Arc::new(test_policy());
1453        let id = AuthIdentity {
1454            method: crate::auth::AuthMethod::BearerToken,
1455            name: "alice".into(),
1456            role: "viewer".into(),
1457            raw_token: None,
1458            sub: None,
1459        };
1460        let app = rbac_router_with_identity(policy, id);
1461        let body = serde_json::json!([
1462            {
1463                "jsonrpc": "2.0",
1464                "id": 1,
1465                "method": "tools/call",
1466                "params": { "name": "resource_list", "arguments": {} }
1467            },
1468            {
1469                "jsonrpc": "2.0",
1470                "id": 2,
1471                "method": "tools/call",
1472                "params": { "name": "resource_delete", "arguments": {} }
1473            }
1474        ])
1475        .to_string();
1476        let req = Request::builder()
1477            .method(Method::POST)
1478            .uri("/mcp")
1479            .header("content-type", "application/json")
1480            .body(Body::from(body))
1481            .unwrap();
1482        let resp = app.oneshot(req).await.unwrap();
1483        assert_eq!(resp.status(), StatusCode::FORBIDDEN);
1484    }
1485
1486    #[tokio::test]
1487    async fn middleware_batch_mixed_allowed_and_denied_rejects() {
1488        let policy = Arc::new(test_policy());
1489        let id = AuthIdentity {
1490            method: crate::auth::AuthMethod::BearerToken,
1491            name: "dev".into(),
1492            role: "restricted-exec".into(),
1493            raw_token: None,
1494            sub: None,
1495        };
1496        let app = rbac_router_with_identity(policy, id);
1497        let body = serde_json::json!([
1498            {
1499                "jsonrpc": "2.0",
1500                "id": 1,
1501                "method": "tools/call",
1502                "params": {
1503                    "name": "resource_exec",
1504                    "arguments": { "cmd": "ls -la", "host": "dev-1" }
1505                }
1506            },
1507            {
1508                "jsonrpc": "2.0",
1509                "id": 2,
1510                "method": "tools/call",
1511                "params": {
1512                    "name": "resource_exec",
1513                    "arguments": { "cmd": "rm -rf /", "host": "dev-1" }
1514                }
1515            }
1516        ])
1517        .to_string();
1518        let req = Request::builder()
1519            .method(Method::POST)
1520            .uri("/mcp")
1521            .header("content-type", "application/json")
1522            .body(Body::from(body))
1523            .unwrap();
1524        let resp = app.oneshot(req).await.unwrap();
1525        assert_eq!(resp.status(), StatusCode::FORBIDDEN);
1526    }
1527
1528    // -- redact_arg / redaction_salt tests --
1529
1530    #[test]
1531    fn redact_with_salt_is_deterministic_per_salt() {
1532        let salt = b"unit-test-salt";
1533        let a = redact_with_salt(salt, "rm -rf /");
1534        let b = redact_with_salt(salt, "rm -rf /");
1535        assert_eq!(a, b, "same input + salt must yield identical hash");
1536        assert_eq!(a.len(), 8, "redacted hash is 8 hex chars (4 bytes)");
1537        assert!(
1538            a.chars().all(|c| c.is_ascii_hexdigit()),
1539            "redacted hash must be lowercase hex: {a}"
1540        );
1541    }
1542
1543    #[test]
1544    fn redact_with_salt_differs_across_salts() {
1545        let v = "the-same-value";
1546        let h1 = redact_with_salt(b"salt-one", v);
1547        let h2 = redact_with_salt(b"salt-two", v);
1548        assert_ne!(
1549            h1, h2,
1550            "different salts must produce different hashes for the same value"
1551        );
1552    }
1553
1554    #[test]
1555    fn redact_with_salt_distinguishes_values() {
1556        let salt = b"k";
1557        let h1 = redact_with_salt(salt, "alpha");
1558        let h2 = redact_with_salt(salt, "beta");
1559        // Hash collisions on 32 bits are 1-in-4-billion; safe to assert.
1560        assert_ne!(h1, h2, "different values must produce different hashes");
1561    }
1562
1563    #[test]
1564    fn policy_with_configured_salt_redacts_consistently() {
1565        let cfg = RbacConfig {
1566            enabled: true,
1567            roles: vec![],
1568            redaction_salt: Some(SecretString::from("my-stable-salt")),
1569        };
1570        let p1 = RbacPolicy::new(&cfg);
1571        let p2 = RbacPolicy::new(&cfg);
1572        assert_eq!(
1573            p1.redact_arg("payload"),
1574            p2.redact_arg("payload"),
1575            "policies built from the same configured salt must agree"
1576        );
1577    }
1578
1579    #[test]
1580    fn policy_without_configured_salt_uses_process_salt() {
1581        let cfg = RbacConfig {
1582            enabled: true,
1583            roles: vec![],
1584            redaction_salt: None,
1585        };
1586        let p1 = RbacPolicy::new(&cfg);
1587        let p2 = RbacPolicy::new(&cfg);
1588        // Within one process, the lazy OnceLock salt is shared.
1589        assert_eq!(
1590            p1.redact_arg("payload"),
1591            p2.redact_arg("payload"),
1592            "process-wide salt must be consistent within one process"
1593        );
1594    }
1595
1596    #[test]
1597    fn redact_arg_is_fast_enough() {
1598        // Sanity floor: a single redaction should take well under 100 µs
1599        // even in unoptimized debug builds. Production criterion bench
1600        // (see H-T4 plan) will assert a stricter <10 µs threshold.
1601        let salt = b"perf-sanity-salt-32-bytes-padded";
1602        let value = "x".repeat(256);
1603        let start = std::time::Instant::now();
1604        let _ = redact_with_salt(salt, &value);
1605        let elapsed = start.elapsed();
1606        assert!(
1607            elapsed < Duration::from_millis(5),
1608            "single redact_with_salt took {elapsed:?}, expected <5 ms even in debug"
1609        );
1610    }
1611
1612    // -- enforce_tool_policy identity propagation regression test (BUG H-S3) --
1613
1614    /// Regression: when `enforce_tool_policy` denied a request, the deny
1615    /// log used to read `current_identity()`, which was always `None` at
1616    /// that point because the task-local context is installed *after*
1617    /// policy enforcement. The fix passes `identity_name` explicitly.
1618    ///
1619    /// We assert the deny path returns 403 (the visible behaviour).
1620    /// The log-content assertion lives behind tracing-test which we have
1621    /// not yet added as a dev-dep; the explicit-parameter signature alone
1622    /// makes the previous bug structurally impossible.
1623    #[tokio::test]
1624    async fn deny_path_uses_explicit_identity_not_task_local() {
1625        let policy = Arc::new(test_policy());
1626        let id = AuthIdentity {
1627            method: crate::auth::AuthMethod::BearerToken,
1628            name: "alice-the-auditor".into(),
1629            role: "viewer".into(),
1630            raw_token: None,
1631            sub: None,
1632        };
1633        let app = rbac_router_with_identity(policy, id);
1634        // viewer is not allowed to call resource_delete -> 403.
1635        let body = tool_call_body("resource_delete", &serde_json::json!({}));
1636        let req = Request::builder()
1637            .method(Method::POST)
1638            .uri("/mcp")
1639            .header("content-type", "application/json")
1640            .body(Body::from(body))
1641            .unwrap();
1642        let resp = app.oneshot(req).await.unwrap();
1643        assert_eq!(resp.status(), StatusCode::FORBIDDEN);
1644    }
1645}