#[cfg(test)]
mod tests;
use crate::config::{McpPolicyConfig, OperatingMode};
use crate::error::StorageError;
use crate::mcp_policy::types::PolicyRateLimit;
use crate::mcp_policy::{McpPolicyEvaluator, PolicyDecision, PolicyDenialReason};
use crate::storage::mutation_audit;
use crate::storage::DbPool;
const IDEMPOTENCY_WINDOW_SECS: u32 = 300;
pub struct MutationGateway;
pub struct MutationRequest<'a> {
pub pool: &'a DbPool,
pub policy_config: &'a McpPolicyConfig,
pub mode: &'a OperatingMode,
pub tool_name: &'a str,
pub params_json: &'a str,
}
#[derive(Debug)]
pub enum GatewayDecision {
Proceed(MutationTicket),
Denied(GatewayDenial),
RoutedToApproval {
queue_id: i64,
reason: String,
rule_id: Option<String>,
},
DryRun { rule_id: Option<String> },
Duplicate(DuplicateInfo),
}
#[derive(Debug, Clone)]
pub struct GatewayDenial {
pub reason: PolicyDenialReason,
pub rule_id: Option<String>,
}
#[derive(Debug, Clone)]
pub struct DuplicateInfo {
pub original_correlation_id: String,
pub cached_result: Option<String>,
pub audit_id: i64,
}
#[derive(Debug)]
pub struct MutationTicket {
pub audit_id: i64,
pub correlation_id: String,
pub tool_name: String,
}
impl MutationGateway {
pub async fn evaluate(req: &MutationRequest<'_>) -> Result<GatewayDecision, StorageError> {
let decision =
McpPolicyEvaluator::evaluate(req.pool, req.policy_config, req.mode, req.tool_name)
.await?;
let _ = McpPolicyEvaluator::log_decision(req.pool, req.tool_name, &decision).await;
match decision {
PolicyDecision::Deny { reason, rule_id } => {
return Ok(GatewayDecision::Denied(GatewayDenial { reason, rule_id }));
}
PolicyDecision::RouteToApproval { reason, rule_id } => {
let queue_id = crate::storage::approval_queue::enqueue_with_context(
req.pool,
req.tool_name,
"",
"",
req.params_json,
"mcp_policy",
req.tool_name,
0.0,
"[]",
Some(&reason),
Some(&match &rule_id {
Some(rid) => format!(r#"["policy_rule:{rid}"]"#),
None => "[]".to_string(),
}),
)
.await
.map_err(|e| StorageError::Query {
source: sqlx::Error::Protocol(format!("Failed to enqueue for approval: {e}")),
})?;
return Ok(GatewayDecision::RoutedToApproval {
queue_id,
reason,
rule_id,
});
}
PolicyDecision::DryRun { rule_id } => {
return Ok(GatewayDecision::DryRun { rule_id });
}
PolicyDecision::Allow => { }
}
let params_hash = mutation_audit::compute_params_hash(req.tool_name, req.params_json);
let params_summary = mutation_audit::truncate_summary(req.params_json, 500);
if let Some(existing) = mutation_audit::find_recent_duplicate(
req.pool,
req.tool_name,
¶ms_hash,
IDEMPOTENCY_WINDOW_SECS,
)
.await?
{
let dup_corr = generate_correlation_id();
let dup_id = mutation_audit::insert_pending(
req.pool,
&dup_corr,
None,
req.tool_name,
¶ms_hash,
¶ms_summary,
)
.await?;
let _ =
mutation_audit::mark_duplicate(req.pool, dup_id, &existing.correlation_id).await;
return Ok(GatewayDecision::Duplicate(DuplicateInfo {
original_correlation_id: existing.correlation_id,
cached_result: existing.result_summary,
audit_id: dup_id,
}));
}
let correlation_id = generate_correlation_id();
let audit_id = mutation_audit::insert_pending(
req.pool,
&correlation_id,
None,
req.tool_name,
¶ms_hash,
¶ms_summary,
)
.await?;
Ok(GatewayDecision::Proceed(MutationTicket {
audit_id,
correlation_id,
tool_name: req.tool_name.to_string(),
}))
}
pub async fn complete_success(
pool: &DbPool,
ticket: &MutationTicket,
result_summary: &str,
rollback_action: Option<&str>,
elapsed_ms: u64,
rate_limit_configs: &[PolicyRateLimit],
) -> Result<(), StorageError> {
let summary = mutation_audit::truncate_summary(result_summary, 500);
mutation_audit::complete_success(
pool,
ticket.audit_id,
&summary,
rollback_action,
elapsed_ms,
)
.await?;
McpPolicyEvaluator::record_mutation(pool, &ticket.tool_name, rate_limit_configs).await?;
Ok(())
}
pub async fn complete_failure(
pool: &DbPool,
ticket: &MutationTicket,
error_message: &str,
elapsed_ms: u64,
) -> Result<(), StorageError> {
mutation_audit::complete_failure(pool, ticket.audit_id, error_message, elapsed_ms).await
}
}
fn generate_correlation_id() -> String {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::SystemTime;
static COUNTER: AtomicU64 = AtomicU64::new(0);
let now = SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)
.unwrap_or_default();
let nanos = now.as_nanos();
let count = COUNTER.fetch_add(1, Ordering::Relaxed);
let mut hasher = DefaultHasher::new();
nanos.hash(&mut hasher);
count.hash(&mut hasher);
std::thread::current().id().hash(&mut hasher);
let h1 = hasher.finish();
count.wrapping_add(1).hash(&mut hasher);
let h2 = hasher.finish();
format!(
"{:08x}-{:04x}-4{:03x}-{:04x}-{:012x}",
(h1 >> 32) as u32,
(h1 >> 16) as u16,
h1 as u16 & 0x0fff,
(h2 >> 48) as u16 & 0x3fff | 0x8000,
h2 & 0xffffffffffff,
)
}