use std::collections::HashMap;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use std::time::Instant;
use serde::{Deserialize, Serialize};
use tower::Service;
use tower_mcp::{McpRequest, McpResponse, RouterRequest, RouterResponse};
use crate::policy::ToolsetKind;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Deserialize, Serialize)]
#[serde(rename_all = "kebab-case")]
pub enum AuditLevel {
#[default]
All,
Writes,
Destructive,
Denied,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(default)]
pub struct AuditConfig {
pub enabled: bool,
pub level: AuditLevel,
pub include_args: bool,
pub redact_fields: Vec<String>,
}
impl Default for AuditConfig {
fn default() -> Self {
Self {
enabled: false,
level: AuditLevel::All,
include_args: false,
redact_fields: vec![
"password".to_string(),
"api_key".to_string(),
"api_secret".to_string(),
"secret".to_string(),
],
}
}
}
#[derive(Clone)]
pub struct AuditLayer {
config: Arc<AuditConfig>,
tool_toolset: Arc<HashMap<String, ToolsetKind>>,
}
impl AuditLayer {
pub fn new(config: Arc<AuditConfig>, tool_toolset: Arc<HashMap<String, ToolsetKind>>) -> Self {
Self {
config,
tool_toolset,
}
}
}
impl<S> tower::Layer<S> for AuditLayer {
type Service = AuditService<S>;
fn layer(&self, inner: S) -> Self::Service {
AuditService {
inner,
config: self.config.clone(),
tool_toolset: self.tool_toolset.clone(),
}
}
}
#[derive(Clone)]
pub struct AuditService<S> {
inner: S,
config: Arc<AuditConfig>,
tool_toolset: Arc<HashMap<String, ToolsetKind>>,
}
impl<S> Service<RouterRequest> for AuditService<S>
where
S: Service<RouterRequest, Response = RouterResponse, Error = std::convert::Infallible>
+ Clone
+ Send
+ 'static,
S::Future: Send,
{
type Response = RouterResponse;
type Error = std::convert::Infallible;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, req: RouterRequest) -> Self::Future {
let tool_call_info = match &req.inner {
McpRequest::CallTool(params) => {
let toolset = self
.tool_toolset
.get(¶ms.name)
.map(|k| k.to_string())
.unwrap_or_else(|| "unknown".to_string());
let args = if self.config.include_args {
let redacted = redact_value(¶ms.arguments, &self.config.redact_fields);
Some(redacted.to_string())
} else {
None
};
Some((params.name.clone(), toolset, args))
}
_ => None,
};
let config = self.config.clone();
let mut inner = self.inner.clone();
Box::pin(async move {
if let Some((tool_name, toolset, args)) = tool_call_info {
let start = Instant::now();
let response = inner.call(req).await?;
let duration_ms = start.elapsed().as_millis() as u64;
let (event, result_status) = match &response.inner {
Ok(McpResponse::CallTool(_)) => ("tool_invocation", "success"),
Ok(_) => ("tool_invocation", "success"),
Err(err) if err.code == -32007 => ("tool_denied", "denied"),
Err(_) => ("tool_error", "error"),
};
if should_log(config.level, event, &toolset) {
if let Some(args) = args {
tracing::info!(
target: "audit",
event,
tool = %tool_name,
toolset = %toolset,
result = result_status,
duration_ms,
arguments = %args,
);
} else {
tracing::info!(
target: "audit",
event,
tool = %tool_name,
toolset = %toolset,
result = result_status,
duration_ms,
);
}
}
Ok(response)
} else {
inner.call(req).await
}
})
}
}
fn should_log(level: AuditLevel, event: &str, _toolset: &str) -> bool {
match level {
AuditLevel::All | AuditLevel::Writes | AuditLevel::Destructive => true,
AuditLevel::Denied => event == "tool_denied" || event == "tool_error",
}
}
pub fn redact_value(value: &serde_json::Value, redact_fields: &[String]) -> serde_json::Value {
match value {
serde_json::Value::Object(map) => {
let redacted: serde_json::Map<String, serde_json::Value> = map
.iter()
.map(|(k, v)| {
if redact_fields.iter().any(|f| f == k) {
(
k.clone(),
serde_json::Value::String("[REDACTED]".to_string()),
)
} else {
(k.clone(), redact_value(v, redact_fields))
}
})
.collect();
serde_json::Value::Object(redacted)
}
serde_json::Value::Array(arr) => {
serde_json::Value::Array(arr.iter().map(|v| redact_value(v, redact_fields)).collect())
}
other => other.clone(),
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn default_config_is_disabled() {
let config = AuditConfig::default();
assert!(!config.enabled);
assert_eq!(config.level, AuditLevel::All);
assert!(!config.include_args);
assert!(!config.redact_fields.is_empty());
}
#[test]
fn toml_minimal() {
let config: AuditConfig = toml::from_str("enabled = true").unwrap();
assert!(config.enabled);
assert_eq!(config.level, AuditLevel::All);
}
#[test]
fn toml_full() {
let toml_str = r#"
enabled = true
level = "denied"
include_args = true
redact_fields = ["password", "token"]
"#;
let config: AuditConfig = toml::from_str(toml_str).unwrap();
assert!(config.enabled);
assert_eq!(config.level, AuditLevel::Denied);
assert!(config.include_args);
assert_eq!(config.redact_fields, vec!["password", "token"]);
}
#[test]
fn toml_empty_is_default() {
let config: AuditConfig = toml::from_str("").unwrap();
assert!(!config.enabled);
}
#[test]
fn toml_roundtrip() {
let config = AuditConfig {
enabled: true,
level: AuditLevel::Writes,
include_args: true,
redact_fields: vec!["secret".to_string()],
};
let s = toml::to_string_pretty(&config).unwrap();
let parsed: AuditConfig = toml::from_str(&s).unwrap();
assert_eq!(parsed.enabled, config.enabled);
assert_eq!(parsed.level, config.level);
assert_eq!(parsed.include_args, config.include_args);
assert_eq!(parsed.redact_fields, config.redact_fields);
}
#[test]
fn redact_top_level_fields() {
let value = json!({
"name": "my-db",
"password": "secret123",
"api_key": "ak_123"
});
let fields = vec!["password".to_string(), "api_key".to_string()];
let redacted = redact_value(&value, &fields);
assert_eq!(redacted["name"], "my-db");
assert_eq!(redacted["password"], "[REDACTED]");
assert_eq!(redacted["api_key"], "[REDACTED]");
}
#[test]
fn redact_nested_fields() {
let value = json!({
"config": {
"name": "test",
"credentials": {
"password": "secret",
"username": "admin"
}
}
});
let fields = vec!["password".to_string()];
let redacted = redact_value(&value, &fields);
assert_eq!(redacted["config"]["name"], "test");
assert_eq!(redacted["config"]["credentials"]["password"], "[REDACTED]");
assert_eq!(redacted["config"]["credentials"]["username"], "admin");
}
#[test]
fn redact_in_array() {
let value = json!([
{"name": "a", "secret": "s1"},
{"name": "b", "secret": "s2"}
]);
let fields = vec!["secret".to_string()];
let redacted = redact_value(&value, &fields);
assert_eq!(redacted[0]["name"], "a");
assert_eq!(redacted[0]["secret"], "[REDACTED]");
assert_eq!(redacted[1]["secret"], "[REDACTED]");
}
#[test]
fn redact_no_matching_fields() {
let value = json!({"name": "test", "count": 42});
let fields = vec!["password".to_string()];
let redacted = redact_value(&value, &fields);
assert_eq!(redacted, value);
}
#[test]
fn redact_scalar_passthrough() {
let value = json!("just a string");
let fields = vec!["password".to_string()];
let redacted = redact_value(&value, &fields);
assert_eq!(redacted, value);
}
#[test]
fn all_level_logs_everything() {
assert!(should_log(AuditLevel::All, "tool_invocation", "cloud"));
assert!(should_log(AuditLevel::All, "tool_denied", "cloud"));
assert!(should_log(AuditLevel::All, "tool_error", "cloud"));
}
#[test]
fn denied_level_only_logs_denied_and_errors() {
assert!(!should_log(AuditLevel::Denied, "tool_invocation", "cloud"));
assert!(should_log(AuditLevel::Denied, "tool_denied", "cloud"));
assert!(should_log(AuditLevel::Denied, "tool_error", "cloud"));
}
}