use serde::Serialize;
use crate::config::{McpPolicyConfig, OperatingMode};
use crate::error::StorageError;
use crate::storage::rate_limits;
use crate::storage::DbPool;
use super::rules::{build_effective_rules, find_matching_rule, make_eval_context};
use super::types::{tool_category, PolicyAction, PolicyAuditRecordV2};
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum PolicyDecision {
Allow,
RouteToApproval {
reason: String,
rule_id: Option<String>,
},
Deny {
reason: PolicyDenialReason,
rule_id: Option<String>,
},
DryRun { rule_id: Option<String> },
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
pub enum PolicyDenialReason {
ToolBlocked,
RateLimited,
HardRule,
UserRule,
}
impl std::fmt::Display for PolicyDenialReason {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
PolicyDenialReason::ToolBlocked => write!(f, "tool_blocked"),
PolicyDenialReason::RateLimited => write!(f, "rate_limited"),
PolicyDenialReason::HardRule => write!(f, "hard_rule"),
PolicyDenialReason::UserRule => write!(f, "user_rule"),
}
}
}
#[derive(Debug, Clone, Serialize)]
pub struct PolicyAuditRecord {
pub tool_name: String,
pub decision: String,
pub reason: Option<String>,
pub matched_rule_id: Option<String>,
pub matched_rule_label: Option<String>,
pub rate_limit_key: Option<String>,
}
pub struct McpPolicyEvaluator;
impl McpPolicyEvaluator {
pub async fn evaluate(
pool: &DbPool,
config: &McpPolicyConfig,
mode: &OperatingMode,
tool_name: &str,
) -> Result<PolicyDecision, StorageError> {
if !config.enforce_for_mutations {
return Ok(PolicyDecision::Allow);
}
let rules = build_effective_rules(config, mode);
let ctx = make_eval_context(tool_name, mode);
if let Some(rule) = find_matching_rule(&rules, &ctx) {
let rule_id = Some(rule.id.clone());
match &rule.action {
PolicyAction::Allow => {
}
PolicyAction::Deny { reason } => {
let denial = if rule.id.starts_with("v1:blocked:") {
PolicyDenialReason::ToolBlocked
} else if rule.id.starts_with("hard:") {
PolicyDenialReason::HardRule
} else {
PolicyDenialReason::UserRule
};
return Ok(PolicyDecision::Deny {
reason: denial,
rule_id: Some(format!("{} ({})", rule.id, reason)),
});
}
PolicyAction::RequireApproval { reason } => {
return Ok(PolicyDecision::RouteToApproval {
reason: reason.clone(),
rule_id,
});
}
PolicyAction::DryRun => {
return Ok(PolicyDecision::DryRun { rule_id });
}
}
}
if let Some(exceeded_key) = rate_limits::check_policy_rate_limits(
pool,
tool_name,
&ctx.category.to_string(),
&config.rate_limits,
)
.await?
{
return Ok(PolicyDecision::Deny {
reason: PolicyDenialReason::RateLimited,
rule_id: Some(exceeded_key),
});
}
let allowed = rate_limits::check_rate_limit(pool, "mcp_mutation").await?;
if !allowed {
return Ok(PolicyDecision::Deny {
reason: PolicyDenialReason::RateLimited,
rule_id: None,
});
}
Ok(PolicyDecision::Allow)
}
pub async fn log_decision(
pool: &DbPool,
tool_name: &str,
decision: &PolicyDecision,
) -> Result<(), StorageError> {
let (status, reason_str, rule_id) = match decision {
PolicyDecision::Allow => ("allowed", None, None),
PolicyDecision::RouteToApproval { reason, rule_id } => {
("routed_to_approval", Some(reason.clone()), rule_id.clone())
}
PolicyDecision::Deny { reason, rule_id } => {
("denied", Some(reason.to_string()), rule_id.clone())
}
PolicyDecision::DryRun { rule_id } => ("dry_run", None, rule_id.clone()),
};
let category = tool_category(tool_name);
let audit = PolicyAuditRecordV2 {
tool_name: tool_name.to_string(),
category: category.to_string(),
decision: status.to_string(),
reason: reason_str,
matched_rule_id: rule_id,
matched_rule_label: None,
rate_limit_key: None,
};
let metadata = serde_json::to_string(&audit).ok();
crate::storage::action_log::log_action(
pool,
"mcp_policy",
status,
Some(tool_name),
metadata.as_deref(),
)
.await
}
pub async fn record_mutation(
pool: &DbPool,
tool_name: &str,
rate_limit_configs: &[super::types::PolicyRateLimit],
) -> Result<(), StorageError> {
rate_limits::increment_rate_limit(pool, "mcp_mutation").await?;
let category = tool_category(tool_name).to_string();
rate_limits::record_policy_rate_limits(pool, tool_name, &category, rate_limit_configs).await
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn denial_reason_display_tool_blocked() {
assert_eq!(PolicyDenialReason::ToolBlocked.to_string(), "tool_blocked");
}
#[test]
fn denial_reason_display_rate_limited() {
assert_eq!(PolicyDenialReason::RateLimited.to_string(), "rate_limited");
}
#[test]
fn denial_reason_display_hard_rule() {
assert_eq!(PolicyDenialReason::HardRule.to_string(), "hard_rule");
}
#[test]
fn denial_reason_display_user_rule() {
assert_eq!(PolicyDenialReason::UserRule.to_string(), "user_rule");
}
#[test]
fn policy_decision_allow_eq() {
assert_eq!(PolicyDecision::Allow, PolicyDecision::Allow);
}
#[test]
fn policy_decision_dry_run_pattern() {
let d = PolicyDecision::DryRun {
rule_id: Some("r1".to_string()),
};
assert!(matches!(d, PolicyDecision::DryRun { .. }));
}
#[test]
fn policy_decision_route_to_approval_pattern() {
let d = PolicyDecision::RouteToApproval {
reason: "needs review".to_string(),
rule_id: None,
};
assert!(matches!(d, PolicyDecision::RouteToApproval { .. }));
}
#[test]
fn policy_decision_deny_pattern() {
let d = PolicyDecision::Deny {
reason: PolicyDenialReason::ToolBlocked,
rule_id: None,
};
assert!(matches!(
d,
PolicyDecision::Deny {
reason: PolicyDenialReason::ToolBlocked,
..
}
));
}
#[test]
fn denial_reason_eq() {
assert_eq!(
PolicyDenialReason::RateLimited,
PolicyDenialReason::RateLimited
);
assert_ne!(
PolicyDenialReason::ToolBlocked,
PolicyDenialReason::HardRule
);
}
}