use std::collections::{HashMap, HashSet};
use std::time::{Duration, Instant};
use parking_lot::RwLock;
use serde::{Deserialize, Serialize};
use tokio::sync::oneshot;
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum ApprovalDecision {
Approve,
Deny,
ApproveAlways {
scope: ApprovalScope,
},
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum ApprovalScope {
ThisCall,
ThisTool,
ThisSession,
}
#[derive(Debug, Clone, PartialEq, Eq, thiserror::Error)]
pub enum ApprovalError {
#[error("approval request not found: {call_id}")]
NotFound {
call_id: String,
},
#[error("approval expired: executor no longer waiting for {call_id}")]
Expired {
call_id: String,
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PendingApprovalInfo {
pub call_id: String,
pub tool_name: String,
pub arguments: serde_json::Value,
pub elapsed: Duration,
}
struct PendingApproval {
call_id: String,
tool_name: String,
arguments: serde_json::Value,
requested_at: Instant,
respond: oneshot::Sender<ApprovalDecision>,
}
#[derive(Debug, Default)]
struct RuntimeOverrides {
approve_all: bool,
approved_tools: HashSet<String>,
}
pub struct ApprovalGate {
pending: RwLock<HashMap<String, PendingApproval>>,
overrides: RwLock<RuntimeOverrides>,
}
impl ApprovalGate {
pub fn new() -> Self {
Self {
pending: RwLock::new(HashMap::new()),
overrides: RwLock::new(RuntimeOverrides::default()),
}
}
pub fn request(
&self,
call_id: &str,
tool_name: &str,
arguments: serde_json::Value,
) -> oneshot::Receiver<ApprovalDecision> {
let (tx, rx) = oneshot::channel();
let entry = PendingApproval {
call_id: call_id.to_string(),
tool_name: tool_name.to_string(),
arguments,
requested_at: Instant::now(),
respond: tx,
};
self.pending.write().insert(call_id.to_string(), entry);
rx
}
pub fn deliver(&self, call_id: &str, decision: ApprovalDecision) -> Result<(), ApprovalError> {
let entry =
self.pending
.write()
.remove(call_id)
.ok_or_else(|| ApprovalError::NotFound {
call_id: call_id.to_string(),
})?;
if let ApprovalDecision::ApproveAlways { scope } = &decision {
self.apply_override(*scope, &entry.tool_name);
}
entry
.respond
.send(decision)
.map_err(|_| ApprovalError::Expired {
call_id: call_id.to_string(),
})
}
pub fn pending(&self) -> Vec<PendingApprovalInfo> {
let map = self.pending.read();
map.values()
.map(|p| PendingApprovalInfo {
call_id: p.call_id.clone(),
tool_name: p.tool_name.clone(),
arguments: p.arguments.clone(),
elapsed: p.requested_at.elapsed(),
})
.collect()
}
pub fn pending_count(&self) -> usize {
self.pending.read().len()
}
pub fn is_runtime_approved(&self, tool_name: &str) -> bool {
let overrides = self.overrides.read();
overrides.approve_all || overrides.approved_tools.contains(tool_name)
}
pub fn cleanup_closed(&self) -> usize {
let mut map = self.pending.write();
let before = map.len();
map.retain(|_, p| !p.respond.is_closed());
before - map.len()
}
pub fn cleanup_expired(&self, timeout: Duration) -> usize {
let mut map = self.pending.write();
let before = map.len();
map.retain(|_, p| p.requested_at.elapsed() < timeout);
before - map.len()
}
pub fn reset_overrides(&self) {
let mut overrides = self.overrides.write();
overrides.approve_all = false;
overrides.approved_tools.clear();
}
fn apply_override(&self, scope: ApprovalScope, tool_name: &str) {
match scope {
ApprovalScope::ThisCall => {
},
ApprovalScope::ThisTool => {
self.overrides
.write()
.approved_tools
.insert(tool_name.to_string());
},
ApprovalScope::ThisSession => {
self.overrides.write().approve_all = true;
},
}
}
}
impl Default for ApprovalGate {
fn default() -> Self {
Self::new()
}
}
impl std::fmt::Debug for ApprovalGate {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ApprovalGate")
.field("pending_count", &self.pending.read().len())
.field("overrides", &*self.overrides.read())
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
#[test]
fn test_request_creates_pending_entry() {
let gate = ApprovalGate::new();
assert_eq!(gate.pending_count(), 0);
let _rx = gate.request("call_1", "bash", serde_json::json!({"command": "ls"}));
assert_eq!(gate.pending_count(), 1);
let pending = gate.pending();
assert_eq!(pending.len(), 1);
assert_eq!(pending[0].call_id, "call_1");
assert_eq!(pending[0].tool_name, "bash");
assert_eq!(pending[0].arguments, serde_json::json!({"command": "ls"}));
}
#[tokio::test]
async fn test_approve_delivers_to_receiver() {
let gate = ApprovalGate::new();
let rx = gate.request("call_1", "bash", serde_json::json!({}));
gate.deliver("call_1", ApprovalDecision::Approve).unwrap();
let decision = rx.await.unwrap();
assert_eq!(decision, ApprovalDecision::Approve);
}
#[tokio::test]
async fn test_deny_delivers_to_receiver() {
let gate = ApprovalGate::new();
let rx = gate.request("call_1", "bash", serde_json::json!({}));
gate.deliver("call_1", ApprovalDecision::Deny).unwrap();
let decision = rx.await.unwrap();
assert_eq!(decision, ApprovalDecision::Deny);
}
#[tokio::test]
async fn test_approve_always_delivers_to_receiver() {
let gate = ApprovalGate::new();
let rx = gate.request("call_1", "bash", serde_json::json!({}));
let expected = ApprovalDecision::ApproveAlways {
scope: ApprovalScope::ThisTool,
};
gate.deliver("call_1", expected.clone()).unwrap();
let decision = rx.await.unwrap();
assert_eq!(decision, expected);
}
#[test]
fn test_deliver_unknown_call_id_returns_not_found() {
let gate = ApprovalGate::new();
let result = gate.deliver("nonexistent", ApprovalDecision::Approve);
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
ApprovalError::NotFound { call_id } if call_id == "nonexistent"
));
}
#[test]
fn test_deliver_to_empty_gate_returns_not_found() {
let gate = ApprovalGate::new();
let result = gate.deliver("call_1", ApprovalDecision::Deny);
assert!(matches!(
result.unwrap_err(),
ApprovalError::NotFound { .. }
));
}
#[test]
fn test_dropped_receiver_causes_expired_error() {
let gate = ApprovalGate::new();
let rx = gate.request("call_1", "bash", serde_json::json!({}));
drop(rx);
assert_eq!(gate.pending_count(), 1);
let result = gate.deliver("call_1", ApprovalDecision::Approve);
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
ApprovalError::Expired { call_id } if call_id == "call_1"
));
}
#[tokio::test]
async fn test_tokio_timeout_expires() {
let gate = Arc::new(ApprovalGate::new());
let rx = gate.request("call_1", "bash", serde_json::json!({}));
let result = tokio::time::timeout(Duration::from_millis(50), rx).await;
assert!(result.is_err(), "Should have timed out");
}
#[tokio::test]
async fn test_decision_just_before_timeout_is_honored() {
let gate = Arc::new(ApprovalGate::new());
let rx = gate.request("call_1", "bash", serde_json::json!({}));
let gate_clone = Arc::clone(&gate);
tokio::spawn(async move {
tokio::time::sleep(Duration::from_millis(30)).await;
gate_clone
.deliver("call_1", ApprovalDecision::Approve)
.unwrap();
});
let result = tokio::time::timeout(Duration::from_millis(200), rx).await;
assert!(result.is_ok(), "Should not have timed out");
assert_eq!(result.unwrap().unwrap(), ApprovalDecision::Approve);
}
#[test]
fn test_cleanup_expired_removes_old_entries() {
let gate = ApprovalGate::new();
let _rx = gate.request("call_1", "bash", serde_json::json!({}));
assert_eq!(gate.pending_count(), 1);
let removed = gate.cleanup_expired(Duration::ZERO);
assert_eq!(removed, 1);
assert_eq!(gate.pending_count(), 0);
}
#[test]
fn test_cleanup_expired_keeps_fresh_entries() {
let gate = ApprovalGate::new();
let _rx = gate.request("call_1", "bash", serde_json::json!({}));
let removed = gate.cleanup_expired(Duration::from_secs(3600));
assert_eq!(removed, 0);
assert_eq!(gate.pending_count(), 1);
}
#[test]
fn test_cleanup_closed_removes_dropped_receivers() {
let gate = ApprovalGate::new();
let rx1 = gate.request("call_1", "bash", serde_json::json!({}));
let _rx2 = gate.request("call_2", "read_file", serde_json::json!({}));
assert_eq!(gate.pending_count(), 2);
drop(rx1);
let removed = gate.cleanup_closed();
assert_eq!(removed, 1);
assert_eq!(gate.pending_count(), 1);
let pending = gate.pending();
assert_eq!(pending[0].call_id, "call_2");
}
#[test]
fn test_cleanup_closed_preserves_active_entries() {
let gate = ApprovalGate::new();
let _rx1 = gate.request("call_1", "bash", serde_json::json!({}));
let _rx2 = gate.request("call_2", "read_file", serde_json::json!({}));
let removed = gate.cleanup_closed();
assert_eq!(removed, 0);
assert_eq!(gate.pending_count(), 2);
}
#[tokio::test]
async fn test_approve_always_this_tool_adds_runtime_override() {
let gate = ApprovalGate::new();
assert!(!gate.is_runtime_approved("bash"));
let rx = gate.request("call_1", "bash", serde_json::json!({}));
gate.deliver(
"call_1",
ApprovalDecision::ApproveAlways {
scope: ApprovalScope::ThisTool,
},
)
.unwrap();
let _ = rx.await;
assert!(gate.is_runtime_approved("bash"));
}
#[tokio::test]
async fn test_approve_always_this_tool_scoped_to_tool_name() {
let gate = ApprovalGate::new();
let rx = gate.request("call_1", "bash", serde_json::json!({}));
gate.deliver(
"call_1",
ApprovalDecision::ApproveAlways {
scope: ApprovalScope::ThisTool,
},
)
.unwrap();
let _ = rx.await;
assert!(gate.is_runtime_approved("bash"));
assert!(!gate.is_runtime_approved("write_file"));
assert!(!gate.is_runtime_approved("read_file"));
assert!(!gate.is_runtime_approved("edit_file"));
}
#[tokio::test]
async fn test_approve_always_this_session_approves_all_tools() {
let gate = ApprovalGate::new();
let rx = gate.request("call_1", "bash", serde_json::json!({}));
gate.deliver(
"call_1",
ApprovalDecision::ApproveAlways {
scope: ApprovalScope::ThisSession,
},
)
.unwrap();
let _ = rx.await;
assert!(gate.is_runtime_approved("bash"));
assert!(gate.is_runtime_approved("write_file"));
assert!(gate.is_runtime_approved("read_file"));
assert!(gate.is_runtime_approved("any_tool"));
}
#[tokio::test]
async fn test_approve_always_this_call_no_persistent_override() {
let gate = ApprovalGate::new();
let rx = gate.request("call_1", "bash", serde_json::json!({}));
gate.deliver(
"call_1",
ApprovalDecision::ApproveAlways {
scope: ApprovalScope::ThisCall,
},
)
.unwrap();
let _ = rx.await;
assert!(!gate.is_runtime_approved("bash"));
}
#[tokio::test]
async fn test_multiple_tool_overrides_accumulate() {
let gate = ApprovalGate::new();
let rx1 = gate.request("call_1", "bash", serde_json::json!({}));
gate.deliver(
"call_1",
ApprovalDecision::ApproveAlways {
scope: ApprovalScope::ThisTool,
},
)
.unwrap();
let _ = rx1.await;
let rx2 = gate.request("call_2", "write_file", serde_json::json!({}));
gate.deliver(
"call_2",
ApprovalDecision::ApproveAlways {
scope: ApprovalScope::ThisTool,
},
)
.unwrap();
let _ = rx2.await;
assert!(gate.is_runtime_approved("bash"));
assert!(gate.is_runtime_approved("write_file"));
assert!(!gate.is_runtime_approved("read_file"));
}
#[test]
fn test_reset_overrides_clears_all() {
let gate = ApprovalGate::new();
gate.overrides
.write()
.approved_tools
.insert("bash".to_string());
gate.overrides.write().approve_all = true;
assert!(gate.is_runtime_approved("bash"));
assert!(gate.is_runtime_approved("anything"));
gate.reset_overrides();
assert!(!gate.is_runtime_approved("bash"));
assert!(!gate.is_runtime_approved("anything"));
}
#[test]
fn test_no_runtime_overrides_initially() {
let gate = ApprovalGate::new();
assert!(!gate.is_runtime_approved("bash"));
assert!(!gate.is_runtime_approved("read_file"));
assert!(!gate.is_runtime_approved("write_file"));
}
#[tokio::test]
async fn test_multiple_pending_approvals_independent() {
let gate = Arc::new(ApprovalGate::new());
let rx_a = gate.request("call_a", "bash", serde_json::json!({}));
let rx_b = gate.request("call_b", "write_file", serde_json::json!({}));
assert_eq!(gate.pending_count(), 2);
gate.deliver("call_a", ApprovalDecision::Approve).unwrap();
gate.deliver("call_b", ApprovalDecision::Deny).unwrap();
assert_eq!(rx_a.await.unwrap(), ApprovalDecision::Approve);
assert_eq!(rx_b.await.unwrap(), ApprovalDecision::Deny);
assert_eq!(gate.pending_count(), 0);
}
#[tokio::test]
async fn test_one_timeout_one_approved() {
let gate = Arc::new(ApprovalGate::new());
let rx_a = gate.request("call_a", "bash", serde_json::json!({}));
let rx_b = gate.request("call_b", "read_file", serde_json::json!({}));
gate.deliver("call_a", ApprovalDecision::Approve).unwrap();
assert_eq!(rx_a.await.unwrap(), ApprovalDecision::Approve);
let result_b = tokio::time::timeout(Duration::from_millis(50), rx_b).await;
assert!(result_b.is_err(), "B should have timed out");
let result = gate.deliver("call_b", ApprovalDecision::Approve);
assert!(matches!(result.unwrap_err(), ApprovalError::Expired { .. }));
}
#[tokio::test]
async fn test_concurrent_delivery_from_multiple_tasks() {
let gate = Arc::new(ApprovalGate::new());
let rx1 = gate.request("call_1", "bash", serde_json::json!({}));
let rx2 = gate.request("call_2", "read_file", serde_json::json!({}));
let rx3 = gate.request("call_3", "write_file", serde_json::json!({}));
let g1 = Arc::clone(&gate);
let g2 = Arc::clone(&gate);
let g3 = Arc::clone(&gate);
let (r1, r2, r3) = tokio::join!(
tokio::spawn(async move { g1.deliver("call_1", ApprovalDecision::Approve) }),
tokio::spawn(async move { g2.deliver("call_2", ApprovalDecision::Deny) }),
tokio::spawn(async move {
g3.deliver(
"call_3",
ApprovalDecision::ApproveAlways {
scope: ApprovalScope::ThisTool,
},
)
}),
);
assert!(r1.unwrap().is_ok());
assert!(r2.unwrap().is_ok());
assert!(r3.unwrap().is_ok());
assert_eq!(rx1.await.unwrap(), ApprovalDecision::Approve);
assert_eq!(rx2.await.unwrap(), ApprovalDecision::Deny);
assert_eq!(
rx3.await.unwrap(),
ApprovalDecision::ApproveAlways {
scope: ApprovalScope::ThisTool,
}
);
assert!(gate.is_runtime_approved("write_file"));
assert!(!gate.is_runtime_approved("bash"));
assert!(!gate.is_runtime_approved("read_file"));
}
#[test]
fn test_oneshot_consumed_once() {
let gate = ApprovalGate::new();
let _rx = gate.request("call_1", "bash", serde_json::json!({}));
let first = gate.deliver("call_1", ApprovalDecision::Approve);
assert!(first.is_ok());
let second = gate.deliver("call_1", ApprovalDecision::Approve);
assert!(second.is_err());
assert!(matches!(
second.unwrap_err(),
ApprovalError::NotFound { call_id } if call_id == "call_1"
));
}
#[test]
fn test_pending_list_reflects_state() {
let gate = ApprovalGate::new();
assert!(gate.pending().is_empty());
let _rx1 = gate.request("call_1", "bash", serde_json::json!({}));
assert_eq!(gate.pending().len(), 1);
let _rx2 = gate.request("call_2", "read_file", serde_json::json!({}));
assert_eq!(gate.pending().len(), 2);
gate.deliver("call_1", ApprovalDecision::Approve).unwrap();
assert_eq!(gate.pending().len(), 1);
let remaining = gate.pending();
assert_eq!(remaining[0].call_id, "call_2");
}
#[test]
fn test_pending_info_contains_arguments() {
let gate = ApprovalGate::new();
let args = serde_json::json!({"command": "git status", "timeout_secs": 30});
let _rx = gate.request("call_1", "bash", args.clone());
let pending = gate.pending();
assert_eq!(pending.len(), 1);
assert_eq!(pending[0].arguments, args);
}
#[test]
fn test_debug_format() {
let gate = ApprovalGate::new();
let debug = format!("{gate:?}");
assert!(debug.contains("ApprovalGate"));
assert!(debug.contains("pending_count"));
}
#[test]
fn test_default_gate() {
let gate = ApprovalGate::default();
assert_eq!(gate.pending_count(), 0);
assert!(!gate.is_runtime_approved("anything"));
}
mod proptest_approval {
use super::*;
use proptest::prelude::*;
fn approval_decision_strategy() -> impl Strategy<Value = ApprovalDecision> {
prop_oneof![
Just(ApprovalDecision::Approve),
Just(ApprovalDecision::Deny),
prop_oneof![
Just(ApprovalScope::ThisCall),
Just(ApprovalScope::ThisTool),
Just(ApprovalScope::ThisSession),
]
.prop_map(|scope| ApprovalDecision::ApproveAlways { scope }),
]
}
fn tool_name_strategy() -> impl Strategy<Value = String> {
prop_oneof![
Just("bash".to_string()),
Just("read_file".to_string()),
Just("write_file".to_string()),
Just("edit_file".to_string()),
Just("list_files".to_string()),
Just("search_files".to_string()),
Just("claude_code".to_string()),
]
}
proptest! {
#[test]
fn prop_oneshot_consumed_once(
decision in approval_decision_strategy(),
) {
let gate = ApprovalGate::new();
let _rx = gate.request("call_x", "bash", serde_json::json!({}));
let first = gate.deliver("call_x", decision.clone());
prop_assert!(first.is_ok());
let second = gate.deliver("call_x", decision);
prop_assert!(second.is_err());
let is_not_found = matches!(
second.unwrap_err(),
ApprovalError::NotFound { .. }
);
prop_assert!(is_not_found, "expected NotFound after consumption");
}
#[test]
fn prop_approve_always_this_tool_takes_effect(
tool_name in tool_name_strategy(),
) {
let gate = ApprovalGate::new();
let _rx = gate.request("call_x", &tool_name, serde_json::json!({}));
prop_assert!(!gate.is_runtime_approved(&tool_name));
gate.deliver(
"call_x",
ApprovalDecision::ApproveAlways {
scope: ApprovalScope::ThisTool,
},
)
.unwrap();
prop_assert!(gate.is_runtime_approved(&tool_name));
}
#[test]
fn prop_approve_always_scoped_to_tool(
approved in tool_name_strategy(),
other in tool_name_strategy(),
) {
prop_assume!(approved != other);
let gate = ApprovalGate::new();
let _rx = gate.request("call_x", &approved, serde_json::json!({}));
gate.deliver(
"call_x",
ApprovalDecision::ApproveAlways {
scope: ApprovalScope::ThisTool,
},
)
.unwrap();
prop_assert!(gate.is_runtime_approved(&approved));
prop_assert!(!gate.is_runtime_approved(&other));
}
#[test]
fn prop_deny_delivers_correctly(
tool_name in tool_name_strategy(),
) {
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.expect("tokio runtime");
rt.block_on(async {
let gate = ApprovalGate::new();
let rx = gate.request("call_x", &tool_name, serde_json::json!({}));
gate.deliver("call_x", ApprovalDecision::Deny).unwrap();
let decision = rx.await.unwrap();
assert_eq!(decision, ApprovalDecision::Deny);
});
}
#[test]
fn prop_approve_delivers_correctly(
tool_name in tool_name_strategy(),
) {
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.expect("tokio runtime");
rt.block_on(async {
let gate = ApprovalGate::new();
let rx = gate.request("call_x", &tool_name, serde_json::json!({}));
gate.deliver("call_x", ApprovalDecision::Approve).unwrap();
let decision = rx.await.unwrap();
assert_eq!(decision, ApprovalDecision::Approve);
});
}
}
}
}