use std::collections::VecDeque;
use std::sync::Mutex;
use std::time::Instant;
use dashmap::DashMap;
use serde::{Deserialize, Serialize};
use zeph_common::ToolName;
use crate::manager::McpTrustLevel;
use crate::tool::{DataSensitivity, McpTool};
#[derive(Debug, thiserror::Error)]
pub enum DataFlowViolation {
#[error(
"tool '{tool_name}' (sensitivity={sensitivity:?}) on server '{server_id}' \
(trust={trust:?}) violates data-flow policy: \
high-sensitivity tools require trusted servers"
)]
SensitivityTrustMismatch {
server_id: String,
tool_name: ToolName,
sensitivity: DataSensitivity,
trust: McpTrustLevel,
},
}
pub fn check_data_flow(
tool: &McpTool,
server_trust: McpTrustLevel,
) -> Result<(), DataFlowViolation> {
match (tool.security_meta.data_sensitivity, server_trust) {
(DataSensitivity::High, McpTrustLevel::Untrusted | McpTrustLevel::Sandboxed) => {
Err(DataFlowViolation::SensitivityTrustMismatch {
server_id: tool.server_id.clone(),
tool_name: tool.name.as_str().into(),
sensitivity: tool.security_meta.data_sensitivity,
trust: server_trust,
})
}
(DataSensitivity::Medium, McpTrustLevel::Sandboxed) => {
tracing::warn!(
server_id = %tool.server_id,
tool_name = %tool.name,
"medium-sensitivity tool on sandboxed server — use with caution"
);
Ok(())
}
_ => Ok(()),
}
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct RateLimit {
pub max_calls_per_minute: u32,
}
#[derive(Debug, Clone, Default, Deserialize, Serialize)]
#[serde(default)]
pub struct McpPolicy {
pub allowed_tools: Option<Vec<String>>,
pub denied_tools: Vec<String>,
pub rate_limit: Option<RateLimit>,
}
#[derive(Debug, thiserror::Error)]
pub enum PolicyViolation {
#[error("tool '{tool_name}' is denied on server '{server_id}'")]
ToolDenied {
server_id: String,
tool_name: ToolName,
},
#[error("tool '{tool_name}' is not in the allowlist for server '{server_id}'")]
ToolNotAllowed {
server_id: String,
tool_name: ToolName,
},
#[error("rate limit exceeded for server '{server_id}' (max {max_calls_per_minute}/min)")]
RateLimitExceeded {
server_id: String,
max_calls_per_minute: u32,
},
}
pub struct PolicyEnforcer {
policies: DashMap<String, McpPolicy>,
call_windows: DashMap<String, Mutex<VecDeque<Instant>>>,
}
impl PolicyEnforcer {
#[must_use]
pub fn new(entries: Vec<(String, McpPolicy)>) -> Self {
let policies = DashMap::new();
let call_windows = DashMap::new();
for (id, policy) in entries {
if policy.rate_limit.is_some() {
call_windows.insert(id.clone(), Mutex::new(VecDeque::new()));
}
policies.insert(id, policy);
}
Self {
policies,
call_windows,
}
}
pub fn check(&self, server_id: &str, tool_name: &str) -> Result<(), PolicyViolation> {
let Some(policy) = self.policies.get(server_id) else {
return Ok(());
};
if policy.denied_tools.iter().any(|t| t == tool_name) {
return Err(PolicyViolation::ToolDenied {
server_id: server_id.into(),
tool_name: tool_name.into(),
});
}
if policy
.allowed_tools
.as_ref()
.is_some_and(|allowlist| !allowlist.iter().any(|t| t == tool_name))
{
return Err(PolicyViolation::ToolNotAllowed {
server_id: server_id.into(),
tool_name: tool_name.into(),
});
}
if let Some(rl) = &policy.rate_limit {
self.check_rate_limit(server_id, rl.max_calls_per_minute)?;
}
Ok(())
}
fn check_rate_limit(
&self,
server_id: &str,
max_calls_per_minute: u32,
) -> Result<(), PolicyViolation> {
let window_entry = self
.call_windows
.get(server_id)
.expect("call_windows entry created alongside rate_limit policy");
let mut window = window_entry.lock().expect("rate limit mutex not poisoned");
let now = Instant::now();
let cutoff = now
.checked_sub(std::time::Duration::from_secs(60))
.unwrap_or(now);
while window.front().is_some_and(|t| *t < cutoff) {
window.pop_front();
}
if window.len() >= max_calls_per_minute as usize {
return Err(PolicyViolation::RateLimitExceeded {
server_id: server_id.into(),
max_calls_per_minute,
});
}
window.push_back(now);
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
fn enforcer_with_policy(server_id: &str, policy: McpPolicy) -> PolicyEnforcer {
PolicyEnforcer::new(vec![(server_id.into(), policy)])
}
#[test]
fn no_policy_allows_any_tool() {
let enforcer = PolicyEnforcer::new(vec![]);
assert!(enforcer.check("any-server", "any-tool").is_ok());
}
#[test]
fn denied_tool_blocked() {
let policy = McpPolicy {
denied_tools: vec!["rm".into()],
..Default::default()
};
let enforcer = enforcer_with_policy("srv", policy);
let err = enforcer.check("srv", "rm").unwrap_err();
assert!(matches!(err, PolicyViolation::ToolDenied { .. }));
}
#[test]
fn deny_takes_precedence_over_allowlist() {
let policy = McpPolicy {
allowed_tools: Some(vec!["rm".into()]),
denied_tools: vec!["rm".into()],
..Default::default()
};
let enforcer = enforcer_with_policy("srv", policy);
let err = enforcer.check("srv", "rm").unwrap_err();
assert!(matches!(err, PolicyViolation::ToolDenied { .. }));
}
#[test]
fn allowlist_blocks_unlisted_tool() {
let policy = McpPolicy {
allowed_tools: Some(vec!["read_file".into()]),
..Default::default()
};
let enforcer = enforcer_with_policy("srv", policy);
let err = enforcer.check("srv", "write_file").unwrap_err();
assert!(matches!(err, PolicyViolation::ToolNotAllowed { .. }));
}
#[test]
fn allowlist_permits_listed_tool() {
let policy = McpPolicy {
allowed_tools: Some(vec!["read_file".into()]),
..Default::default()
};
let enforcer = enforcer_with_policy("srv", policy);
assert!(enforcer.check("srv", "read_file").is_ok());
}
#[test]
fn rate_limit_blocks_after_threshold() {
let policy = McpPolicy {
rate_limit: Some(RateLimit {
max_calls_per_minute: 2,
}),
..Default::default()
};
let enforcer = enforcer_with_policy("srv", policy);
assert!(enforcer.check("srv", "tool").is_ok());
assert!(enforcer.check("srv", "tool").is_ok());
let err = enforcer.check("srv", "tool").unwrap_err();
assert!(matches!(err, PolicyViolation::RateLimitExceeded { .. }));
}
#[test]
fn unknown_server_is_allowed() {
let policy = McpPolicy {
denied_tools: vec!["rm".into()],
..Default::default()
};
let enforcer = enforcer_with_policy("srv", policy);
assert!(enforcer.check("other-srv", "rm").is_ok());
}
fn make_tool_with_meta(
name: &str,
sensitivity: crate::tool::DataSensitivity,
) -> crate::tool::McpTool {
use crate::tool::ToolSecurityMeta;
crate::tool::McpTool {
server_id: "srv".into(),
name: name.into(),
description: "test".into(),
input_schema: serde_json::json!({}),
security_meta: ToolSecurityMeta {
data_sensitivity: sensitivity,
capabilities: vec![],
flagged_parameters: Vec::new(),
},
}
}
#[test]
fn data_flow_high_sensitivity_untrusted_blocked() {
let tool = make_tool_with_meta("exec_shell", crate::tool::DataSensitivity::High);
let result = check_data_flow(&tool, McpTrustLevel::Untrusted);
assert!(matches!(
result,
Err(DataFlowViolation::SensitivityTrustMismatch { .. })
));
}
#[test]
fn data_flow_high_sensitivity_sandboxed_blocked() {
let tool = make_tool_with_meta("exec_shell", crate::tool::DataSensitivity::High);
let result = check_data_flow(&tool, McpTrustLevel::Sandboxed);
assert!(matches!(
result,
Err(DataFlowViolation::SensitivityTrustMismatch { .. })
));
}
#[test]
fn data_flow_high_sensitivity_trusted_allowed() {
let tool = make_tool_with_meta("exec_shell", crate::tool::DataSensitivity::High);
assert!(check_data_flow(&tool, McpTrustLevel::Trusted).is_ok());
}
#[test]
fn data_flow_medium_sensitivity_untrusted_allowed() {
let tool = make_tool_with_meta("write_file", crate::tool::DataSensitivity::Medium);
assert!(check_data_flow(&tool, McpTrustLevel::Untrusted).is_ok());
}
#[test]
fn data_flow_medium_sensitivity_sandboxed_warns_but_allows() {
let tool = make_tool_with_meta("write_file", crate::tool::DataSensitivity::Medium);
assert!(check_data_flow(&tool, McpTrustLevel::Sandboxed).is_ok());
}
#[test]
fn data_flow_low_sensitivity_any_trust_allowed() {
let tool = make_tool_with_meta("get_info", crate::tool::DataSensitivity::Low);
assert!(check_data_flow(&tool, McpTrustLevel::Untrusted).is_ok());
assert!(check_data_flow(&tool, McpTrustLevel::Sandboxed).is_ok());
assert!(check_data_flow(&tool, McpTrustLevel::Trusted).is_ok());
}
#[test]
fn data_flow_none_sensitivity_any_trust_allowed() {
let tool = make_tool_with_meta("read_info", crate::tool::DataSensitivity::None);
assert!(check_data_flow(&tool, McpTrustLevel::Untrusted).is_ok());
}
#[test]
fn data_flow_violation_message_descriptive() {
let tool = make_tool_with_meta("exec_shell", crate::tool::DataSensitivity::High);
let err = check_data_flow(&tool, McpTrustLevel::Untrusted).unwrap_err();
let msg = err.to_string();
assert!(msg.contains("exec_shell"));
assert!(msg.contains("high-sensitivity"));
}
#[test]
fn policy_violation_messages_are_descriptive() {
let denied = PolicyViolation::ToolDenied {
server_id: "s".into(),
tool_name: "t".into(),
};
assert!(denied.to_string().contains("denied"));
let not_allowed = PolicyViolation::ToolNotAllowed {
server_id: "s".into(),
tool_name: "t".into(),
};
assert!(not_allowed.to_string().contains("allowlist"));
let rate = PolicyViolation::RateLimitExceeded {
server_id: "s".into(),
max_calls_per_minute: 10,
};
assert!(rate.to_string().contains("rate limit"));
}
#[test]
fn data_flow_medium_sensitivity_trusted_allowed() {
let tool = make_tool_with_meta("write_file", crate::tool::DataSensitivity::Medium);
assert!(check_data_flow(&tool, McpTrustLevel::Trusted).is_ok());
}
}