use crate::agent::AgentEvent;
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::{broadcast, oneshot, RwLock};
pub use crate::queue::SessionLane;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
pub enum TimeoutAction {
#[default]
Reject,
AutoApprove,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ConfirmationPolicy {
pub enabled: bool,
pub default_timeout_ms: u64,
pub timeout_action: TimeoutAction,
pub yolo_lanes: HashSet<SessionLane>,
}
impl Default for ConfirmationPolicy {
fn default() -> Self {
Self {
enabled: false, default_timeout_ms: 30_000, timeout_action: TimeoutAction::Reject,
yolo_lanes: HashSet::new(), }
}
}
impl ConfirmationPolicy {
pub fn enabled() -> Self {
Self {
enabled: true,
..Default::default()
}
}
pub fn with_yolo_lanes(mut self, lanes: impl IntoIterator<Item = SessionLane>) -> Self {
self.yolo_lanes = lanes.into_iter().collect();
self
}
pub fn with_timeout(mut self, timeout_ms: u64, action: TimeoutAction) -> Self {
self.default_timeout_ms = timeout_ms;
self.timeout_action = action;
self
}
pub fn is_yolo(&self, tool_name: &str) -> bool {
if !self.enabled {
return true; }
let lane = SessionLane::from_tool_name(tool_name);
self.yolo_lanes.contains(&lane)
}
pub fn requires_confirmation(&self, tool_name: &str) -> bool {
!self.is_yolo(tool_name)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ConfirmationResponse {
pub approved: bool,
pub reason: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PendingConfirmationInfo {
pub tool_id: String,
pub tool_name: String,
pub args: serde_json::Value,
pub remaining_ms: u64,
}
#[async_trait::async_trait]
pub trait ConfirmationProvider: Send + Sync {
async fn requires_confirmation(&self, tool_name: &str) -> bool;
async fn request_confirmation(
&self,
tool_id: &str,
tool_name: &str,
args: &serde_json::Value,
) -> oneshot::Receiver<ConfirmationResponse>;
async fn confirm(
&self,
tool_id: &str,
approved: bool,
reason: Option<String>,
) -> Result<bool, String>;
async fn policy(&self) -> ConfirmationPolicy;
async fn set_policy(&self, policy: ConfirmationPolicy);
async fn check_timeouts(&self) -> usize;
async fn cancel_all(&self) -> usize;
async fn pending_confirmations(&self) -> Vec<PendingConfirmationInfo> {
Vec::new()
}
}
pub struct PendingConfirmation {
pub tool_id: String,
pub tool_name: String,
pub args: serde_json::Value,
pub created_at: Instant,
pub timeout_ms: u64,
response_tx: oneshot::Sender<ConfirmationResponse>,
}
impl PendingConfirmation {
pub fn is_timed_out(&self) -> bool {
self.created_at.elapsed() > Duration::from_millis(self.timeout_ms)
}
pub fn remaining_ms(&self) -> u64 {
let elapsed = self.created_at.elapsed().as_millis() as u64;
self.timeout_ms.saturating_sub(elapsed)
}
}
pub struct ConfirmationManager {
policy: RwLock<ConfirmationPolicy>,
pending: Arc<RwLock<HashMap<String, PendingConfirmation>>>,
event_tx: broadcast::Sender<AgentEvent>,
}
impl ConfirmationManager {
pub fn new(policy: ConfirmationPolicy, event_tx: broadcast::Sender<AgentEvent>) -> Self {
Self {
policy: RwLock::new(policy),
pending: Arc::new(RwLock::new(HashMap::new())),
event_tx,
}
}
pub async fn policy(&self) -> ConfirmationPolicy {
self.policy.read().await.clone()
}
pub async fn set_policy(&self, policy: ConfirmationPolicy) {
*self.policy.write().await = policy;
}
pub async fn requires_confirmation(&self, tool_name: &str) -> bool {
self.policy.read().await.requires_confirmation(tool_name)
}
pub async fn request_confirmation(
&self,
tool_id: &str,
tool_name: &str,
args: &serde_json::Value,
) -> oneshot::Receiver<ConfirmationResponse> {
let (tx, rx) = oneshot::channel();
let policy = self.policy.read().await;
let timeout_ms = policy.default_timeout_ms;
drop(policy);
let pending = PendingConfirmation {
tool_id: tool_id.to_string(),
tool_name: tool_name.to_string(),
args: args.clone(),
created_at: Instant::now(),
timeout_ms,
response_tx: tx,
};
{
let mut pending_map = self.pending.write().await;
pending_map.insert(tool_id.to_string(), pending);
}
let _ = self.event_tx.send(AgentEvent::ConfirmationRequired {
tool_id: tool_id.to_string(),
tool_name: tool_name.to_string(),
args: args.clone(),
timeout_ms,
});
rx
}
pub async fn confirm(
&self,
tool_id: &str,
approved: bool,
reason: Option<String>,
) -> Result<bool, String> {
let pending = {
let mut pending_map = self.pending.write().await;
pending_map.remove(tool_id)
};
if let Some(confirmation) = pending {
let _ = self.event_tx.send(AgentEvent::ConfirmationReceived {
tool_id: tool_id.to_string(),
approved,
reason: reason.clone(),
});
let response = ConfirmationResponse { approved, reason };
let _ = confirmation.response_tx.send(response);
Ok(true)
} else {
Ok(false)
}
}
pub async fn check_timeouts(&self) -> usize {
let policy = self.policy.read().await;
let timeout_action = policy.timeout_action;
drop(policy);
let mut timed_out = Vec::new();
{
let pending_map = self.pending.read().await;
for (tool_id, pending) in pending_map.iter() {
if pending.is_timed_out() {
timed_out.push(tool_id.clone());
}
}
}
for tool_id in &timed_out {
let pending = {
let mut pending_map = self.pending.write().await;
pending_map.remove(tool_id)
};
if let Some(confirmation) = pending {
let (approved, action_taken) = match timeout_action {
TimeoutAction::Reject => (false, "rejected"),
TimeoutAction::AutoApprove => (true, "auto_approved"),
};
let _ = self.event_tx.send(AgentEvent::ConfirmationTimeout {
tool_id: tool_id.clone(),
action_taken: action_taken.to_string(),
});
let response = ConfirmationResponse {
approved,
reason: Some(format!("Confirmation timed out, action: {}", action_taken)),
};
let _ = confirmation.response_tx.send(response);
}
}
timed_out.len()
}
pub async fn pending_count(&self) -> usize {
self.pending.read().await.len()
}
pub async fn pending_confirmations(&self) -> Vec<(String, String, u64)> {
let pending_map = self.pending.read().await;
pending_map
.values()
.map(|p| (p.tool_id.clone(), p.tool_name.clone(), p.remaining_ms()))
.collect()
}
pub async fn pending_confirmation_details(&self) -> Vec<PendingConfirmationInfo> {
let pending_map = self.pending.read().await;
pending_map
.values()
.map(|p| PendingConfirmationInfo {
tool_id: p.tool_id.clone(),
tool_name: p.tool_name.clone(),
args: p.args.clone(),
remaining_ms: p.remaining_ms(),
})
.collect()
}
pub async fn cancel(&self, tool_id: &str) -> bool {
let pending = {
let mut pending_map = self.pending.write().await;
pending_map.remove(tool_id)
};
if let Some(confirmation) = pending {
let response = ConfirmationResponse {
approved: false,
reason: Some("Confirmation cancelled".to_string()),
};
let _ = confirmation.response_tx.send(response);
true
} else {
false
}
}
pub async fn cancel_all(&self) -> usize {
let pending_list: Vec<_> = {
let mut pending_map = self.pending.write().await;
pending_map.drain().collect()
};
let count = pending_list.len();
for (_, confirmation) in pending_list {
let response = ConfirmationResponse {
approved: false,
reason: Some("Confirmation cancelled".to_string()),
};
let _ = confirmation.response_tx.send(response);
}
count
}
}
#[async_trait::async_trait]
impl ConfirmationProvider for ConfirmationManager {
async fn requires_confirmation(&self, tool_name: &str) -> bool {
self.requires_confirmation(tool_name).await
}
async fn request_confirmation(
&self,
tool_id: &str,
tool_name: &str,
args: &serde_json::Value,
) -> oneshot::Receiver<ConfirmationResponse> {
self.request_confirmation(tool_id, tool_name, args).await
}
async fn confirm(
&self,
tool_id: &str,
approved: bool,
reason: Option<String>,
) -> Result<bool, String> {
self.confirm(tool_id, approved, reason).await
}
async fn policy(&self) -> ConfirmationPolicy {
self.policy().await
}
async fn set_policy(&self, policy: ConfirmationPolicy) {
self.set_policy(policy).await
}
async fn check_timeouts(&self) -> usize {
self.check_timeouts().await
}
async fn cancel_all(&self) -> usize {
self.cancel_all().await
}
async fn pending_confirmations(&self) -> Vec<PendingConfirmationInfo> {
self.pending_confirmation_details().await
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_session_lane() {
assert_eq!(SessionLane::from_tool_name("read"), SessionLane::Query);
assert_eq!(SessionLane::from_tool_name("grep"), SessionLane::Query);
assert_eq!(SessionLane::from_tool_name("bash"), SessionLane::Execute);
assert_eq!(SessionLane::from_tool_name("write"), SessionLane::Execute);
}
#[test]
fn test_session_lane_priority() {
assert_eq!(SessionLane::Control.priority(), 0);
assert_eq!(SessionLane::Query.priority(), 1);
assert_eq!(SessionLane::Execute.priority(), 2);
assert_eq!(SessionLane::Generate.priority(), 3);
assert!(SessionLane::Control.priority() < SessionLane::Query.priority());
assert!(SessionLane::Query.priority() < SessionLane::Execute.priority());
assert!(SessionLane::Execute.priority() < SessionLane::Generate.priority());
}
#[test]
fn test_session_lane_all_query() {
let query_tools = ["read", "glob", "ls", "grep", "list_files", "search"];
for tool in query_tools {
assert_eq!(
SessionLane::from_tool_name(tool),
SessionLane::Query,
"Tool '{}' should be in Query lane",
tool
);
}
}
#[test]
fn test_session_lane_all_execute() {
let execute_tools = ["bash", "write", "edit", "delete", "move", "copy", "execute"];
for tool in execute_tools {
assert_eq!(
SessionLane::from_tool_name(tool),
SessionLane::Execute,
"Tool '{}' should be in Execute lane",
tool
);
}
}
#[test]
fn test_confirmation_policy_default() {
let policy = ConfirmationPolicy::default();
assert!(!policy.enabled);
assert!(!policy.requires_confirmation("bash"));
assert!(!policy.requires_confirmation("write"));
assert!(!policy.requires_confirmation("read"));
}
#[test]
fn test_confirmation_policy_enabled() {
let policy = ConfirmationPolicy::enabled();
assert!(policy.enabled);
assert!(policy.requires_confirmation("bash"));
assert!(policy.requires_confirmation("write"));
assert!(policy.requires_confirmation("read"));
assert!(policy.requires_confirmation("grep"));
}
#[test]
fn test_confirmation_policy_yolo_mode() {
let policy = ConfirmationPolicy::enabled().with_yolo_lanes([SessionLane::Execute]);
assert!(!policy.requires_confirmation("bash")); assert!(!policy.requires_confirmation("write")); assert!(policy.requires_confirmation("read")); }
#[test]
fn test_confirmation_policy_yolo_multiple_lanes() {
let policy = ConfirmationPolicy::enabled()
.with_yolo_lanes([SessionLane::Query, SessionLane::Execute]);
assert!(!policy.requires_confirmation("bash")); assert!(!policy.requires_confirmation("read")); assert!(!policy.requires_confirmation("grep")); }
#[test]
fn test_confirmation_policy_is_yolo() {
let policy = ConfirmationPolicy::enabled().with_yolo_lanes([SessionLane::Execute]);
assert!(policy.is_yolo("bash")); assert!(policy.is_yolo("write")); assert!(!policy.is_yolo("read")); }
#[test]
fn test_confirmation_policy_disabled_is_always_yolo() {
let policy = ConfirmationPolicy::default(); assert!(policy.is_yolo("bash"));
assert!(policy.is_yolo("read"));
assert!(policy.is_yolo("unknown_tool"));
}
#[test]
fn test_confirmation_policy_with_timeout() {
let policy = ConfirmationPolicy::enabled().with_timeout(5000, TimeoutAction::AutoApprove);
assert_eq!(policy.default_timeout_ms, 5000);
assert_eq!(policy.timeout_action, TimeoutAction::AutoApprove);
}
#[tokio::test]
async fn test_confirmation_manager_no_hitl() {
let (event_tx, _) = broadcast::channel(100);
let manager = ConfirmationManager::new(ConfirmationPolicy::default(), event_tx);
assert!(!manager.requires_confirmation("bash").await);
}
#[tokio::test]
async fn test_confirmation_manager_with_hitl() {
let (event_tx, _) = broadcast::channel(100);
let manager = ConfirmationManager::new(ConfirmationPolicy::enabled(), event_tx);
assert!(manager.requires_confirmation("bash").await);
assert!(manager.requires_confirmation("read").await);
}
#[tokio::test]
async fn test_confirmation_manager_with_yolo() {
let (event_tx, _) = broadcast::channel(100);
let policy = ConfirmationPolicy::enabled().with_yolo_lanes([SessionLane::Query]);
let manager = ConfirmationManager::new(policy, event_tx);
assert!(manager.requires_confirmation("bash").await); assert!(!manager.requires_confirmation("read").await); }
#[tokio::test]
async fn test_confirmation_manager_policy_update() {
let (event_tx, _) = broadcast::channel(100);
let manager = ConfirmationManager::new(ConfirmationPolicy::default(), event_tx);
assert!(!manager.requires_confirmation("bash").await);
manager.set_policy(ConfirmationPolicy::enabled()).await;
assert!(manager.requires_confirmation("bash").await);
manager
.set_policy(ConfirmationPolicy::enabled().with_yolo_lanes([SessionLane::Execute]))
.await;
assert!(!manager.requires_confirmation("bash").await);
}
#[tokio::test]
async fn test_confirmation_flow_approve() {
let (event_tx, mut event_rx) = broadcast::channel(100);
let manager = ConfirmationManager::new(ConfirmationPolicy::enabled(), event_tx);
let rx = manager
.request_confirmation("tool-1", "bash", &serde_json::json!({"command": "ls"}))
.await;
let event = event_rx.recv().await.unwrap();
match event {
AgentEvent::ConfirmationRequired {
tool_id,
tool_name,
timeout_ms,
..
} => {
assert_eq!(tool_id, "tool-1");
assert_eq!(tool_name, "bash");
assert_eq!(timeout_ms, 30_000); }
_ => panic!("Expected ConfirmationRequired event"),
}
let result = manager.confirm("tool-1", true, None).await;
assert!(result.is_ok());
assert!(result.unwrap());
let event = event_rx.recv().await.unwrap();
match event {
AgentEvent::ConfirmationReceived {
tool_id, approved, ..
} => {
assert_eq!(tool_id, "tool-1");
assert!(approved);
}
_ => panic!("Expected ConfirmationReceived event"),
}
let response = rx.await.unwrap();
assert!(response.approved);
assert!(response.reason.is_none());
}
#[tokio::test]
async fn test_confirmation_flow_reject() {
let (event_tx, mut event_rx) = broadcast::channel(100);
let manager = ConfirmationManager::new(ConfirmationPolicy::enabled(), event_tx);
let rx = manager
.request_confirmation(
"tool-1",
"bash",
&serde_json::json!({"command": "rm -rf /"}),
)
.await;
let _ = event_rx.recv().await.unwrap();
let result = manager
.confirm("tool-1", false, Some("Dangerous command".to_string()))
.await;
assert!(result.is_ok());
assert!(result.unwrap());
let event = event_rx.recv().await.unwrap();
match event {
AgentEvent::ConfirmationReceived {
tool_id,
approved,
reason,
} => {
assert_eq!(tool_id, "tool-1");
assert!(!approved);
assert_eq!(reason, Some("Dangerous command".to_string()));
}
_ => panic!("Expected ConfirmationReceived event"),
}
let response = rx.await.unwrap();
assert!(!response.approved);
assert_eq!(response.reason, Some("Dangerous command".to_string()));
}
#[tokio::test]
async fn test_confirmation_not_found() {
let (event_tx, _) = broadcast::channel(100);
let manager = ConfirmationManager::new(ConfirmationPolicy::enabled(), event_tx);
let result = manager.confirm("non-existent", true, None).await;
assert!(result.is_ok());
assert!(!result.unwrap()); }
#[tokio::test]
async fn test_multiple_confirmations() {
let (event_tx, _) = broadcast::channel(100);
let manager = ConfirmationManager::new(ConfirmationPolicy::enabled(), event_tx);
let rx1 = manager
.request_confirmation("tool-1", "bash", &serde_json::json!({"cmd": "1"}))
.await;
let rx2 = manager
.request_confirmation("tool-2", "write", &serde_json::json!({"cmd": "2"}))
.await;
let rx3 = manager
.request_confirmation("tool-3", "edit", &serde_json::json!({"cmd": "3"}))
.await;
assert_eq!(manager.pending_count().await, 3);
manager.confirm("tool-1", true, None).await.unwrap();
let response1 = rx1.await.unwrap();
assert!(response1.approved);
manager.confirm("tool-2", false, None).await.unwrap();
let response2 = rx2.await.unwrap();
assert!(!response2.approved);
manager.confirm("tool-3", true, None).await.unwrap();
let response3 = rx3.await.unwrap();
assert!(response3.approved);
assert_eq!(manager.pending_count().await, 0);
}
#[tokio::test]
async fn test_pending_confirmations_info() {
let (event_tx, _) = broadcast::channel(100);
let manager = ConfirmationManager::new(ConfirmationPolicy::enabled(), event_tx);
let _rx1 = manager
.request_confirmation("tool-1", "bash", &serde_json::json!({}))
.await;
let _rx2 = manager
.request_confirmation("tool-2", "write", &serde_json::json!({}))
.await;
let pending = manager.pending_confirmations().await;
assert_eq!(pending.len(), 2);
let tool_ids: Vec<&str> = pending.iter().map(|(id, _, _)| id.as_str()).collect();
assert!(tool_ids.contains(&"tool-1"));
assert!(tool_ids.contains(&"tool-2"));
}
#[tokio::test]
async fn test_cancel_confirmation() {
let (event_tx, _) = broadcast::channel(100);
let manager = ConfirmationManager::new(ConfirmationPolicy::enabled(), event_tx);
let rx = manager
.request_confirmation("tool-1", "bash", &serde_json::json!({}))
.await;
assert_eq!(manager.pending_count().await, 1);
let cancelled = manager.cancel("tool-1").await;
assert!(cancelled);
assert_eq!(manager.pending_count().await, 0);
let response = rx.await.unwrap();
assert!(!response.approved);
assert_eq!(response.reason, Some("Confirmation cancelled".to_string()));
}
#[tokio::test]
async fn test_cancel_nonexistent() {
let (event_tx, _) = broadcast::channel(100);
let manager = ConfirmationManager::new(ConfirmationPolicy::enabled(), event_tx);
let cancelled = manager.cancel("non-existent").await;
assert!(!cancelled);
}
#[tokio::test]
async fn test_cancel_all() {
let (event_tx, _) = broadcast::channel(100);
let manager = ConfirmationManager::new(ConfirmationPolicy::enabled(), event_tx);
let rx1 = manager
.request_confirmation("tool-1", "bash", &serde_json::json!({}))
.await;
let rx2 = manager
.request_confirmation("tool-2", "write", &serde_json::json!({}))
.await;
let rx3 = manager
.request_confirmation("tool-3", "edit", &serde_json::json!({}))
.await;
assert_eq!(manager.pending_count().await, 3);
let cancelled_count = manager.cancel_all().await;
assert_eq!(cancelled_count, 3);
assert_eq!(manager.pending_count().await, 0);
for rx in [rx1, rx2, rx3] {
let response = rx.await.unwrap();
assert!(!response.approved);
assert_eq!(response.reason, Some("Confirmation cancelled".to_string()));
}
}
#[tokio::test]
async fn test_timeout_reject() {
let (event_tx, mut event_rx) = broadcast::channel(100);
let policy = ConfirmationPolicy {
enabled: true,
default_timeout_ms: 50, timeout_action: TimeoutAction::Reject,
..Default::default()
};
let manager = ConfirmationManager::new(policy, event_tx);
let rx = manager
.request_confirmation("tool-1", "bash", &serde_json::json!({}))
.await;
let _ = event_rx.recv().await.unwrap();
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
let timed_out = manager.check_timeouts().await;
assert_eq!(timed_out, 1);
let event = event_rx.recv().await.unwrap();
match event {
AgentEvent::ConfirmationTimeout {
tool_id,
action_taken,
} => {
assert_eq!(tool_id, "tool-1");
assert_eq!(action_taken, "rejected");
}
_ => panic!("Expected ConfirmationTimeout event"),
}
let response = rx.await.unwrap();
assert!(!response.approved);
assert!(response.reason.as_ref().unwrap().contains("timed out"));
}
#[tokio::test]
async fn test_timeout_auto_approve() {
let (event_tx, mut event_rx) = broadcast::channel(100);
let policy = ConfirmationPolicy {
enabled: true,
default_timeout_ms: 50, timeout_action: TimeoutAction::AutoApprove,
..Default::default()
};
let manager = ConfirmationManager::new(policy, event_tx);
let rx = manager
.request_confirmation("tool-1", "bash", &serde_json::json!({}))
.await;
let _ = event_rx.recv().await.unwrap();
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
let timed_out = manager.check_timeouts().await;
assert_eq!(timed_out, 1);
let event = event_rx.recv().await.unwrap();
match event {
AgentEvent::ConfirmationTimeout {
tool_id,
action_taken,
} => {
assert_eq!(tool_id, "tool-1");
assert_eq!(action_taken, "auto_approved");
}
_ => panic!("Expected ConfirmationTimeout event"),
}
let response = rx.await.unwrap();
assert!(response.approved);
assert!(response.reason.as_ref().unwrap().contains("auto_approved"));
}
#[tokio::test]
async fn test_no_timeout_when_confirmed() {
let (event_tx, _) = broadcast::channel(100);
let policy = ConfirmationPolicy {
enabled: true,
default_timeout_ms: 50,
timeout_action: TimeoutAction::Reject,
..Default::default()
};
let manager = ConfirmationManager::new(policy, event_tx);
let rx = manager
.request_confirmation("tool-1", "bash", &serde_json::json!({}))
.await;
manager.confirm("tool-1", true, None).await.unwrap();
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
let timed_out = manager.check_timeouts().await;
assert_eq!(timed_out, 0);
let response = rx.await.unwrap();
assert!(response.approved);
assert!(response.reason.is_none());
}
}