1use crate::agent::AgentEvent;
11use serde::{Deserialize, Serialize};
12use std::collections::{HashMap, HashSet};
13use std::sync::Arc;
14use std::time::{Duration, Instant};
15use tokio::sync::{broadcast, oneshot, RwLock};
16
17pub use crate::queue::SessionLane;
19
20#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
22pub enum TimeoutAction {
23 #[default]
25 Reject,
26 AutoApprove,
28}
29
30#[derive(Debug, Clone, Serialize, Deserialize)]
36pub struct ConfirmationPolicy {
37 pub enabled: bool,
39
40 pub default_timeout_ms: u64,
42
43 pub timeout_action: TimeoutAction,
45
46 pub yolo_lanes: HashSet<SessionLane>,
50}
51
52impl Default for ConfirmationPolicy {
53 fn default() -> Self {
54 Self {
55 enabled: false, default_timeout_ms: 30_000, timeout_action: TimeoutAction::Reject,
58 yolo_lanes: HashSet::new(), }
60 }
61}
62
63impl ConfirmationPolicy {
64 pub fn enabled() -> Self {
66 Self {
67 enabled: true,
68 ..Default::default()
69 }
70 }
71
72 pub fn with_yolo_lanes(mut self, lanes: impl IntoIterator<Item = SessionLane>) -> Self {
74 self.yolo_lanes = lanes.into_iter().collect();
75 self
76 }
77
78 pub fn with_timeout(mut self, timeout_ms: u64, action: TimeoutAction) -> Self {
80 self.default_timeout_ms = timeout_ms;
81 self.timeout_action = action;
82 self
83 }
84
85 pub fn is_yolo(&self, tool_name: &str) -> bool {
90 if !self.enabled {
91 return true; }
93 let lane = SessionLane::from_tool_name(tool_name);
94 self.yolo_lanes.contains(&lane)
95 }
96
97 pub fn requires_confirmation(&self, tool_name: &str) -> bool {
102 !self.is_yolo(tool_name)
103 }
104}
105
106#[derive(Debug, Clone, Serialize, Deserialize)]
108pub struct ConfirmationResponse {
109 pub approved: bool,
111 pub reason: Option<String>,
113}
114
115pub struct PendingConfirmation {
117 pub tool_id: String,
119 pub tool_name: String,
121 pub args: serde_json::Value,
123 pub created_at: Instant,
125 pub timeout_ms: u64,
127 response_tx: oneshot::Sender<ConfirmationResponse>,
129}
130
131impl PendingConfirmation {
132 pub fn is_timed_out(&self) -> bool {
134 self.created_at.elapsed() > Duration::from_millis(self.timeout_ms)
135 }
136
137 pub fn remaining_ms(&self) -> u64 {
139 let elapsed = self.created_at.elapsed().as_millis() as u64;
140 self.timeout_ms.saturating_sub(elapsed)
141 }
142}
143
144pub struct ConfirmationManager {
146 policy: RwLock<ConfirmationPolicy>,
148 pending: Arc<RwLock<HashMap<String, PendingConfirmation>>>,
150 event_tx: broadcast::Sender<AgentEvent>,
152}
153
154impl ConfirmationManager {
155 pub fn new(policy: ConfirmationPolicy, event_tx: broadcast::Sender<AgentEvent>) -> Self {
157 Self {
158 policy: RwLock::new(policy),
159 pending: Arc::new(RwLock::new(HashMap::new())),
160 event_tx,
161 }
162 }
163
164 pub async fn policy(&self) -> ConfirmationPolicy {
166 self.policy.read().await.clone()
167 }
168
169 pub async fn set_policy(&self, policy: ConfirmationPolicy) {
171 *self.policy.write().await = policy;
172 }
173
174 pub async fn requires_confirmation(&self, tool_name: &str) -> bool {
176 self.policy.read().await.requires_confirmation(tool_name)
177 }
178
179 pub async fn request_confirmation(
184 &self,
185 tool_id: &str,
186 tool_name: &str,
187 args: &serde_json::Value,
188 ) -> oneshot::Receiver<ConfirmationResponse> {
189 let (tx, rx) = oneshot::channel();
190
191 let policy = self.policy.read().await;
192 let timeout_ms = policy.default_timeout_ms;
193 drop(policy);
194
195 let pending = PendingConfirmation {
196 tool_id: tool_id.to_string(),
197 tool_name: tool_name.to_string(),
198 args: args.clone(),
199 created_at: Instant::now(),
200 timeout_ms,
201 response_tx: tx,
202 };
203
204 {
206 let mut pending_map = self.pending.write().await;
207 pending_map.insert(tool_id.to_string(), pending);
208 }
209
210 let _ = self.event_tx.send(AgentEvent::ConfirmationRequired {
212 tool_id: tool_id.to_string(),
213 tool_name: tool_name.to_string(),
214 args: args.clone(),
215 timeout_ms,
216 });
217
218 rx
219 }
220
221 pub async fn confirm(
226 &self,
227 tool_id: &str,
228 approved: bool,
229 reason: Option<String>,
230 ) -> Result<bool, String> {
231 let pending = {
232 let mut pending_map = self.pending.write().await;
233 pending_map.remove(tool_id)
234 };
235
236 if let Some(confirmation) = pending {
237 let _ = self.event_tx.send(AgentEvent::ConfirmationReceived {
239 tool_id: tool_id.to_string(),
240 approved,
241 reason: reason.clone(),
242 });
243
244 let response = ConfirmationResponse { approved, reason };
246 let _ = confirmation.response_tx.send(response);
247
248 Ok(true)
249 } else {
250 Ok(false)
251 }
252 }
253
254 pub async fn check_timeouts(&self) -> usize {
258 let policy = self.policy.read().await;
259 let timeout_action = policy.timeout_action;
260 drop(policy);
261
262 let mut timed_out = Vec::new();
263
264 {
266 let pending_map = self.pending.read().await;
267 for (tool_id, pending) in pending_map.iter() {
268 if pending.is_timed_out() {
269 timed_out.push(tool_id.clone());
270 }
271 }
272 }
273
274 for tool_id in &timed_out {
276 let pending = {
277 let mut pending_map = self.pending.write().await;
278 pending_map.remove(tool_id)
279 };
280
281 if let Some(confirmation) = pending {
282 let (approved, action_taken) = match timeout_action {
283 TimeoutAction::Reject => (false, "rejected"),
284 TimeoutAction::AutoApprove => (true, "auto_approved"),
285 };
286
287 let _ = self.event_tx.send(AgentEvent::ConfirmationTimeout {
289 tool_id: tool_id.clone(),
290 action_taken: action_taken.to_string(),
291 });
292
293 let response = ConfirmationResponse {
295 approved,
296 reason: Some(format!("Confirmation timed out, action: {}", action_taken)),
297 };
298 let _ = confirmation.response_tx.send(response);
299 }
300 }
301
302 timed_out.len()
303 }
304
305 pub async fn pending_count(&self) -> usize {
307 self.pending.read().await.len()
308 }
309
310 pub async fn pending_confirmations(&self) -> Vec<(String, String, u64)> {
312 let pending_map = self.pending.read().await;
313 pending_map
314 .values()
315 .map(|p| (p.tool_id.clone(), p.tool_name.clone(), p.remaining_ms()))
316 .collect()
317 }
318
319 pub async fn cancel(&self, tool_id: &str) -> bool {
321 let pending = {
322 let mut pending_map = self.pending.write().await;
323 pending_map.remove(tool_id)
324 };
325
326 if let Some(confirmation) = pending {
327 let response = ConfirmationResponse {
328 approved: false,
329 reason: Some("Confirmation cancelled".to_string()),
330 };
331 let _ = confirmation.response_tx.send(response);
332 true
333 } else {
334 false
335 }
336 }
337
338 pub async fn cancel_all(&self) -> usize {
340 let pending_list: Vec<_> = {
341 let mut pending_map = self.pending.write().await;
342 pending_map.drain().collect()
343 };
344
345 let count = pending_list.len();
346
347 for (_, confirmation) in pending_list {
348 let response = ConfirmationResponse {
349 approved: false,
350 reason: Some("Confirmation cancelled".to_string()),
351 };
352 let _ = confirmation.response_tx.send(response);
353 }
354
355 count
356 }
357}
358
359#[cfg(test)]
360mod tests {
361 use super::*;
362
363 #[test]
368 fn test_session_lane() {
369 assert_eq!(SessionLane::from_tool_name("read"), SessionLane::Query);
370 assert_eq!(SessionLane::from_tool_name("grep"), SessionLane::Query);
371 assert_eq!(SessionLane::from_tool_name("bash"), SessionLane::Execute);
372 assert_eq!(SessionLane::from_tool_name("write"), SessionLane::Execute);
373 }
374
375 #[test]
376 fn test_session_lane_priority() {
377 assert_eq!(SessionLane::Control.priority(), 0);
378 assert_eq!(SessionLane::Query.priority(), 1);
379 assert_eq!(SessionLane::Execute.priority(), 2);
380 assert_eq!(SessionLane::Generate.priority(), 3);
381
382 assert!(SessionLane::Control.priority() < SessionLane::Query.priority());
384 assert!(SessionLane::Query.priority() < SessionLane::Execute.priority());
385 assert!(SessionLane::Execute.priority() < SessionLane::Generate.priority());
386 }
387
388 #[test]
389 fn test_session_lane_all_query() {
390 let query_tools = ["read", "glob", "ls", "grep", "list_files", "search"];
391 for tool in query_tools {
392 assert_eq!(
393 SessionLane::from_tool_name(tool),
394 SessionLane::Query,
395 "Tool '{}' should be in Query lane",
396 tool
397 );
398 }
399 }
400
401 #[test]
402 fn test_session_lane_all_execute() {
403 let execute_tools = ["bash", "write", "edit", "delete", "move", "copy", "execute"];
404 for tool in execute_tools {
405 assert_eq!(
406 SessionLane::from_tool_name(tool),
407 SessionLane::Execute,
408 "Tool '{}' should be in Execute lane",
409 tool
410 );
411 }
412 }
413
414 #[test]
423 fn test_confirmation_policy_default() {
424 let policy = ConfirmationPolicy::default();
425 assert!(!policy.enabled);
426 assert!(!policy.requires_confirmation("bash"));
428 assert!(!policy.requires_confirmation("write"));
429 assert!(!policy.requires_confirmation("read"));
430 }
431
432 #[test]
433 fn test_confirmation_policy_enabled() {
434 let policy = ConfirmationPolicy::enabled();
435 assert!(policy.enabled);
436 assert!(policy.requires_confirmation("bash"));
438 assert!(policy.requires_confirmation("write"));
439 assert!(policy.requires_confirmation("read"));
440 assert!(policy.requires_confirmation("grep"));
441 }
442
443 #[test]
444 fn test_confirmation_policy_yolo_mode() {
445 let policy = ConfirmationPolicy::enabled().with_yolo_lanes([SessionLane::Execute]);
446
447 assert!(!policy.requires_confirmation("bash")); assert!(!policy.requires_confirmation("write")); assert!(policy.requires_confirmation("read")); }
451
452 #[test]
453 fn test_confirmation_policy_yolo_multiple_lanes() {
454 let policy = ConfirmationPolicy::enabled()
455 .with_yolo_lanes([SessionLane::Query, SessionLane::Execute]);
456
457 assert!(!policy.requires_confirmation("bash")); assert!(!policy.requires_confirmation("read")); assert!(!policy.requires_confirmation("grep")); }
462
463 #[test]
464 fn test_confirmation_policy_is_yolo() {
465 let policy = ConfirmationPolicy::enabled().with_yolo_lanes([SessionLane::Execute]);
466
467 assert!(policy.is_yolo("bash")); assert!(policy.is_yolo("write")); assert!(!policy.is_yolo("read")); }
471
472 #[test]
473 fn test_confirmation_policy_disabled_is_always_yolo() {
474 let policy = ConfirmationPolicy::default(); assert!(policy.is_yolo("bash"));
476 assert!(policy.is_yolo("read"));
477 assert!(policy.is_yolo("unknown_tool"));
478 }
479
480 #[test]
481 fn test_confirmation_policy_with_timeout() {
482 let policy = ConfirmationPolicy::enabled().with_timeout(5000, TimeoutAction::AutoApprove);
483
484 assert_eq!(policy.default_timeout_ms, 5000);
485 assert_eq!(policy.timeout_action, TimeoutAction::AutoApprove);
486 }
487
488 #[tokio::test]
493 async fn test_confirmation_manager_no_hitl() {
494 let (event_tx, _) = broadcast::channel(100);
495 let manager = ConfirmationManager::new(ConfirmationPolicy::default(), event_tx);
496
497 assert!(!manager.requires_confirmation("bash").await);
498 }
499
500 #[tokio::test]
501 async fn test_confirmation_manager_with_hitl() {
502 let (event_tx, _) = broadcast::channel(100);
503 let manager = ConfirmationManager::new(ConfirmationPolicy::enabled(), event_tx);
504
505 assert!(manager.requires_confirmation("bash").await);
507 assert!(manager.requires_confirmation("read").await);
508 }
509
510 #[tokio::test]
511 async fn test_confirmation_manager_with_yolo() {
512 let (event_tx, _) = broadcast::channel(100);
513 let policy = ConfirmationPolicy::enabled().with_yolo_lanes([SessionLane::Query]);
514 let manager = ConfirmationManager::new(policy, event_tx);
515
516 assert!(manager.requires_confirmation("bash").await); assert!(!manager.requires_confirmation("read").await); }
519
520 #[tokio::test]
521 async fn test_confirmation_manager_policy_update() {
522 let (event_tx, _) = broadcast::channel(100);
523 let manager = ConfirmationManager::new(ConfirmationPolicy::default(), event_tx);
524
525 assert!(!manager.requires_confirmation("bash").await);
527
528 manager.set_policy(ConfirmationPolicy::enabled()).await;
530 assert!(manager.requires_confirmation("bash").await);
531
532 manager
534 .set_policy(ConfirmationPolicy::enabled().with_yolo_lanes([SessionLane::Execute]))
535 .await;
536 assert!(!manager.requires_confirmation("bash").await);
537 }
538
539 #[tokio::test]
544 async fn test_confirmation_flow_approve() {
545 let (event_tx, mut event_rx) = broadcast::channel(100);
546 let manager = ConfirmationManager::new(ConfirmationPolicy::enabled(), event_tx);
547
548 let rx = manager
550 .request_confirmation("tool-1", "bash", &serde_json::json!({"command": "ls"}))
551 .await;
552
553 let event = event_rx.recv().await.unwrap();
555 match event {
556 AgentEvent::ConfirmationRequired {
557 tool_id,
558 tool_name,
559 timeout_ms,
560 ..
561 } => {
562 assert_eq!(tool_id, "tool-1");
563 assert_eq!(tool_name, "bash");
564 assert_eq!(timeout_ms, 30_000); }
566 _ => panic!("Expected ConfirmationRequired event"),
567 }
568
569 let result = manager.confirm("tool-1", true, None).await;
571 assert!(result.is_ok());
572 assert!(result.unwrap());
573
574 let event = event_rx.recv().await.unwrap();
576 match event {
577 AgentEvent::ConfirmationReceived {
578 tool_id, approved, ..
579 } => {
580 assert_eq!(tool_id, "tool-1");
581 assert!(approved);
582 }
583 _ => panic!("Expected ConfirmationReceived event"),
584 }
585
586 let response = rx.await.unwrap();
588 assert!(response.approved);
589 assert!(response.reason.is_none());
590 }
591
592 #[tokio::test]
593 async fn test_confirmation_flow_reject() {
594 let (event_tx, mut event_rx) = broadcast::channel(100);
595 let manager = ConfirmationManager::new(ConfirmationPolicy::enabled(), event_tx);
596
597 let rx = manager
599 .request_confirmation(
600 "tool-1",
601 "bash",
602 &serde_json::json!({"command": "rm -rf /"}),
603 )
604 .await;
605
606 let _ = event_rx.recv().await.unwrap();
608
609 let result = manager
611 .confirm("tool-1", false, Some("Dangerous command".to_string()))
612 .await;
613 assert!(result.is_ok());
614 assert!(result.unwrap());
615
616 let event = event_rx.recv().await.unwrap();
618 match event {
619 AgentEvent::ConfirmationReceived {
620 tool_id,
621 approved,
622 reason,
623 } => {
624 assert_eq!(tool_id, "tool-1");
625 assert!(!approved);
626 assert_eq!(reason, Some("Dangerous command".to_string()));
627 }
628 _ => panic!("Expected ConfirmationReceived event"),
629 }
630
631 let response = rx.await.unwrap();
633 assert!(!response.approved);
634 assert_eq!(response.reason, Some("Dangerous command".to_string()));
635 }
636
637 #[tokio::test]
638 async fn test_confirmation_not_found() {
639 let (event_tx, _) = broadcast::channel(100);
640 let manager = ConfirmationManager::new(ConfirmationPolicy::enabled(), event_tx);
641
642 let result = manager.confirm("non-existent", true, None).await;
644 assert!(result.is_ok());
645 assert!(!result.unwrap()); }
647
648 #[tokio::test]
653 async fn test_multiple_confirmations() {
654 let (event_tx, _) = broadcast::channel(100);
655 let manager = ConfirmationManager::new(ConfirmationPolicy::enabled(), event_tx);
656
657 let rx1 = manager
659 .request_confirmation("tool-1", "bash", &serde_json::json!({"cmd": "1"}))
660 .await;
661 let rx2 = manager
662 .request_confirmation("tool-2", "write", &serde_json::json!({"cmd": "2"}))
663 .await;
664 let rx3 = manager
665 .request_confirmation("tool-3", "edit", &serde_json::json!({"cmd": "3"}))
666 .await;
667
668 assert_eq!(manager.pending_count().await, 3);
670
671 manager.confirm("tool-1", true, None).await.unwrap();
673 let response1 = rx1.await.unwrap();
674 assert!(response1.approved);
675
676 manager.confirm("tool-2", false, None).await.unwrap();
678 let response2 = rx2.await.unwrap();
679 assert!(!response2.approved);
680
681 manager.confirm("tool-3", true, None).await.unwrap();
683 let response3 = rx3.await.unwrap();
684 assert!(response3.approved);
685
686 assert_eq!(manager.pending_count().await, 0);
688 }
689
690 #[tokio::test]
691 async fn test_pending_confirmations_info() {
692 let (event_tx, _) = broadcast::channel(100);
693 let manager = ConfirmationManager::new(ConfirmationPolicy::enabled(), event_tx);
694
695 let _rx1 = manager
697 .request_confirmation("tool-1", "bash", &serde_json::json!({}))
698 .await;
699 let _rx2 = manager
700 .request_confirmation("tool-2", "write", &serde_json::json!({}))
701 .await;
702
703 let pending = manager.pending_confirmations().await;
704 assert_eq!(pending.len(), 2);
705
706 let tool_ids: Vec<&str> = pending.iter().map(|(id, _, _)| id.as_str()).collect();
708 assert!(tool_ids.contains(&"tool-1"));
709 assert!(tool_ids.contains(&"tool-2"));
710 }
711
712 #[tokio::test]
717 async fn test_cancel_confirmation() {
718 let (event_tx, _) = broadcast::channel(100);
719 let manager = ConfirmationManager::new(ConfirmationPolicy::enabled(), event_tx);
720
721 let rx = manager
723 .request_confirmation("tool-1", "bash", &serde_json::json!({}))
724 .await;
725
726 assert_eq!(manager.pending_count().await, 1);
727
728 let cancelled = manager.cancel("tool-1").await;
730 assert!(cancelled);
731 assert_eq!(manager.pending_count().await, 0);
732
733 let response = rx.await.unwrap();
735 assert!(!response.approved);
736 assert_eq!(response.reason, Some("Confirmation cancelled".to_string()));
737 }
738
739 #[tokio::test]
740 async fn test_cancel_nonexistent() {
741 let (event_tx, _) = broadcast::channel(100);
742 let manager = ConfirmationManager::new(ConfirmationPolicy::enabled(), event_tx);
743
744 let cancelled = manager.cancel("non-existent").await;
745 assert!(!cancelled);
746 }
747
748 #[tokio::test]
749 async fn test_cancel_all() {
750 let (event_tx, _) = broadcast::channel(100);
751 let manager = ConfirmationManager::new(ConfirmationPolicy::enabled(), event_tx);
752
753 let rx1 = manager
755 .request_confirmation("tool-1", "bash", &serde_json::json!({}))
756 .await;
757 let rx2 = manager
758 .request_confirmation("tool-2", "write", &serde_json::json!({}))
759 .await;
760 let rx3 = manager
761 .request_confirmation("tool-3", "edit", &serde_json::json!({}))
762 .await;
763
764 assert_eq!(manager.pending_count().await, 3);
765
766 let cancelled_count = manager.cancel_all().await;
768 assert_eq!(cancelled_count, 3);
769 assert_eq!(manager.pending_count().await, 0);
770
771 for rx in [rx1, rx2, rx3] {
773 let response = rx.await.unwrap();
774 assert!(!response.approved);
775 assert_eq!(response.reason, Some("Confirmation cancelled".to_string()));
776 }
777 }
778
779 #[tokio::test]
784 async fn test_timeout_reject() {
785 let (event_tx, mut event_rx) = broadcast::channel(100);
786 let policy = ConfirmationPolicy {
787 enabled: true,
788 default_timeout_ms: 50, timeout_action: TimeoutAction::Reject,
790 ..Default::default()
791 };
792 let manager = ConfirmationManager::new(policy, event_tx);
793
794 let rx = manager
796 .request_confirmation("tool-1", "bash", &serde_json::json!({}))
797 .await;
798
799 let _ = event_rx.recv().await.unwrap();
801
802 tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
804
805 let timed_out = manager.check_timeouts().await;
807 assert_eq!(timed_out, 1);
808
809 let event = event_rx.recv().await.unwrap();
811 match event {
812 AgentEvent::ConfirmationTimeout {
813 tool_id,
814 action_taken,
815 } => {
816 assert_eq!(tool_id, "tool-1");
817 assert_eq!(action_taken, "rejected");
818 }
819 _ => panic!("Expected ConfirmationTimeout event"),
820 }
821
822 let response = rx.await.unwrap();
824 assert!(!response.approved);
825 assert!(response.reason.as_ref().unwrap().contains("timed out"));
826 }
827
828 #[tokio::test]
829 async fn test_timeout_auto_approve() {
830 let (event_tx, mut event_rx) = broadcast::channel(100);
831 let policy = ConfirmationPolicy {
832 enabled: true,
833 default_timeout_ms: 50, timeout_action: TimeoutAction::AutoApprove,
835 ..Default::default()
836 };
837 let manager = ConfirmationManager::new(policy, event_tx);
838
839 let rx = manager
841 .request_confirmation("tool-1", "bash", &serde_json::json!({}))
842 .await;
843
844 let _ = event_rx.recv().await.unwrap();
846
847 tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
849
850 let timed_out = manager.check_timeouts().await;
852 assert_eq!(timed_out, 1);
853
854 let event = event_rx.recv().await.unwrap();
856 match event {
857 AgentEvent::ConfirmationTimeout {
858 tool_id,
859 action_taken,
860 } => {
861 assert_eq!(tool_id, "tool-1");
862 assert_eq!(action_taken, "auto_approved");
863 }
864 _ => panic!("Expected ConfirmationTimeout event"),
865 }
866
867 let response = rx.await.unwrap();
869 assert!(response.approved);
870 assert!(response.reason.as_ref().unwrap().contains("auto_approved"));
871 }
872
873 #[tokio::test]
874 async fn test_no_timeout_when_confirmed() {
875 let (event_tx, _) = broadcast::channel(100);
876 let policy = ConfirmationPolicy {
877 enabled: true,
878 default_timeout_ms: 50,
879 timeout_action: TimeoutAction::Reject,
880 ..Default::default()
881 };
882 let manager = ConfirmationManager::new(policy, event_tx);
883
884 let rx = manager
886 .request_confirmation("tool-1", "bash", &serde_json::json!({}))
887 .await;
888
889 manager.confirm("tool-1", true, None).await.unwrap();
891
892 tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
894
895 let timed_out = manager.check_timeouts().await;
897 assert_eq!(timed_out, 0);
898
899 let response = rx.await.unwrap();
901 assert!(response.approved);
902 assert!(response.reason.is_none());
903 }
904}