use chrono::{DateTime, Duration, Utc};
use serde::{Deserialize, Serialize};
use serde_json::Value;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ApprovalPolicy {
AlwaysAllow,
AlwaysRequire,
RequireForTools(Vec<String>),
RequireForDangerous,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum ApprovalPolicyConfig {
#[default]
AlwaysAllow,
AlwaysRequire,
RequireForTools,
RequireForDangerous,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(default)]
pub struct ApprovalConfig {
pub enabled: bool,
pub policy: ApprovalPolicyConfig,
pub require_for: Vec<String>,
pub dangerous_tools: Vec<String>,
pub auto_approve_timeout_secs: u64,
}
impl Default for ApprovalConfig {
fn default() -> Self {
Self {
enabled: false,
policy: ApprovalPolicyConfig::AlwaysAllow,
require_for: Vec::new(),
dangerous_tools: ApprovalGate::default_dangerous_tools(),
auto_approve_timeout_secs: 0,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ApprovalRequest {
pub tool_name: String,
pub arguments: Value,
pub timestamp: DateTime<Utc>,
pub auto_approve_at: Option<DateTime<Utc>>,
}
impl ApprovalRequest {
pub fn new(tool_name: String, arguments: Value, auto_approve_timeout_secs: u64) -> Self {
let timestamp = Utc::now();
let auto_approve_at = if auto_approve_timeout_secs > 0 {
Some(timestamp + Duration::seconds(auto_approve_timeout_secs as i64))
} else {
None
};
Self {
tool_name,
arguments,
timestamp,
auto_approve_at,
}
}
pub fn is_auto_approved(&self) -> bool {
match self.auto_approve_at {
Some(deadline) => Utc::now() >= deadline,
None => false,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ApprovalResponse {
Approved,
Denied(String),
TimedOut,
}
pub struct ApprovalGate {
enabled: bool,
policy: ApprovalPolicy,
auto_approve_timeout_secs: u64,
}
impl ApprovalGate {
pub fn new(config: ApprovalConfig) -> Self {
let policy = match config.policy {
ApprovalPolicyConfig::AlwaysAllow => ApprovalPolicy::AlwaysAllow,
ApprovalPolicyConfig::AlwaysRequire => ApprovalPolicy::AlwaysRequire,
ApprovalPolicyConfig::RequireForTools => {
ApprovalPolicy::RequireForTools(config.require_for)
}
ApprovalPolicyConfig::RequireForDangerous => ApprovalPolicy::RequireForDangerous,
};
let policy = if policy == ApprovalPolicy::RequireForDangerous {
ApprovalPolicy::RequireForTools(config.dangerous_tools)
} else {
policy
};
Self {
enabled: config.enabled,
policy,
auto_approve_timeout_secs: config.auto_approve_timeout_secs,
}
}
pub fn requires_approval(&self, tool_name: &str) -> bool {
if !self.enabled {
return false;
}
match &self.policy {
ApprovalPolicy::AlwaysAllow => false,
ApprovalPolicy::AlwaysRequire => true,
ApprovalPolicy::RequireForTools(tools) => tools
.iter()
.any(|pattern| matches_tool_pattern(pattern, tool_name)),
ApprovalPolicy::RequireForDangerous => Self::default_dangerous_tools()
.iter()
.any(|pattern| matches_tool_pattern(pattern, tool_name)),
}
}
pub fn format_approval_request(&self, tool_name: &str, args: &Value) -> String {
let args_display = match serde_json::to_string_pretty(args) {
Ok(pretty) => pretty,
Err(_) => args.to_string(),
};
format!(
"[Approval Required]\n\
Tool: {tool_name}\n\
Arguments:\n{args_display}\n\n\
Approve execution? (yes/no)"
)
}
pub fn create_request(&self, tool_name: &str, args: &Value) -> ApprovalRequest {
ApprovalRequest::new(
tool_name.to_string(),
args.clone(),
self.auto_approve_timeout_secs,
)
}
pub fn default_dangerous_tools() -> Vec<String> {
vec![
"shell".to_string(),
"write_file".to_string(),
"edit_file".to_string(),
"google".to_string(),
]
}
pub fn policy(&self) -> &ApprovalPolicy {
&self.policy
}
pub fn is_enabled(&self) -> bool {
self.enabled
}
}
fn matches_tool_pattern(pattern: &str, tool_name: &str) -> bool {
if pattern == "*" {
return true;
}
if let Some(prefix) = pattern.strip_suffix('*') {
return tool_name.starts_with(prefix);
}
if let Some(suffix) = pattern.strip_prefix('*') {
return tool_name.ends_with(suffix);
}
pattern == tool_name
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn test_always_allow_returns_false() {
let config = ApprovalConfig {
enabled: true,
policy: ApprovalPolicyConfig::AlwaysAllow,
..Default::default()
};
let gate = ApprovalGate::new(config);
assert!(!gate.requires_approval("shell"));
assert!(!gate.requires_approval("write_file"));
assert!(!gate.requires_approval("echo"));
assert!(!gate.requires_approval("anything"));
}
#[test]
fn test_always_require_returns_true() {
let config = ApprovalConfig {
enabled: true,
policy: ApprovalPolicyConfig::AlwaysRequire,
..Default::default()
};
let gate = ApprovalGate::new(config);
assert!(gate.requires_approval("shell"));
assert!(gate.requires_approval("write_file"));
assert!(gate.requires_approval("echo"));
assert!(gate.requires_approval("web_search"));
}
#[test]
fn test_require_for_tools_matches_listed_tools() {
let config = ApprovalConfig {
enabled: true,
policy: ApprovalPolicyConfig::RequireForTools,
require_for: vec!["shell".to_string(), "write_file".to_string()],
..Default::default()
};
let gate = ApprovalGate::new(config);
assert!(gate.requires_approval("shell"));
assert!(gate.requires_approval("write_file"));
assert!(!gate.requires_approval("echo"));
assert!(!gate.requires_approval("read_file"));
}
#[test]
fn test_require_for_tools_empty_list() {
let config = ApprovalConfig {
enabled: true,
policy: ApprovalPolicyConfig::RequireForTools,
require_for: vec![],
..Default::default()
};
let gate = ApprovalGate::new(config);
assert!(!gate.requires_approval("shell"));
assert!(!gate.requires_approval("anything"));
}
#[test]
fn test_require_for_tools_multiple_tools() {
let config = ApprovalConfig {
enabled: true,
policy: ApprovalPolicyConfig::RequireForTools,
require_for: vec![
"shell".to_string(),
"write_file".to_string(),
"edit_file".to_string(),
"web_fetch".to_string(),
],
..Default::default()
};
let gate = ApprovalGate::new(config);
assert!(gate.requires_approval("shell"));
assert!(gate.requires_approval("write_file"));
assert!(gate.requires_approval("edit_file"));
assert!(gate.requires_approval("web_fetch"));
assert!(!gate.requires_approval("echo"));
assert!(!gate.requires_approval("read_file"));
}
#[test]
fn test_require_for_dangerous_matches_default_dangerous_tools() {
let config = ApprovalConfig {
enabled: true,
policy: ApprovalPolicyConfig::RequireForDangerous,
..Default::default()
};
let gate = ApprovalGate::new(config);
assert!(gate.requires_approval("shell"));
assert!(gate.requires_approval("write_file"));
assert!(gate.requires_approval("edit_file"));
assert!(!gate.requires_approval("echo"));
assert!(!gate.requires_approval("read_file"));
assert!(!gate.requires_approval("web_search"));
}
#[test]
fn test_require_for_dangerous_custom_list() {
let config = ApprovalConfig {
enabled: true,
policy: ApprovalPolicyConfig::RequireForDangerous,
dangerous_tools: vec!["web_fetch".to_string(), "message".to_string()],
..Default::default()
};
let gate = ApprovalGate::new(config);
assert!(gate.requires_approval("web_fetch"));
assert!(gate.requires_approval("message"));
assert!(!gate.requires_approval("shell"));
assert!(!gate.requires_approval("write_file"));
}
#[test]
fn test_disabled_config_bypasses_all_checks() {
let config = ApprovalConfig {
enabled: false,
policy: ApprovalPolicyConfig::AlwaysRequire,
..Default::default()
};
let gate = ApprovalGate::new(config);
assert!(!gate.requires_approval("shell"));
assert!(!gate.requires_approval("write_file"));
assert!(!gate.requires_approval("echo"));
}
#[test]
fn test_default_config() {
let config = ApprovalConfig::default();
assert!(!config.enabled);
assert_eq!(config.policy, ApprovalPolicyConfig::AlwaysAllow);
assert!(config.require_for.is_empty());
assert_eq!(
config.dangerous_tools,
vec!["shell", "write_file", "edit_file", "google"]
);
assert_eq!(config.auto_approve_timeout_secs, 0);
}
#[test]
fn test_tool_name_case_sensitivity() {
let config = ApprovalConfig {
enabled: true,
policy: ApprovalPolicyConfig::RequireForTools,
require_for: vec!["shell".to_string()],
..Default::default()
};
let gate = ApprovalGate::new(config);
assert!(gate.requires_approval("shell"));
assert!(!gate.requires_approval("Shell"));
assert!(!gate.requires_approval("SHELL"));
}
#[test]
fn test_wildcard_patterns() {
let config = ApprovalConfig {
enabled: true,
policy: ApprovalPolicyConfig::RequireForTools,
require_for: vec!["shell*".to_string(), "*_file".to_string()],
..Default::default()
};
let gate = ApprovalGate::new(config);
assert!(gate.requires_approval("shell"));
assert!(gate.requires_approval("shell_exec"));
assert!(gate.requires_approval("write_file"));
assert!(!gate.requires_approval("web_search"));
}
#[test]
fn test_format_approval_request_output() {
let config = ApprovalConfig::default();
let gate = ApprovalGate::new(config);
let args = json!({"command": "rm -rf /tmp/test"});
let output = gate.format_approval_request("shell", &args);
assert!(output.contains("[Approval Required]"));
assert!(output.contains("Tool: shell"));
assert!(output.contains("rm -rf /tmp/test"));
assert!(output.contains("Approve execution? (yes/no)"));
}
#[test]
fn test_format_approval_request_complex_args() {
let config = ApprovalConfig::default();
let gate = ApprovalGate::new(config);
let args = json!({
"path": "/home/user/file.txt",
"content": "Hello, world!",
"overwrite": true
});
let output = gate.format_approval_request("write_file", &args);
assert!(output.contains("Tool: write_file"));
assert!(output.contains("/home/user/file.txt"));
assert!(output.contains("Hello, world!"));
}
#[test]
fn test_approval_request_construction() {
let args = json!({"command": "ls -la"});
let request = ApprovalRequest::new("shell".to_string(), args.clone(), 0);
assert_eq!(request.tool_name, "shell");
assert_eq!(request.arguments, args);
assert!(request.auto_approve_at.is_none());
}
#[test]
fn test_approval_request_with_auto_approve_timeout() {
let args = json!({"command": "ls"});
let before = Utc::now();
let request = ApprovalRequest::new("shell".to_string(), args, 30);
let after = Utc::now();
assert!(request.auto_approve_at.is_some());
let deadline = request.auto_approve_at.unwrap();
let earliest = before + Duration::seconds(30);
let latest = after + Duration::seconds(30);
assert!(deadline >= earliest);
assert!(deadline <= latest);
}
#[test]
fn test_approval_request_not_auto_approved_immediately() {
let args = json!({"command": "ls"});
let request = ApprovalRequest::new("shell".to_string(), args, 60);
assert!(!request.is_auto_approved());
}
#[test]
fn test_approval_request_no_auto_approve_when_disabled() {
let args = json!({"command": "ls"});
let request = ApprovalRequest::new("shell".to_string(), args, 0);
assert!(!request.is_auto_approved());
}
#[test]
fn test_approval_response_variants() {
let approved = ApprovalResponse::Approved;
let denied = ApprovalResponse::Denied("too dangerous".to_string());
let timed_out = ApprovalResponse::TimedOut;
assert_eq!(approved, ApprovalResponse::Approved);
assert_eq!(
denied,
ApprovalResponse::Denied("too dangerous".to_string())
);
assert_eq!(timed_out, ApprovalResponse::TimedOut);
}
#[test]
fn test_approval_config_serialization_roundtrip() {
let config = ApprovalConfig {
enabled: true,
policy: ApprovalPolicyConfig::RequireForTools,
require_for: vec!["shell".to_string(), "write_file".to_string()],
dangerous_tools: vec![
"shell".to_string(),
"write_file".to_string(),
"edit_file".to_string(),
],
auto_approve_timeout_secs: 30,
};
let json_str = serde_json::to_string(&config).expect("serialize");
let deserialized: ApprovalConfig = serde_json::from_str(&json_str).expect("deserialize");
assert_eq!(deserialized.enabled, config.enabled);
assert_eq!(deserialized.policy, config.policy);
assert_eq!(deserialized.require_for, config.require_for);
assert_eq!(deserialized.dangerous_tools, config.dangerous_tools);
assert_eq!(
deserialized.auto_approve_timeout_secs,
config.auto_approve_timeout_secs
);
}
#[test]
fn test_approval_config_deserialize_from_json() {
let json_str = r#"{
"enabled": true,
"policy": "require_for_dangerous",
"dangerous_tools": ["shell", "write_file"]
}"#;
let config: ApprovalConfig = serde_json::from_str(json_str).unwrap();
assert!(config.enabled);
assert_eq!(config.policy, ApprovalPolicyConfig::RequireForDangerous);
assert_eq!(
config.dangerous_tools,
vec!["shell".to_string(), "write_file".to_string()]
);
assert!(config.require_for.is_empty());
assert_eq!(config.auto_approve_timeout_secs, 0);
}
#[test]
fn test_default_dangerous_tools_list() {
let defaults = ApprovalGate::default_dangerous_tools();
assert_eq!(defaults.len(), 4);
assert!(defaults.contains(&"shell".to_string()));
assert!(defaults.contains(&"write_file".to_string()));
assert!(defaults.contains(&"edit_file".to_string()));
assert!(defaults.contains(&"google".to_string()));
}
#[test]
fn test_gate_create_request() {
let config = ApprovalConfig {
enabled: true,
auto_approve_timeout_secs: 45,
..Default::default()
};
let gate = ApprovalGate::new(config);
let args = json!({"command": "echo test"});
let request = gate.create_request("shell", &args);
assert_eq!(request.tool_name, "shell");
assert_eq!(request.arguments, args);
assert!(request.auto_approve_at.is_some());
}
#[test]
fn test_gate_is_enabled() {
let enabled_gate = ApprovalGate::new(ApprovalConfig {
enabled: true,
..Default::default()
});
assert!(enabled_gate.is_enabled());
let disabled_gate = ApprovalGate::new(ApprovalConfig::default());
assert!(!disabled_gate.is_enabled());
}
#[test]
fn test_gate_policy_accessor() {
let config = ApprovalConfig {
enabled: true,
policy: ApprovalPolicyConfig::AlwaysRequire,
..Default::default()
};
let gate = ApprovalGate::new(config);
assert_eq!(*gate.policy(), ApprovalPolicy::AlwaysRequire);
}
#[test]
fn test_gate_policy_always_allow_accessor() {
let config = ApprovalConfig {
enabled: true,
policy: ApprovalPolicyConfig::AlwaysAllow,
..Default::default()
};
let gate = ApprovalGate::new(config);
assert_eq!(*gate.policy(), ApprovalPolicy::AlwaysAllow);
}
}