use super::{Decision, McpContext, Middleware};
use crate::{config::tool_matches, live_config::LiveConfig};
use async_trait::async_trait;
use std::sync::Arc;
use tokio::sync::watch;
#[cfg(test)]
mod tests {
use super::*;
use crate::config::{FilterMode, make_agent};
use std::collections::HashMap;
fn make_mw(agents: HashMap<String, crate::config::AgentPolicy>) -> AuthMiddleware {
let live = Arc::new(LiveConfig::new(
agents,
vec![],
vec![],
None,
FilterMode::Block,
None,
));
let (_, rx) = watch::channel(live);
AuthMiddleware::new(rx)
}
fn ctx(agent: &str, method: &str, tool: Option<&str>) -> McpContext {
McpContext {
agent_id: agent.to_string(),
method: method.to_string(),
tool_name: tool.map(String::from),
arguments: None,
client_ip: None,
}
}
#[tokio::test]
async fn non_tools_call_always_allowed() {
let mw = make_mw(HashMap::new()); assert!(matches!(
mw.check(&ctx("nobody", "initialize", None)).await,
Decision::Allow { rl: None }
));
assert!(matches!(
mw.check(&ctx("nobody", "notifications/initialized", None))
.await,
Decision::Allow { rl: None }
));
}
#[tokio::test]
async fn unknown_agent_blocked_on_tools_call() {
let mw = make_mw(HashMap::new());
assert!(matches!(
mw.check(&ctx("ghost", "tools/call", Some("echo"))).await,
Decision::Block { .. }
));
}
#[tokio::test]
async fn denied_tool_blocked() {
let mut agents = HashMap::new();
agents.insert(
"cursor".to_string(),
make_agent(None, vec!["write_file"], 60),
);
let mw = make_mw(agents);
assert!(matches!(
mw.check(&ctx("cursor", "tools/call", Some("write_file")))
.await,
Decision::Block { .. }
));
}
#[tokio::test]
async fn non_denied_tool_allowed_without_allowlist() {
let mut agents = HashMap::new();
agents.insert(
"cursor".to_string(),
make_agent(None, vec!["write_file"], 60),
);
let mw = make_mw(agents);
assert!(matches!(
mw.check(&ctx("cursor", "tools/call", Some("read_file")))
.await,
Decision::Allow { rl: None }
));
}
#[tokio::test]
async fn allowlist_permits_listed_tool() {
let mut agents = HashMap::new();
agents.insert(
"claude".to_string(),
make_agent(Some(vec!["read_file"]), vec![], 60),
);
let mw = make_mw(agents);
assert!(matches!(
mw.check(&ctx("claude", "tools/call", Some("read_file")))
.await,
Decision::Allow { rl: None }
));
}
#[tokio::test]
async fn allowlist_blocks_unlisted_tool() {
let mut agents = HashMap::new();
agents.insert(
"claude".to_string(),
make_agent(Some(vec!["read_file"]), vec![], 60),
);
let mw = make_mw(agents);
assert!(matches!(
mw.check(&ctx("claude", "tools/call", Some("delete_file")))
.await,
Decision::Block { .. }
));
}
#[tokio::test]
async fn denied_takes_priority_over_allowlist() {
let mut agents = HashMap::new();
agents.insert(
"cursor".to_string(),
make_agent(
Some(vec!["read_file", "write_file"]),
vec!["write_file"],
60,
),
);
let mw = make_mw(agents);
assert!(matches!(
mw.check(&ctx("cursor", "tools/call", Some("write_file")))
.await,
Decision::Block { .. }
));
}
#[tokio::test]
async fn glob_denylist_blocks_matching_tools() {
let mut agents = HashMap::new();
agents.insert("agent".to_string(), make_agent(None, vec!["write_*"], 60));
let mw = make_mw(agents);
assert!(matches!(
mw.check(&ctx("agent", "tools/call", Some("write_file")))
.await,
Decision::Block { .. }
));
assert!(matches!(
mw.check(&ctx("agent", "tools/call", Some("write_dir")))
.await,
Decision::Block { .. }
));
}
#[tokio::test]
async fn glob_denylist_allows_non_matching_tools() {
let mut agents = HashMap::new();
agents.insert("agent".to_string(), make_agent(None, vec!["write_*"], 60));
let mw = make_mw(agents);
assert!(matches!(
mw.check(&ctx("agent", "tools/call", Some("read_file")))
.await,
Decision::Allow { .. }
));
}
#[tokio::test]
async fn glob_denylist_star_blocks_all_tools() {
let mut agents = HashMap::new();
agents.insert("agent".to_string(), make_agent(None, vec!["*"], 60));
let mw = make_mw(agents);
assert!(matches!(
mw.check(&ctx("agent", "tools/call", Some("any_tool")))
.await,
Decision::Block { .. }
));
}
#[tokio::test]
async fn glob_allowlist_permits_matching_tools() {
let mut agents = HashMap::new();
agents.insert(
"agent".to_string(),
make_agent(Some(vec!["read_*", "list_*"]), vec![], 60),
);
let mw = make_mw(agents);
assert!(matches!(
mw.check(&ctx("agent", "tools/call", Some("read_file")))
.await,
Decision::Allow { .. }
));
assert!(matches!(
mw.check(&ctx("agent", "tools/call", Some("list_dir")))
.await,
Decision::Allow { .. }
));
}
#[tokio::test]
async fn glob_allowlist_blocks_non_matching_tools() {
let mut agents = HashMap::new();
agents.insert(
"agent".to_string(),
make_agent(Some(vec!["read_*"]), vec![], 60),
);
let mw = make_mw(agents);
assert!(matches!(
mw.check(&ctx("agent", "tools/call", Some("write_file")))
.await,
Decision::Block { .. }
));
assert!(matches!(
mw.check(&ctx("agent", "tools/call", Some("delete_file")))
.await,
Decision::Block { .. }
));
}
#[tokio::test]
async fn glob_allowlist_star_permits_all_tools() {
let mut agents = HashMap::new();
agents.insert("agent".to_string(), make_agent(Some(vec!["*"]), vec![], 60));
let mw = make_mw(agents);
assert!(matches!(
mw.check(&ctx("agent", "tools/call", Some("anything")))
.await,
Decision::Allow { .. }
));
}
#[tokio::test]
async fn glob_deny_overrides_glob_allowlist() {
let mut agents = HashMap::new();
agents.insert(
"agent".to_string(),
make_agent(Some(vec!["read_*"]), vec!["read_file"], 60),
);
let mw = make_mw(agents);
assert!(matches!(
mw.check(&ctx("agent", "tools/call", Some("read_file")))
.await,
Decision::Block { .. }
));
assert!(matches!(
mw.check(&ctx("agent", "tools/call", Some("read_dir")))
.await,
Decision::Allow { .. }
));
}
fn make_mw_with_default(
agents: HashMap<String, crate::config::AgentPolicy>,
default: crate::config::AgentPolicy,
) -> AuthMiddleware {
use crate::config::FilterMode;
let live = Arc::new(LiveConfig::new(
agents,
vec![],
vec![],
None,
FilterMode::Block,
Some(default),
));
let (_, rx) = watch::channel(live);
AuthMiddleware::new(rx)
}
#[tokio::test]
async fn unknown_agent_falls_back_to_default_policy() {
let default = make_agent(None, vec!["delete_*"], 60);
let mw = make_mw_with_default(HashMap::new(), default);
assert!(matches!(
mw.check(&ctx("unknown-agent", "tools/call", Some("read_file")))
.await,
Decision::Allow { .. }
));
assert!(matches!(
mw.check(&ctx("unknown-agent", "tools/call", Some("delete_db")))
.await,
Decision::Block { .. }
));
}
#[tokio::test]
async fn named_agent_takes_precedence_over_default_policy() {
let mut agents = HashMap::new();
agents.insert(
"strict-agent".to_string(),
make_agent(Some(vec!["read_file"]), vec![], 60),
);
let default = make_agent(Some(vec!["*"]), vec![], 60);
let mw = make_mw_with_default(agents, default);
assert!(matches!(
mw.check(&ctx("strict-agent", "tools/call", Some("write_file")))
.await,
Decision::Block { .. }
));
}
#[tokio::test]
async fn tools_call_without_tool_name_blocked() {
let mw = make_mw(HashMap::new());
let ctx = McpContext {
agent_id: "ghost".to_string(),
method: "tools/call".to_string(),
tool_name: None,
arguments: None,
client_ip: None,
};
assert!(matches!(mw.check(&ctx).await, Decision::Block { .. }));
}
#[tokio::test]
async fn block_reason_contains_tool_name() {
let mut agents = HashMap::new();
agents.insert("agent".to_string(), make_agent(None, vec!["delete_db"], 60));
let mw = make_mw(agents);
if let Decision::Block { reason, .. } = mw
.check(&ctx("agent", "tools/call", Some("delete_db")))
.await
{
assert!(reason.contains("delete_db"));
} else {
panic!("expected Block");
}
}
#[tokio::test]
async fn block_reason_contains_agent_name_for_unknown_agent() {
let mw = make_mw(HashMap::new());
if let Decision::Block { reason, .. } = mw
.check(&ctx("mystery-agent", "tools/call", Some("echo")))
.await
{
assert!(reason.contains("mystery-agent"));
} else {
panic!("expected Block");
}
}
}
pub struct AuthMiddleware {
config: watch::Receiver<Arc<LiveConfig>>,
}
impl AuthMiddleware {
pub fn new(config: watch::Receiver<Arc<LiveConfig>>) -> Self {
Self { config }
}
}
#[async_trait]
impl Middleware for AuthMiddleware {
fn name(&self) -> &'static str {
"auth"
}
async fn check(&self, ctx: &McpContext) -> Decision {
if ctx.method != "tools/call" {
return Decision::Allow { rl: None };
}
let tool = ctx.tool_name.as_deref().unwrap_or("");
let cfg = self.config.borrow();
let Some(policy) = cfg
.agents
.get(&ctx.agent_id)
.or(cfg.default_policy.as_ref())
else {
return Decision::Block {
reason: format!("unknown agent '{}'", ctx.agent_id),
rl: None,
};
};
if policy.denied_tools.iter().any(|t| tool_matches(t, tool)) {
return Decision::Block {
reason: format!("tool '{tool}' explicitly denied"),
rl: None,
};
}
if let Some(allowed) = &policy.allowed_tools
&& !allowed.iter().any(|t| tool_matches(t, tool))
{
return Decision::Block {
reason: format!("tool '{tool}' not in allowlist"),
rl: None,
};
}
Decision::Allow { rl: None }
}
}