use std::{net::IpAddr, num::NonZeroU32, sync::Arc, time::Duration};
use axum::{
body::Body,
extract::ConnectInfo,
http::{Method, Request, StatusCode},
middleware::Next,
response::{IntoResponse, Response},
};
use hmac::{Hmac, Mac};
use http_body_util::BodyExt;
use secrecy::{ExposeSecret, SecretString};
use serde::Deserialize;
use sha2::Sha256;
use crate::{
auth::{AuthIdentity, TlsConnInfo},
bounded_limiter::BoundedKeyedLimiter,
error::McpxError,
};
pub(crate) type ToolRateLimiter = BoundedKeyedLimiter<IpAddr>;
const DEFAULT_TOOL_RATE: NonZeroU32 = NonZeroU32::new(120).unwrap();
const DEFAULT_TOOL_MAX_TRACKED_KEYS: usize = 10_000;
const DEFAULT_TOOL_IDLE_EVICTION: Duration = Duration::from_mins(15);
#[must_use]
pub(crate) fn build_tool_rate_limiter(max_per_minute: u32) -> Arc<ToolRateLimiter> {
build_tool_rate_limiter_with_bounds(
max_per_minute,
DEFAULT_TOOL_MAX_TRACKED_KEYS,
DEFAULT_TOOL_IDLE_EVICTION,
)
}
#[must_use]
pub(crate) fn build_tool_rate_limiter_with_bounds(
max_per_minute: u32,
max_tracked_keys: usize,
idle_eviction: Duration,
) -> Arc<ToolRateLimiter> {
let quota =
governor::Quota::per_minute(NonZeroU32::new(max_per_minute).unwrap_or(DEFAULT_TOOL_RATE));
Arc::new(BoundedKeyedLimiter::new(
quota,
max_tracked_keys,
idle_eviction,
))
}
tokio::task_local! {
static CURRENT_ROLE: String;
static CURRENT_IDENTITY: String;
static CURRENT_TOKEN: SecretString;
static CURRENT_SUB: String;
}
#[must_use]
pub fn current_role() -> Option<String> {
CURRENT_ROLE.try_with(Clone::clone).ok()
}
#[must_use]
pub fn current_identity() -> Option<String> {
CURRENT_IDENTITY.try_with(Clone::clone).ok()
}
#[must_use]
pub fn current_token() -> Option<SecretString> {
CURRENT_TOKEN
.try_with(|t| {
if t.expose_secret().is_empty() {
None
} else {
Some(t.clone())
}
})
.ok()
.flatten()
}
#[must_use]
pub fn current_sub() -> Option<String> {
CURRENT_SUB
.try_with(Clone::clone)
.ok()
.filter(|s| !s.is_empty())
}
pub async fn with_token_scope<F: Future>(token: SecretString, f: F) -> F::Output {
CURRENT_TOKEN.scope(token, f).await
}
pub async fn with_rbac_scope<F: Future>(
role: String,
identity: String,
token: SecretString,
sub: String,
f: F,
) -> F::Output {
CURRENT_ROLE
.scope(
role,
CURRENT_IDENTITY.scope(
identity,
CURRENT_TOKEN.scope(token, CURRENT_SUB.scope(sub, f)),
),
)
.await
}
#[derive(Debug, Clone, Deserialize)]
#[non_exhaustive]
pub struct RoleConfig {
pub name: String,
#[serde(default)]
pub description: Option<String>,
#[serde(default)]
pub allow: Vec<String>,
#[serde(default)]
pub deny: Vec<String>,
#[serde(default = "default_hosts")]
pub hosts: Vec<String>,
#[serde(default)]
pub argument_allowlists: Vec<ArgumentAllowlist>,
}
impl RoleConfig {
#[must_use]
pub fn new(name: impl Into<String>, allow: Vec<String>, hosts: Vec<String>) -> Self {
Self {
name: name.into(),
description: None,
allow,
deny: vec![],
hosts,
argument_allowlists: vec![],
}
}
#[must_use]
pub fn with_argument_allowlists(mut self, allowlists: Vec<ArgumentAllowlist>) -> Self {
self.argument_allowlists = allowlists;
self
}
}
#[derive(Debug, Clone, Deserialize)]
#[non_exhaustive]
pub struct ArgumentAllowlist {
pub tool: String,
pub argument: String,
#[serde(default)]
pub allowed: Vec<String>,
}
impl ArgumentAllowlist {
#[must_use]
pub fn new(tool: impl Into<String>, argument: impl Into<String>, allowed: Vec<String>) -> Self {
Self {
tool: tool.into(),
argument: argument.into(),
allowed,
}
}
}
fn default_hosts() -> Vec<String> {
vec!["*".into()]
}
#[derive(Debug, Clone, Default, Deserialize)]
#[non_exhaustive]
pub struct RbacConfig {
#[serde(default)]
pub enabled: bool,
#[serde(default)]
pub roles: Vec<RoleConfig>,
#[serde(default)]
pub redaction_salt: Option<SecretString>,
}
impl RbacConfig {
#[must_use]
pub fn with_roles(roles: Vec<RoleConfig>) -> Self {
Self {
enabled: true,
roles,
redaction_salt: None,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[non_exhaustive]
pub enum RbacDecision {
Allow,
Deny,
}
#[derive(Debug, Clone, serde::Serialize)]
#[non_exhaustive]
pub struct RbacRoleSummary {
pub name: String,
pub allow: usize,
pub deny: usize,
pub hosts: usize,
pub argument_allowlists: usize,
}
#[derive(Debug, Clone, serde::Serialize)]
#[non_exhaustive]
pub struct RbacPolicySummary {
pub enabled: bool,
pub roles: Vec<RbacRoleSummary>,
}
#[derive(Debug, Clone)]
#[non_exhaustive]
pub struct RbacPolicy {
roles: Vec<RoleConfig>,
enabled: bool,
redaction_salt: Arc<SecretString>,
}
impl RbacPolicy {
#[must_use]
pub fn new(config: &RbacConfig) -> Self {
let salt = config
.redaction_salt
.clone()
.unwrap_or_else(|| process_redaction_salt().clone());
Self {
roles: config.roles.clone(),
enabled: config.enabled,
redaction_salt: Arc::new(salt),
}
}
#[must_use]
pub fn disabled() -> Self {
Self {
roles: Vec::new(),
enabled: false,
redaction_salt: Arc::new(process_redaction_salt().clone()),
}
}
#[must_use]
pub fn is_enabled(&self) -> bool {
self.enabled
}
#[must_use]
pub fn summary(&self) -> RbacPolicySummary {
let roles = self
.roles
.iter()
.map(|r| RbacRoleSummary {
name: r.name.clone(),
allow: r.allow.len(),
deny: r.deny.len(),
hosts: r.hosts.len(),
argument_allowlists: r.argument_allowlists.len(),
})
.collect();
RbacPolicySummary {
enabled: self.enabled,
roles,
}
}
#[must_use]
pub fn check_operation(&self, role: &str, operation: &str) -> RbacDecision {
if !self.enabled {
return RbacDecision::Allow;
}
let Some(role_cfg) = self.find_role(role) else {
return RbacDecision::Deny;
};
if role_cfg.deny.iter().any(|d| d == operation) {
return RbacDecision::Deny;
}
if role_cfg.allow.iter().any(|a| a == "*" || a == operation) {
return RbacDecision::Allow;
}
RbacDecision::Deny
}
#[must_use]
pub fn check(&self, role: &str, operation: &str, host: &str) -> RbacDecision {
if !self.enabled {
return RbacDecision::Allow;
}
let Some(role_cfg) = self.find_role(role) else {
return RbacDecision::Deny;
};
if role_cfg.deny.iter().any(|d| d == operation) {
return RbacDecision::Deny;
}
if !role_cfg.allow.iter().any(|a| a == "*" || a == operation) {
return RbacDecision::Deny;
}
if !Self::host_matches(&role_cfg.hosts, host) {
return RbacDecision::Deny;
}
RbacDecision::Allow
}
#[must_use]
pub fn host_visible(&self, role: &str, host: &str) -> bool {
if !self.enabled {
return true;
}
let Some(role_cfg) = self.find_role(role) else {
return false;
};
Self::host_matches(&role_cfg.hosts, host)
}
#[must_use]
pub fn host_patterns(&self, role: &str) -> Option<&[String]> {
self.find_role(role).map(|r| r.hosts.as_slice())
}
#[must_use]
pub fn argument_allowed(&self, role: &str, tool: &str, argument: &str, value: &str) -> bool {
if !self.enabled {
return true;
}
let Some(role_cfg) = self.find_role(role) else {
return false;
};
for al in &role_cfg.argument_allowlists {
if al.tool != tool && !glob_match(&al.tool, tool) {
continue;
}
if al.argument != argument {
continue;
}
if al.allowed.is_empty() {
continue;
}
let first_token = value.split_whitespace().next().unwrap_or(value);
let basename = first_token.rsplit('/').next().unwrap_or(first_token);
if !al.allowed.iter().any(|a| a == first_token || a == basename) {
return false;
}
}
true
}
fn find_role(&self, name: &str) -> Option<&RoleConfig> {
self.roles.iter().find(|r| r.name == name)
}
fn host_matches(patterns: &[String], host: &str) -> bool {
patterns.iter().any(|p| glob_match(p, host))
}
#[must_use]
pub fn redact_arg(&self, value: &str) -> String {
redact_with_salt(self.redaction_salt.expose_secret().as_bytes(), value)
}
}
fn process_redaction_salt() -> &'static SecretString {
use base64::{Engine as _, engine::general_purpose::STANDARD_NO_PAD};
static PROCESS_SALT: std::sync::OnceLock<SecretString> = std::sync::OnceLock::new();
PROCESS_SALT.get_or_init(|| {
let mut bytes = [0u8; 32];
rand::fill(&mut bytes);
SecretString::from(STANDARD_NO_PAD.encode(bytes))
})
}
fn redact_with_salt(salt: &[u8], value: &str) -> String {
use std::fmt::Write as _;
use sha2::Digest as _;
type HmacSha256 = Hmac<Sha256>;
let mut mac = if let Ok(m) = HmacSha256::new_from_slice(salt) {
m
} else {
let digest = Sha256::digest(salt);
#[allow(clippy::expect_used)] HmacSha256::new_from_slice(&digest).expect("32-byte SHA256 digest is valid HMAC key")
};
mac.update(value.as_bytes());
let bytes = mac.finalize().into_bytes();
let prefix = bytes.get(..4).unwrap_or(&[0; 4]);
let mut out = String::with_capacity(8);
for b in prefix {
let _ = write!(out, "{b:02x}");
}
out
}
#[allow(clippy::too_many_lines)]
pub(crate) async fn rbac_middleware(
policy: Arc<RbacPolicy>,
tool_limiter: Option<Arc<ToolRateLimiter>>,
req: Request<Body>,
next: Next,
) -> Response {
if req.method() != Method::POST {
return next.run(req).await;
}
let peer_ip: Option<IpAddr> = req
.extensions()
.get::<ConnectInfo<std::net::SocketAddr>>()
.map(|ci| ci.0.ip())
.or_else(|| {
req.extensions()
.get::<ConnectInfo<TlsConnInfo>>()
.map(|ci| ci.0.addr.ip())
});
let identity = req.extensions().get::<AuthIdentity>();
let identity_name = identity.map(|id| id.name.clone()).unwrap_or_default();
let role = identity.map(|id| id.role.clone()).unwrap_or_default();
let raw_token: SecretString = identity
.and_then(|id| id.raw_token.clone())
.unwrap_or_else(|| SecretString::from(String::new()));
let sub = identity.and_then(|id| id.sub.clone()).unwrap_or_default();
if policy.is_enabled() && identity.is_none() {
return McpxError::Rbac("no authenticated identity".into()).into_response();
}
let (parts, body) = req.into_parts();
let bytes = match body.collect().await {
Ok(collected) => collected.to_bytes(),
Err(e) => {
tracing::error!(error = %e, "failed to read request body");
return (
StatusCode::INTERNAL_SERVER_ERROR,
"failed to read request body",
)
.into_response();
}
};
if let Ok(json) = serde_json::from_slice::<serde_json::Value>(&bytes) {
let tool_calls = extract_tool_calls(&json);
if !tool_calls.is_empty() {
for params in tool_calls {
if let Some(resp) = enforce_rate_limit(tool_limiter.as_deref(), peer_ip) {
return resp;
}
if policy.is_enabled()
&& let Some(resp) = enforce_tool_policy(&policy, &identity_name, &role, params)
{
return resp;
}
}
}
}
let req = Request::from_parts(parts, Body::from(bytes));
if role.is_empty() {
next.run(req).await
} else {
CURRENT_ROLE
.scope(
role,
CURRENT_IDENTITY.scope(
identity_name,
CURRENT_TOKEN.scope(raw_token, CURRENT_SUB.scope(sub, next.run(req))),
),
)
.await
}
}
fn extract_tool_calls(value: &serde_json::Value) -> Vec<&serde_json::Value> {
match value {
serde_json::Value::Object(map) => map
.get("method")
.and_then(serde_json::Value::as_str)
.filter(|method| *method == "tools/call")
.and_then(|_| map.get("params"))
.into_iter()
.collect(),
serde_json::Value::Array(items) => items
.iter()
.filter_map(|item| match item {
serde_json::Value::Object(map) => map
.get("method")
.and_then(serde_json::Value::as_str)
.filter(|method| *method == "tools/call")
.and_then(|_| map.get("params")),
serde_json::Value::Null
| serde_json::Value::Bool(_)
| serde_json::Value::Number(_)
| serde_json::Value::String(_)
| serde_json::Value::Array(_) => None,
})
.collect(),
serde_json::Value::Null
| serde_json::Value::Bool(_)
| serde_json::Value::Number(_)
| serde_json::Value::String(_) => Vec::new(),
}
}
fn enforce_rate_limit(
tool_limiter: Option<&ToolRateLimiter>,
peer_ip: Option<IpAddr>,
) -> Option<Response> {
let limiter = tool_limiter?;
let ip = peer_ip?;
if limiter.check_key(&ip).is_err() {
tracing::warn!(%ip, "tool invocation rate limited");
return Some(McpxError::RateLimited("too many tool invocations".into()).into_response());
}
None
}
fn enforce_tool_policy(
policy: &RbacPolicy,
identity_name: &str,
role: &str,
params: &serde_json::Value,
) -> Option<Response> {
let tool_name = params.get("name").and_then(|v| v.as_str()).unwrap_or("");
let host = params
.get("arguments")
.and_then(|a| a.get("host"))
.and_then(|h| h.as_str());
let decision = if let Some(host) = host {
policy.check(role, tool_name, host)
} else {
policy.check_operation(role, tool_name)
};
if decision == RbacDecision::Deny {
tracing::warn!(
user = %identity_name,
role = %role,
tool = tool_name,
host = host.unwrap_or("-"),
"RBAC denied"
);
return Some(
McpxError::Rbac(format!("{tool_name} denied for role '{role}'")).into_response(),
);
}
let args = params.get("arguments").and_then(|a| a.as_object())?;
for (arg_key, arg_val) in args {
if let Some(val_str) = arg_val.as_str()
&& !policy.argument_allowed(role, tool_name, arg_key, val_str)
{
tracing::warn!(
user = %identity_name,
role = %role,
tool = tool_name,
argument = arg_key,
arg_hmac = %policy.redact_arg(val_str),
"argument not in allowlist"
);
return Some(
McpxError::Rbac(format!(
"argument '{arg_key}' value not in allowlist for tool '{tool_name}'"
))
.into_response(),
);
}
}
None
}
fn glob_match(pattern: &str, text: &str) -> bool {
let parts: Vec<&str> = pattern.split('*').collect();
if parts.len() == 1 {
return pattern == text;
}
let mut pos = 0;
if let Some(&first) = parts.first()
&& !first.is_empty()
{
if !text.starts_with(first) {
return false;
}
pos = first.len();
}
if let Some(&last) = parts.last()
&& !last.is_empty()
{
if !text[pos..].ends_with(last) {
return false;
}
let end = text.len() - last.len();
if pos > end {
return false;
}
let middle = &text[pos..end];
let middle_parts = parts.get(1..parts.len() - 1).unwrap_or_default();
return match_middle(middle, middle_parts);
}
let middle = &text[pos..];
let middle_parts = parts.get(1..parts.len() - 1).unwrap_or_default();
match_middle(middle, middle_parts)
}
fn match_middle(mut text: &str, parts: &[&str]) -> bool {
for part in parts {
if part.is_empty() {
continue;
}
if let Some(idx) = text.find(part) {
text = &text[idx + part.len()..];
} else {
return false;
}
}
true
}
#[cfg(test)]
mod tests {
use super::*;
fn test_policy() -> RbacPolicy {
RbacPolicy::new(&RbacConfig {
enabled: true,
roles: vec![
RoleConfig {
name: "viewer".into(),
description: Some("Read-only".into()),
allow: vec![
"list_hosts".into(),
"resource_list".into(),
"resource_inspect".into(),
"resource_logs".into(),
"system_info".into(),
],
deny: vec![],
hosts: vec!["*".into()],
argument_allowlists: vec![],
},
RoleConfig {
name: "deploy".into(),
description: Some("Lifecycle management".into()),
allow: vec![
"list_hosts".into(),
"resource_list".into(),
"resource_run".into(),
"resource_start".into(),
"resource_stop".into(),
"resource_restart".into(),
"resource_logs".into(),
"image_pull".into(),
],
deny: vec!["resource_delete".into(), "resource_exec".into()],
hosts: vec!["web-*".into(), "api-*".into()],
argument_allowlists: vec![],
},
RoleConfig {
name: "ops".into(),
description: Some("Full access".into()),
allow: vec!["*".into()],
deny: vec![],
hosts: vec!["*".into()],
argument_allowlists: vec![],
},
RoleConfig {
name: "restricted-exec".into(),
description: Some("Exec with argument allowlist".into()),
allow: vec!["resource_exec".into()],
deny: vec![],
hosts: vec!["dev-*".into()],
argument_allowlists: vec![ArgumentAllowlist {
tool: "resource_exec".into(),
argument: "cmd".into(),
allowed: vec![
"sh".into(),
"bash".into(),
"cat".into(),
"ls".into(),
"ps".into(),
],
}],
},
],
redaction_salt: None,
})
}
#[test]
fn glob_exact_match() {
assert!(glob_match("web-prod-1", "web-prod-1"));
assert!(!glob_match("web-prod-1", "web-prod-2"));
}
#[test]
fn glob_star_suffix() {
assert!(glob_match("web-*", "web-prod-1"));
assert!(glob_match("web-*", "web-staging"));
assert!(!glob_match("web-*", "api-prod"));
}
#[test]
fn glob_star_prefix() {
assert!(glob_match("*-prod", "web-prod"));
assert!(glob_match("*-prod", "api-prod"));
assert!(!glob_match("*-prod", "web-staging"));
}
#[test]
fn glob_star_middle() {
assert!(glob_match("web-*-prod", "web-us-prod"));
assert!(glob_match("web-*-prod", "web-eu-east-prod"));
assert!(!glob_match("web-*-prod", "web-staging"));
}
#[test]
fn glob_star_only() {
assert!(glob_match("*", "anything"));
assert!(glob_match("*", ""));
}
#[test]
fn glob_multiple_stars() {
assert!(glob_match("*web*prod*", "my-web-us-prod-1"));
assert!(!glob_match("*web*prod*", "my-api-us-staging"));
}
#[test]
fn disabled_policy_allows_everything() {
let policy = RbacPolicy::new(&RbacConfig {
enabled: false,
roles: vec![],
redaction_salt: None,
});
assert_eq!(
policy.check("nonexistent", "resource_delete", "any-host"),
RbacDecision::Allow
);
}
#[test]
fn unknown_role_denied() {
let policy = test_policy();
assert_eq!(
policy.check("unknown", "resource_list", "web-prod-1"),
RbacDecision::Deny
);
}
#[test]
fn viewer_allowed_read_ops() {
let policy = test_policy();
assert_eq!(
policy.check("viewer", "resource_list", "web-prod-1"),
RbacDecision::Allow
);
assert_eq!(
policy.check("viewer", "system_info", "db-host"),
RbacDecision::Allow
);
}
#[test]
fn viewer_denied_write_ops() {
let policy = test_policy();
assert_eq!(
policy.check("viewer", "resource_run", "web-prod-1"),
RbacDecision::Deny
);
assert_eq!(
policy.check("viewer", "resource_delete", "web-prod-1"),
RbacDecision::Deny
);
}
#[test]
fn deploy_allowed_on_matching_hosts() {
let policy = test_policy();
assert_eq!(
policy.check("deploy", "resource_run", "web-prod-1"),
RbacDecision::Allow
);
assert_eq!(
policy.check("deploy", "resource_start", "api-staging"),
RbacDecision::Allow
);
}
#[test]
fn deploy_denied_on_non_matching_host() {
let policy = test_policy();
assert_eq!(
policy.check("deploy", "resource_run", "db-prod-1"),
RbacDecision::Deny
);
}
#[test]
fn deny_overrides_allow() {
let policy = test_policy();
assert_eq!(
policy.check("deploy", "resource_delete", "web-prod-1"),
RbacDecision::Deny
);
assert_eq!(
policy.check("deploy", "resource_exec", "web-prod-1"),
RbacDecision::Deny
);
}
#[test]
fn ops_wildcard_allows_everything() {
let policy = test_policy();
assert_eq!(
policy.check("ops", "resource_delete", "any-host"),
RbacDecision::Allow
);
assert_eq!(
policy.check("ops", "secret_create", "db-host"),
RbacDecision::Allow
);
}
#[test]
fn host_visible_respects_globs() {
let policy = test_policy();
assert!(policy.host_visible("deploy", "web-prod-1"));
assert!(policy.host_visible("deploy", "api-staging"));
assert!(!policy.host_visible("deploy", "db-prod-1"));
assert!(policy.host_visible("ops", "anything"));
assert!(policy.host_visible("viewer", "anything"));
}
#[test]
fn host_visible_unknown_role() {
let policy = test_policy();
assert!(!policy.host_visible("unknown", "web-prod-1"));
}
#[test]
fn argument_allowed_no_allowlist() {
let policy = test_policy();
assert!(policy.argument_allowed("ops", "resource_exec", "cmd", "rm -rf /"));
assert!(policy.argument_allowed("ops", "resource_exec", "cmd", "bash"));
}
#[test]
fn argument_allowed_with_allowlist() {
let policy = test_policy();
assert!(policy.argument_allowed("restricted-exec", "resource_exec", "cmd", "sh"));
assert!(policy.argument_allowed(
"restricted-exec",
"resource_exec",
"cmd",
"bash -c 'echo hi'"
));
assert!(policy.argument_allowed(
"restricted-exec",
"resource_exec",
"cmd",
"cat /etc/hosts"
));
assert!(policy.argument_allowed(
"restricted-exec",
"resource_exec",
"cmd",
"/usr/bin/ls -la"
));
}
#[test]
fn argument_denied_not_in_allowlist() {
let policy = test_policy();
assert!(!policy.argument_allowed("restricted-exec", "resource_exec", "cmd", "rm -rf /"));
assert!(!policy.argument_allowed(
"restricted-exec",
"resource_exec",
"cmd",
"python3 exploit.py"
));
assert!(!policy.argument_allowed(
"restricted-exec",
"resource_exec",
"cmd",
"/usr/bin/curl evil.com"
));
}
#[test]
fn argument_denied_unknown_role() {
let policy = test_policy();
assert!(!policy.argument_allowed("unknown", "resource_exec", "cmd", "sh"));
}
#[test]
fn host_patterns_returns_globs() {
let policy = test_policy();
assert_eq!(
policy.host_patterns("deploy"),
Some(vec!["web-*".to_owned(), "api-*".to_owned()].as_slice())
);
assert_eq!(
policy.host_patterns("ops"),
Some(vec!["*".to_owned()].as_slice())
);
assert!(policy.host_patterns("nonexistent").is_none());
}
#[test]
fn check_operation_allows_without_host() {
let policy = test_policy();
assert_eq!(
policy.check_operation("deploy", "resource_run"),
RbacDecision::Allow
);
assert_eq!(
policy.check("deploy", "resource_run", "db-prod-1"),
RbacDecision::Deny
);
}
#[test]
fn check_operation_deny_overrides() {
let policy = test_policy();
assert_eq!(
policy.check_operation("deploy", "resource_delete"),
RbacDecision::Deny
);
}
#[test]
fn check_operation_unknown_role() {
let policy = test_policy();
assert_eq!(
policy.check_operation("unknown", "resource_list"),
RbacDecision::Deny
);
}
#[test]
fn check_operation_disabled() {
let policy = RbacPolicy::new(&RbacConfig {
enabled: false,
roles: vec![],
redaction_salt: None,
});
assert_eq!(
policy.check_operation("nonexistent", "anything"),
RbacDecision::Allow
);
}
#[test]
fn current_role_returns_none_outside_scope() {
assert!(current_role().is_none());
}
#[test]
fn current_identity_returns_none_outside_scope() {
assert!(current_identity().is_none());
}
use axum::{
body::Body,
http::{Method, Request, StatusCode},
};
use tower::ServiceExt as _;
fn tool_call_body(tool: &str, args: &serde_json::Value) -> String {
serde_json::json!({
"jsonrpc": "2.0",
"id": 1,
"method": "tools/call",
"params": {
"name": tool,
"arguments": args
}
})
.to_string()
}
fn rbac_router(policy: Arc<RbacPolicy>) -> axum::Router {
axum::Router::new()
.route("/mcp", axum::routing::post(|| async { "ok" }))
.layer(axum::middleware::from_fn(move |req, next| {
let p = Arc::clone(&policy);
rbac_middleware(p, None, req, next)
}))
}
fn rbac_router_with_identity(policy: Arc<RbacPolicy>, identity: AuthIdentity) -> axum::Router {
axum::Router::new()
.route("/mcp", axum::routing::post(|| async { "ok" }))
.layer(axum::middleware::from_fn(
move |mut req: Request<Body>, next: Next| {
let p = Arc::clone(&policy);
let id = identity.clone();
async move {
req.extensions_mut().insert(id);
rbac_middleware(p, None, req, next).await
}
},
))
}
#[tokio::test]
async fn middleware_passes_non_post() {
let policy = Arc::new(test_policy());
let app = rbac_router(policy);
let req = Request::builder()
.method(Method::GET)
.uri("/mcp")
.body(Body::empty())
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::METHOD_NOT_ALLOWED);
}
#[tokio::test]
async fn middleware_denies_without_identity() {
let policy = Arc::new(test_policy());
let app = rbac_router(policy);
let body = tool_call_body("resource_list", &serde_json::json!({}));
let req = Request::builder()
.method(Method::POST)
.uri("/mcp")
.header("content-type", "application/json")
.body(Body::from(body))
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::FORBIDDEN);
}
#[tokio::test]
async fn middleware_allows_permitted_tool() {
let policy = Arc::new(test_policy());
let id = AuthIdentity {
method: crate::auth::AuthMethod::BearerToken,
name: "alice".into(),
role: "viewer".into(),
raw_token: None,
sub: None,
};
let app = rbac_router_with_identity(policy, id);
let body = tool_call_body("resource_list", &serde_json::json!({}));
let req = Request::builder()
.method(Method::POST)
.uri("/mcp")
.header("content-type", "application/json")
.body(Body::from(body))
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
}
#[tokio::test]
async fn middleware_denies_unpermitted_tool() {
let policy = Arc::new(test_policy());
let id = AuthIdentity {
method: crate::auth::AuthMethod::BearerToken,
name: "alice".into(),
role: "viewer".into(),
raw_token: None,
sub: None,
};
let app = rbac_router_with_identity(policy, id);
let body = tool_call_body("resource_delete", &serde_json::json!({}));
let req = Request::builder()
.method(Method::POST)
.uri("/mcp")
.header("content-type", "application/json")
.body(Body::from(body))
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::FORBIDDEN);
}
#[tokio::test]
async fn middleware_passes_non_tool_call_post() {
let policy = Arc::new(test_policy());
let id = AuthIdentity {
method: crate::auth::AuthMethod::BearerToken,
name: "alice".into(),
role: "viewer".into(),
raw_token: None,
sub: None,
};
let app = rbac_router_with_identity(policy, id);
let body = serde_json::json!({
"jsonrpc": "2.0",
"id": 1,
"method": "resources/list"
})
.to_string();
let req = Request::builder()
.method(Method::POST)
.uri("/mcp")
.header("content-type", "application/json")
.body(Body::from(body))
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
}
#[tokio::test]
async fn middleware_enforces_argument_allowlist() {
let policy = Arc::new(test_policy());
let id = AuthIdentity {
method: crate::auth::AuthMethod::BearerToken,
name: "dev".into(),
role: "restricted-exec".into(),
raw_token: None,
sub: None,
};
let app = rbac_router_with_identity(Arc::clone(&policy), id.clone());
let body = tool_call_body(
"resource_exec",
&serde_json::json!({"cmd": "ls -la", "host": "dev-1"}),
);
let req = Request::builder()
.method(Method::POST)
.uri("/mcp")
.body(Body::from(body))
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
let app = rbac_router_with_identity(policy, id);
let body = tool_call_body(
"resource_exec",
&serde_json::json!({"cmd": "rm -rf /", "host": "dev-1"}),
);
let req = Request::builder()
.method(Method::POST)
.uri("/mcp")
.body(Body::from(body))
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::FORBIDDEN);
}
#[tokio::test]
async fn middleware_disabled_policy_passes_everything() {
let policy = Arc::new(RbacPolicy::disabled());
let app = rbac_router(policy);
let body = tool_call_body("anything", &serde_json::json!({}));
let req = Request::builder()
.method(Method::POST)
.uri("/mcp")
.body(Body::from(body))
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
}
#[tokio::test]
async fn middleware_batch_all_allowed_passes() {
let policy = Arc::new(test_policy());
let id = AuthIdentity {
method: crate::auth::AuthMethod::BearerToken,
name: "alice".into(),
role: "viewer".into(),
raw_token: None,
sub: None,
};
let app = rbac_router_with_identity(policy, id);
let body = serde_json::json!([
{
"jsonrpc": "2.0",
"id": 1,
"method": "tools/call",
"params": { "name": "resource_list", "arguments": {} }
},
{
"jsonrpc": "2.0",
"id": 2,
"method": "tools/call",
"params": { "name": "system_info", "arguments": {} }
}
])
.to_string();
let req = Request::builder()
.method(Method::POST)
.uri("/mcp")
.header("content-type", "application/json")
.body(Body::from(body))
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
}
#[tokio::test]
async fn middleware_batch_with_denied_call_rejects_entire_batch() {
let policy = Arc::new(test_policy());
let id = AuthIdentity {
method: crate::auth::AuthMethod::BearerToken,
name: "alice".into(),
role: "viewer".into(),
raw_token: None,
sub: None,
};
let app = rbac_router_with_identity(policy, id);
let body = serde_json::json!([
{
"jsonrpc": "2.0",
"id": 1,
"method": "tools/call",
"params": { "name": "resource_list", "arguments": {} }
},
{
"jsonrpc": "2.0",
"id": 2,
"method": "tools/call",
"params": { "name": "resource_delete", "arguments": {} }
}
])
.to_string();
let req = Request::builder()
.method(Method::POST)
.uri("/mcp")
.header("content-type", "application/json")
.body(Body::from(body))
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::FORBIDDEN);
}
#[tokio::test]
async fn middleware_batch_mixed_allowed_and_denied_rejects() {
let policy = Arc::new(test_policy());
let id = AuthIdentity {
method: crate::auth::AuthMethod::BearerToken,
name: "dev".into(),
role: "restricted-exec".into(),
raw_token: None,
sub: None,
};
let app = rbac_router_with_identity(policy, id);
let body = serde_json::json!([
{
"jsonrpc": "2.0",
"id": 1,
"method": "tools/call",
"params": {
"name": "resource_exec",
"arguments": { "cmd": "ls -la", "host": "dev-1" }
}
},
{
"jsonrpc": "2.0",
"id": 2,
"method": "tools/call",
"params": {
"name": "resource_exec",
"arguments": { "cmd": "rm -rf /", "host": "dev-1" }
}
}
])
.to_string();
let req = Request::builder()
.method(Method::POST)
.uri("/mcp")
.header("content-type", "application/json")
.body(Body::from(body))
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::FORBIDDEN);
}
#[test]
fn redact_with_salt_is_deterministic_per_salt() {
let salt = b"unit-test-salt";
let a = redact_with_salt(salt, "rm -rf /");
let b = redact_with_salt(salt, "rm -rf /");
assert_eq!(a, b, "same input + salt must yield identical hash");
assert_eq!(a.len(), 8, "redacted hash is 8 hex chars (4 bytes)");
assert!(
a.chars().all(|c| c.is_ascii_hexdigit()),
"redacted hash must be lowercase hex: {a}"
);
}
#[test]
fn redact_with_salt_differs_across_salts() {
let v = "the-same-value";
let h1 = redact_with_salt(b"salt-one", v);
let h2 = redact_with_salt(b"salt-two", v);
assert_ne!(
h1, h2,
"different salts must produce different hashes for the same value"
);
}
#[test]
fn redact_with_salt_distinguishes_values() {
let salt = b"k";
let h1 = redact_with_salt(salt, "alpha");
let h2 = redact_with_salt(salt, "beta");
assert_ne!(h1, h2, "different values must produce different hashes");
}
#[test]
fn policy_with_configured_salt_redacts_consistently() {
let cfg = RbacConfig {
enabled: true,
roles: vec![],
redaction_salt: Some(SecretString::from("my-stable-salt")),
};
let p1 = RbacPolicy::new(&cfg);
let p2 = RbacPolicy::new(&cfg);
assert_eq!(
p1.redact_arg("payload"),
p2.redact_arg("payload"),
"policies built from the same configured salt must agree"
);
}
#[test]
fn policy_without_configured_salt_uses_process_salt() {
let cfg = RbacConfig {
enabled: true,
roles: vec![],
redaction_salt: None,
};
let p1 = RbacPolicy::new(&cfg);
let p2 = RbacPolicy::new(&cfg);
assert_eq!(
p1.redact_arg("payload"),
p2.redact_arg("payload"),
"process-wide salt must be consistent within one process"
);
}
#[test]
fn redact_arg_is_fast_enough() {
let salt = b"perf-sanity-salt-32-bytes-padded";
let value = "x".repeat(256);
let start = std::time::Instant::now();
let _ = redact_with_salt(salt, &value);
let elapsed = start.elapsed();
assert!(
elapsed < Duration::from_millis(5),
"single redact_with_salt took {elapsed:?}, expected <5 ms even in debug"
);
}
#[tokio::test]
async fn deny_path_uses_explicit_identity_not_task_local() {
let policy = Arc::new(test_policy());
let id = AuthIdentity {
method: crate::auth::AuthMethod::BearerToken,
name: "alice-the-auditor".into(),
role: "viewer".into(),
raw_token: None,
sub: None,
};
let app = rbac_router_with_identity(policy, id);
let body = tool_call_body("resource_delete", &serde_json::json!({}));
let req = Request::builder()
.method(Method::POST)
.uri("/mcp")
.header("content-type", "application/json")
.body(Body::from(body))
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::FORBIDDEN);
}
}