1use crate::agent::AgentEvent;
11use crate::queue::SessionLane;
12use serde::{Deserialize, Serialize};
13use std::collections::{HashMap, HashSet};
14use std::sync::Arc;
15use std::time::{Duration, Instant};
16use tokio::sync::{broadcast, oneshot, RwLock};
17
18#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
20pub enum TimeoutAction {
21 #[default]
23 Reject,
24 AutoApprove,
26}
27
28#[derive(Debug, Clone, Serialize, Deserialize)]
34pub struct ConfirmationPolicy {
35 pub enabled: bool,
37
38 pub default_timeout_ms: u64,
40
41 pub timeout_action: TimeoutAction,
43
44 pub yolo_lanes: HashSet<SessionLane>,
48}
49
50impl Default for ConfirmationPolicy {
51 fn default() -> Self {
52 Self {
53 enabled: false, default_timeout_ms: 30_000, timeout_action: TimeoutAction::Reject,
56 yolo_lanes: HashSet::new(), }
58 }
59}
60
61impl ConfirmationPolicy {
62 pub fn enabled() -> Self {
64 Self {
65 enabled: true,
66 ..Default::default()
67 }
68 }
69
70 pub fn with_yolo_lanes(mut self, lanes: impl IntoIterator<Item = SessionLane>) -> Self {
72 self.yolo_lanes = lanes.into_iter().collect();
73 self
74 }
75
76 pub fn with_timeout(mut self, timeout_ms: u64, action: TimeoutAction) -> Self {
78 self.default_timeout_ms = timeout_ms;
79 self.timeout_action = action;
80 self
81 }
82
83 pub fn is_yolo(&self, tool_name: &str) -> bool {
88 if !self.enabled {
89 return true; }
91 let lane = SessionLane::from_tool_name(tool_name);
92 self.yolo_lanes.contains(&lane)
93 }
94
95 pub fn requires_confirmation(&self, tool_name: &str) -> bool {
100 !self.is_yolo(tool_name)
101 }
102}
103
104#[derive(Debug, Clone, Serialize, Deserialize)]
106pub struct ConfirmationResponse {
107 pub approved: bool,
109 pub reason: Option<String>,
111}
112
113#[derive(Debug, Clone, Serialize, Deserialize)]
115pub struct PendingConfirmationInfo {
116 pub tool_id: String,
117 pub tool_name: String,
118 pub args: serde_json::Value,
119 pub remaining_ms: u64,
120}
121
122#[async_trait::async_trait]
127pub trait ConfirmationProvider: Send + Sync {
128 async fn requires_confirmation(&self, tool_name: &str) -> bool;
130
131 async fn request_confirmation(
135 &self,
136 tool_id: &str,
137 tool_name: &str,
138 args: &serde_json::Value,
139 ) -> oneshot::Receiver<ConfirmationResponse>;
140
141 async fn confirm(
146 &self,
147 tool_id: &str,
148 approved: bool,
149 reason: Option<String>,
150 ) -> Result<bool, String>;
151
152 async fn policy(&self) -> ConfirmationPolicy;
154
155 async fn set_policy(&self, policy: ConfirmationPolicy);
157
158 async fn check_timeouts(&self) -> usize;
160
161 async fn cancel_all(&self) -> usize;
163
164 async fn pending_confirmations(&self) -> Vec<PendingConfirmationInfo> {
166 Vec::new()
167 }
168}
169
170pub struct PendingConfirmation {
172 pub tool_id: String,
174 pub tool_name: String,
176 pub args: serde_json::Value,
178 pub created_at: Instant,
180 pub timeout_ms: u64,
182 response_tx: oneshot::Sender<ConfirmationResponse>,
184}
185
186impl PendingConfirmation {
187 pub fn is_timed_out(&self) -> bool {
189 self.created_at.elapsed() > Duration::from_millis(self.timeout_ms)
190 }
191
192 pub fn remaining_ms(&self) -> u64 {
194 let elapsed = self.created_at.elapsed().as_millis() as u64;
195 self.timeout_ms.saturating_sub(elapsed)
196 }
197}
198
199pub struct ConfirmationManager {
201 policy: RwLock<ConfirmationPolicy>,
203 pending: Arc<RwLock<HashMap<String, PendingConfirmation>>>,
205 event_tx: broadcast::Sender<AgentEvent>,
207}
208
209impl ConfirmationManager {
210 pub fn new(policy: ConfirmationPolicy, event_tx: broadcast::Sender<AgentEvent>) -> Self {
212 Self {
213 policy: RwLock::new(policy),
214 pending: Arc::new(RwLock::new(HashMap::new())),
215 event_tx,
216 }
217 }
218
219 pub async fn policy(&self) -> ConfirmationPolicy {
221 self.policy.read().await.clone()
222 }
223
224 pub async fn set_policy(&self, policy: ConfirmationPolicy) {
226 *self.policy.write().await = policy;
227 }
228
229 pub async fn requires_confirmation(&self, tool_name: &str) -> bool {
231 self.policy.read().await.requires_confirmation(tool_name)
232 }
233
234 pub async fn request_confirmation(
239 &self,
240 tool_id: &str,
241 tool_name: &str,
242 args: &serde_json::Value,
243 ) -> oneshot::Receiver<ConfirmationResponse> {
244 let (tx, rx) = oneshot::channel();
245
246 let policy = self.policy.read().await;
247 let timeout_ms = policy.default_timeout_ms;
248 drop(policy);
249
250 let pending = PendingConfirmation {
251 tool_id: tool_id.to_string(),
252 tool_name: tool_name.to_string(),
253 args: args.clone(),
254 created_at: Instant::now(),
255 timeout_ms,
256 response_tx: tx,
257 };
258
259 {
261 let mut pending_map = self.pending.write().await;
262 pending_map.insert(tool_id.to_string(), pending);
263 }
264
265 let _ = self.event_tx.send(AgentEvent::ConfirmationRequired {
267 tool_id: tool_id.to_string(),
268 tool_name: tool_name.to_string(),
269 args: args.clone(),
270 timeout_ms,
271 });
272
273 rx
274 }
275
276 pub async fn confirm(
281 &self,
282 tool_id: &str,
283 approved: bool,
284 reason: Option<String>,
285 ) -> Result<bool, String> {
286 let pending = {
287 let mut pending_map = self.pending.write().await;
288 pending_map.remove(tool_id)
289 };
290
291 if let Some(confirmation) = pending {
292 let _ = self.event_tx.send(AgentEvent::ConfirmationReceived {
294 tool_id: tool_id.to_string(),
295 approved,
296 reason: reason.clone(),
297 });
298
299 let response = ConfirmationResponse { approved, reason };
301 let _ = confirmation.response_tx.send(response);
302
303 Ok(true)
304 } else {
305 Ok(false)
306 }
307 }
308
309 pub async fn check_timeouts(&self) -> usize {
313 let policy = self.policy.read().await;
314 let timeout_action = policy.timeout_action;
315 drop(policy);
316
317 let mut timed_out = Vec::new();
318
319 {
321 let pending_map = self.pending.read().await;
322 for (tool_id, pending) in pending_map.iter() {
323 if pending.is_timed_out() {
324 timed_out.push(tool_id.clone());
325 }
326 }
327 }
328
329 for tool_id in &timed_out {
331 let pending = {
332 let mut pending_map = self.pending.write().await;
333 pending_map.remove(tool_id)
334 };
335
336 if let Some(confirmation) = pending {
337 let (approved, action_taken) = match timeout_action {
338 TimeoutAction::Reject => (false, "rejected"),
339 TimeoutAction::AutoApprove => (true, "auto_approved"),
340 };
341
342 let _ = self.event_tx.send(AgentEvent::ConfirmationTimeout {
344 tool_id: tool_id.clone(),
345 action_taken: action_taken.to_string(),
346 });
347
348 let response = ConfirmationResponse {
350 approved,
351 reason: Some(format!("Confirmation timed out, action: {}", action_taken)),
352 };
353 let _ = confirmation.response_tx.send(response);
354 }
355 }
356
357 timed_out.len()
358 }
359
360 pub async fn pending_count(&self) -> usize {
362 self.pending.read().await.len()
363 }
364
365 pub async fn pending_confirmations(&self) -> Vec<(String, String, u64)> {
367 let pending_map = self.pending.read().await;
368 pending_map
369 .values()
370 .map(|p| (p.tool_id.clone(), p.tool_name.clone(), p.remaining_ms()))
371 .collect()
372 }
373
374 pub async fn pending_confirmation_details(&self) -> Vec<PendingConfirmationInfo> {
376 let pending_map = self.pending.read().await;
377 pending_map
378 .values()
379 .map(|p| PendingConfirmationInfo {
380 tool_id: p.tool_id.clone(),
381 tool_name: p.tool_name.clone(),
382 args: p.args.clone(),
383 remaining_ms: p.remaining_ms(),
384 })
385 .collect()
386 }
387
388 pub async fn cancel(&self, tool_id: &str) -> bool {
390 let pending = {
391 let mut pending_map = self.pending.write().await;
392 pending_map.remove(tool_id)
393 };
394
395 if let Some(confirmation) = pending {
396 let response = ConfirmationResponse {
397 approved: false,
398 reason: Some("Confirmation cancelled".to_string()),
399 };
400 let _ = confirmation.response_tx.send(response);
401 true
402 } else {
403 false
404 }
405 }
406
407 pub async fn cancel_all(&self) -> usize {
409 let pending_list: Vec<_> = {
410 let mut pending_map = self.pending.write().await;
411 pending_map.drain().collect()
412 };
413
414 let count = pending_list.len();
415
416 for (_, confirmation) in pending_list {
417 let response = ConfirmationResponse {
418 approved: false,
419 reason: Some("Confirmation cancelled".to_string()),
420 };
421 let _ = confirmation.response_tx.send(response);
422 }
423
424 count
425 }
426}
427
428#[async_trait::async_trait]
430impl ConfirmationProvider for ConfirmationManager {
431 async fn requires_confirmation(&self, tool_name: &str) -> bool {
432 self.requires_confirmation(tool_name).await
433 }
434
435 async fn request_confirmation(
436 &self,
437 tool_id: &str,
438 tool_name: &str,
439 args: &serde_json::Value,
440 ) -> oneshot::Receiver<ConfirmationResponse> {
441 self.request_confirmation(tool_id, tool_name, args).await
442 }
443
444 async fn confirm(
445 &self,
446 tool_id: &str,
447 approved: bool,
448 reason: Option<String>,
449 ) -> Result<bool, String> {
450 self.confirm(tool_id, approved, reason).await
451 }
452
453 async fn policy(&self) -> ConfirmationPolicy {
454 self.policy().await
455 }
456
457 async fn set_policy(&self, policy: ConfirmationPolicy) {
458 self.set_policy(policy).await
459 }
460
461 async fn check_timeouts(&self) -> usize {
462 self.check_timeouts().await
463 }
464
465 async fn cancel_all(&self) -> usize {
466 self.cancel_all().await
467 }
468
469 async fn pending_confirmations(&self) -> Vec<PendingConfirmationInfo> {
470 self.pending_confirmation_details().await
471 }
472}
473
474pub struct AutoApproveConfirmation;
480
481#[async_trait::async_trait]
482impl ConfirmationProvider for AutoApproveConfirmation {
483 async fn requires_confirmation(&self, _tool_name: &str) -> bool {
484 false
485 }
486
487 async fn request_confirmation(
488 &self,
489 _tool_id: &str,
490 _tool_name: &str,
491 _args: &serde_json::Value,
492 ) -> oneshot::Receiver<ConfirmationResponse> {
493 let (tx, rx) = oneshot::channel();
494 let _ = tx.send(ConfirmationResponse {
495 approved: true,
496 reason: None,
497 });
498 rx
499 }
500
501 async fn confirm(
502 &self,
503 _tool_id: &str,
504 _approved: bool,
505 _reason: Option<String>,
506 ) -> Result<bool, String> {
507 Ok(false)
508 }
509
510 async fn policy(&self) -> ConfirmationPolicy {
511 ConfirmationPolicy {
512 enabled: false,
513 ..ConfirmationPolicy::default()
514 }
515 }
516
517 async fn set_policy(&self, _policy: ConfirmationPolicy) {}
518
519 async fn check_timeouts(&self) -> usize {
520 0
521 }
522
523 async fn cancel_all(&self) -> usize {
524 0
525 }
526}
527
528#[cfg(test)]
529mod tests {
530 use super::*;
531
532 #[test]
537 fn test_session_lane() {
538 assert_eq!(SessionLane::from_tool_name("read"), SessionLane::Query);
539 assert_eq!(SessionLane::from_tool_name("grep"), SessionLane::Query);
540 assert_eq!(SessionLane::from_tool_name("bash"), SessionLane::Execute);
541 assert_eq!(SessionLane::from_tool_name("write"), SessionLane::Execute);
542 }
543
544 #[test]
545 fn test_session_lane_priority() {
546 assert_eq!(SessionLane::Control.priority(), 0);
547 assert_eq!(SessionLane::Query.priority(), 1);
548 assert_eq!(SessionLane::Execute.priority(), 2);
549 assert_eq!(SessionLane::Generate.priority(), 3);
550
551 assert!(SessionLane::Control.priority() < SessionLane::Query.priority());
553 assert!(SessionLane::Query.priority() < SessionLane::Execute.priority());
554 assert!(SessionLane::Execute.priority() < SessionLane::Generate.priority());
555 }
556
557 #[test]
558 fn test_session_lane_all_query() {
559 let query_tools = ["read", "glob", "ls", "grep", "list_files", "search"];
560 for tool in query_tools {
561 assert_eq!(
562 SessionLane::from_tool_name(tool),
563 SessionLane::Query,
564 "Tool '{}' should be in Query lane",
565 tool
566 );
567 }
568 }
569
570 #[test]
571 fn test_session_lane_all_execute() {
572 let execute_tools = ["bash", "write", "edit", "delete", "move", "copy", "execute"];
573 for tool in execute_tools {
574 assert_eq!(
575 SessionLane::from_tool_name(tool),
576 SessionLane::Execute,
577 "Tool '{}' should be in Execute lane",
578 tool
579 );
580 }
581 }
582
583 #[test]
592 fn test_confirmation_policy_default() {
593 let policy = ConfirmationPolicy::default();
594 assert!(!policy.enabled);
595 assert!(!policy.requires_confirmation("bash"));
597 assert!(!policy.requires_confirmation("write"));
598 assert!(!policy.requires_confirmation("read"));
599 }
600
601 #[test]
602 fn test_confirmation_policy_enabled() {
603 let policy = ConfirmationPolicy::enabled();
604 assert!(policy.enabled);
605 assert!(policy.requires_confirmation("bash"));
607 assert!(policy.requires_confirmation("write"));
608 assert!(policy.requires_confirmation("read"));
609 assert!(policy.requires_confirmation("grep"));
610 }
611
612 #[test]
613 fn test_confirmation_policy_yolo_mode() {
614 let policy = ConfirmationPolicy::enabled().with_yolo_lanes([SessionLane::Execute]);
615
616 assert!(!policy.requires_confirmation("bash")); assert!(!policy.requires_confirmation("write")); assert!(policy.requires_confirmation("read")); }
620
621 #[test]
622 fn test_confirmation_policy_yolo_multiple_lanes() {
623 let policy = ConfirmationPolicy::enabled()
624 .with_yolo_lanes([SessionLane::Query, SessionLane::Execute]);
625
626 assert!(!policy.requires_confirmation("bash")); assert!(!policy.requires_confirmation("read")); assert!(!policy.requires_confirmation("grep")); }
631
632 #[test]
633 fn test_confirmation_policy_is_yolo() {
634 let policy = ConfirmationPolicy::enabled().with_yolo_lanes([SessionLane::Execute]);
635
636 assert!(policy.is_yolo("bash")); assert!(policy.is_yolo("write")); assert!(!policy.is_yolo("read")); }
640
641 #[test]
642 fn test_confirmation_policy_disabled_is_always_yolo() {
643 let policy = ConfirmationPolicy::default(); assert!(policy.is_yolo("bash"));
645 assert!(policy.is_yolo("read"));
646 assert!(policy.is_yolo("unknown_tool"));
647 }
648
649 #[test]
650 fn test_confirmation_policy_with_timeout() {
651 let policy = ConfirmationPolicy::enabled().with_timeout(5000, TimeoutAction::AutoApprove);
652
653 assert_eq!(policy.default_timeout_ms, 5000);
654 assert_eq!(policy.timeout_action, TimeoutAction::AutoApprove);
655 }
656
657 #[tokio::test]
662 async fn test_confirmation_manager_no_hitl() {
663 let (event_tx, _) = broadcast::channel(100);
664 let manager = ConfirmationManager::new(ConfirmationPolicy::default(), event_tx);
665
666 assert!(!manager.requires_confirmation("bash").await);
667 }
668
669 #[tokio::test]
670 async fn test_confirmation_manager_with_hitl() {
671 let (event_tx, _) = broadcast::channel(100);
672 let manager = ConfirmationManager::new(ConfirmationPolicy::enabled(), event_tx);
673
674 assert!(manager.requires_confirmation("bash").await);
676 assert!(manager.requires_confirmation("read").await);
677 }
678
679 #[tokio::test]
680 async fn test_confirmation_manager_with_yolo() {
681 let (event_tx, _) = broadcast::channel(100);
682 let policy = ConfirmationPolicy::enabled().with_yolo_lanes([SessionLane::Query]);
683 let manager = ConfirmationManager::new(policy, event_tx);
684
685 assert!(manager.requires_confirmation("bash").await); assert!(!manager.requires_confirmation("read").await); }
688
689 #[tokio::test]
690 async fn test_confirmation_manager_policy_update() {
691 let (event_tx, _) = broadcast::channel(100);
692 let manager = ConfirmationManager::new(ConfirmationPolicy::default(), event_tx);
693
694 assert!(!manager.requires_confirmation("bash").await);
696
697 manager.set_policy(ConfirmationPolicy::enabled()).await;
699 assert!(manager.requires_confirmation("bash").await);
700
701 manager
703 .set_policy(ConfirmationPolicy::enabled().with_yolo_lanes([SessionLane::Execute]))
704 .await;
705 assert!(!manager.requires_confirmation("bash").await);
706 }
707
708 #[tokio::test]
713 async fn test_confirmation_flow_approve() {
714 let (event_tx, mut event_rx) = broadcast::channel(100);
715 let manager = ConfirmationManager::new(ConfirmationPolicy::enabled(), event_tx);
716
717 let rx = manager
719 .request_confirmation("tool-1", "bash", &serde_json::json!({"command": "ls"}))
720 .await;
721
722 let event = event_rx.recv().await.unwrap();
724 match event {
725 AgentEvent::ConfirmationRequired {
726 tool_id,
727 tool_name,
728 timeout_ms,
729 ..
730 } => {
731 assert_eq!(tool_id, "tool-1");
732 assert_eq!(tool_name, "bash");
733 assert_eq!(timeout_ms, 30_000); }
735 _ => panic!("Expected ConfirmationRequired event"),
736 }
737
738 let result = manager.confirm("tool-1", true, None).await;
740 assert!(result.is_ok());
741 assert!(result.unwrap());
742
743 let event = event_rx.recv().await.unwrap();
745 match event {
746 AgentEvent::ConfirmationReceived {
747 tool_id, approved, ..
748 } => {
749 assert_eq!(tool_id, "tool-1");
750 assert!(approved);
751 }
752 _ => panic!("Expected ConfirmationReceived event"),
753 }
754
755 let response = rx.await.unwrap();
757 assert!(response.approved);
758 assert!(response.reason.is_none());
759 }
760
761 #[tokio::test]
762 async fn test_confirmation_flow_reject() {
763 let (event_tx, mut event_rx) = broadcast::channel(100);
764 let manager = ConfirmationManager::new(ConfirmationPolicy::enabled(), event_tx);
765
766 let rx = manager
768 .request_confirmation(
769 "tool-1",
770 "bash",
771 &serde_json::json!({"command": "rm -rf /"}),
772 )
773 .await;
774
775 let _ = event_rx.recv().await.unwrap();
777
778 let result = manager
780 .confirm("tool-1", false, Some("Dangerous command".to_string()))
781 .await;
782 assert!(result.is_ok());
783 assert!(result.unwrap());
784
785 let event = event_rx.recv().await.unwrap();
787 match event {
788 AgentEvent::ConfirmationReceived {
789 tool_id,
790 approved,
791 reason,
792 } => {
793 assert_eq!(tool_id, "tool-1");
794 assert!(!approved);
795 assert_eq!(reason, Some("Dangerous command".to_string()));
796 }
797 _ => panic!("Expected ConfirmationReceived event"),
798 }
799
800 let response = rx.await.unwrap();
802 assert!(!response.approved);
803 assert_eq!(response.reason, Some("Dangerous command".to_string()));
804 }
805
806 #[tokio::test]
807 async fn test_confirmation_not_found() {
808 let (event_tx, _) = broadcast::channel(100);
809 let manager = ConfirmationManager::new(ConfirmationPolicy::enabled(), event_tx);
810
811 let result = manager.confirm("non-existent", true, None).await;
813 assert!(result.is_ok());
814 assert!(!result.unwrap()); }
816
817 #[tokio::test]
822 async fn test_multiple_confirmations() {
823 let (event_tx, _) = broadcast::channel(100);
824 let manager = ConfirmationManager::new(ConfirmationPolicy::enabled(), event_tx);
825
826 let rx1 = manager
828 .request_confirmation("tool-1", "bash", &serde_json::json!({"cmd": "1"}))
829 .await;
830 let rx2 = manager
831 .request_confirmation("tool-2", "write", &serde_json::json!({"cmd": "2"}))
832 .await;
833 let rx3 = manager
834 .request_confirmation("tool-3", "edit", &serde_json::json!({"cmd": "3"}))
835 .await;
836
837 assert_eq!(manager.pending_count().await, 3);
839
840 manager.confirm("tool-1", true, None).await.unwrap();
842 let response1 = rx1.await.unwrap();
843 assert!(response1.approved);
844
845 manager.confirm("tool-2", false, None).await.unwrap();
847 let response2 = rx2.await.unwrap();
848 assert!(!response2.approved);
849
850 manager.confirm("tool-3", true, None).await.unwrap();
852 let response3 = rx3.await.unwrap();
853 assert!(response3.approved);
854
855 assert_eq!(manager.pending_count().await, 0);
857 }
858
859 #[tokio::test]
860 async fn test_pending_confirmations_info() {
861 let (event_tx, _) = broadcast::channel(100);
862 let manager = ConfirmationManager::new(ConfirmationPolicy::enabled(), event_tx);
863
864 let _rx1 = manager
866 .request_confirmation("tool-1", "bash", &serde_json::json!({}))
867 .await;
868 let _rx2 = manager
869 .request_confirmation("tool-2", "write", &serde_json::json!({}))
870 .await;
871
872 let pending = manager.pending_confirmations().await;
873 assert_eq!(pending.len(), 2);
874
875 let tool_ids: Vec<&str> = pending.iter().map(|(id, _, _)| id.as_str()).collect();
877 assert!(tool_ids.contains(&"tool-1"));
878 assert!(tool_ids.contains(&"tool-2"));
879 }
880
881 #[tokio::test]
886 async fn test_cancel_confirmation() {
887 let (event_tx, _) = broadcast::channel(100);
888 let manager = ConfirmationManager::new(ConfirmationPolicy::enabled(), event_tx);
889
890 let rx = manager
892 .request_confirmation("tool-1", "bash", &serde_json::json!({}))
893 .await;
894
895 assert_eq!(manager.pending_count().await, 1);
896
897 let cancelled = manager.cancel("tool-1").await;
899 assert!(cancelled);
900 assert_eq!(manager.pending_count().await, 0);
901
902 let response = rx.await.unwrap();
904 assert!(!response.approved);
905 assert_eq!(response.reason, Some("Confirmation cancelled".to_string()));
906 }
907
908 #[tokio::test]
909 async fn test_cancel_nonexistent() {
910 let (event_tx, _) = broadcast::channel(100);
911 let manager = ConfirmationManager::new(ConfirmationPolicy::enabled(), event_tx);
912
913 let cancelled = manager.cancel("non-existent").await;
914 assert!(!cancelled);
915 }
916
917 #[tokio::test]
918 async fn test_cancel_all() {
919 let (event_tx, _) = broadcast::channel(100);
920 let manager = ConfirmationManager::new(ConfirmationPolicy::enabled(), event_tx);
921
922 let rx1 = manager
924 .request_confirmation("tool-1", "bash", &serde_json::json!({}))
925 .await;
926 let rx2 = manager
927 .request_confirmation("tool-2", "write", &serde_json::json!({}))
928 .await;
929 let rx3 = manager
930 .request_confirmation("tool-3", "edit", &serde_json::json!({}))
931 .await;
932
933 assert_eq!(manager.pending_count().await, 3);
934
935 let cancelled_count = manager.cancel_all().await;
937 assert_eq!(cancelled_count, 3);
938 assert_eq!(manager.pending_count().await, 0);
939
940 for rx in [rx1, rx2, rx3] {
942 let response = rx.await.unwrap();
943 assert!(!response.approved);
944 assert_eq!(response.reason, Some("Confirmation cancelled".to_string()));
945 }
946 }
947
948 #[tokio::test]
953 async fn test_timeout_reject() {
954 let (event_tx, mut event_rx) = broadcast::channel(100);
955 let policy = ConfirmationPolicy {
956 enabled: true,
957 default_timeout_ms: 50, timeout_action: TimeoutAction::Reject,
959 ..Default::default()
960 };
961 let manager = ConfirmationManager::new(policy, event_tx);
962
963 let rx = manager
965 .request_confirmation("tool-1", "bash", &serde_json::json!({}))
966 .await;
967
968 let _ = event_rx.recv().await.unwrap();
970
971 tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
973
974 let timed_out = manager.check_timeouts().await;
976 assert_eq!(timed_out, 1);
977
978 let event = event_rx.recv().await.unwrap();
980 match event {
981 AgentEvent::ConfirmationTimeout {
982 tool_id,
983 action_taken,
984 } => {
985 assert_eq!(tool_id, "tool-1");
986 assert_eq!(action_taken, "rejected");
987 }
988 _ => panic!("Expected ConfirmationTimeout event"),
989 }
990
991 let response = rx.await.unwrap();
993 assert!(!response.approved);
994 assert!(response.reason.as_ref().unwrap().contains("timed out"));
995 }
996
997 #[tokio::test]
998 async fn test_timeout_auto_approve() {
999 let (event_tx, mut event_rx) = broadcast::channel(100);
1000 let policy = ConfirmationPolicy {
1001 enabled: true,
1002 default_timeout_ms: 50, timeout_action: TimeoutAction::AutoApprove,
1004 ..Default::default()
1005 };
1006 let manager = ConfirmationManager::new(policy, event_tx);
1007
1008 let rx = manager
1010 .request_confirmation("tool-1", "bash", &serde_json::json!({}))
1011 .await;
1012
1013 let _ = event_rx.recv().await.unwrap();
1015
1016 tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
1018
1019 let timed_out = manager.check_timeouts().await;
1021 assert_eq!(timed_out, 1);
1022
1023 let event = event_rx.recv().await.unwrap();
1025 match event {
1026 AgentEvent::ConfirmationTimeout {
1027 tool_id,
1028 action_taken,
1029 } => {
1030 assert_eq!(tool_id, "tool-1");
1031 assert_eq!(action_taken, "auto_approved");
1032 }
1033 _ => panic!("Expected ConfirmationTimeout event"),
1034 }
1035
1036 let response = rx.await.unwrap();
1038 assert!(response.approved);
1039 assert!(response.reason.as_ref().unwrap().contains("auto_approved"));
1040 }
1041
1042 #[tokio::test]
1043 async fn test_no_timeout_when_confirmed() {
1044 let (event_tx, _) = broadcast::channel(100);
1045 let policy = ConfirmationPolicy {
1046 enabled: true,
1047 default_timeout_ms: 50,
1048 timeout_action: TimeoutAction::Reject,
1049 ..Default::default()
1050 };
1051 let manager = ConfirmationManager::new(policy, event_tx);
1052
1053 let rx = manager
1055 .request_confirmation("tool-1", "bash", &serde_json::json!({}))
1056 .await;
1057
1058 manager.confirm("tool-1", true, None).await.unwrap();
1060
1061 tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
1063
1064 let timed_out = manager.check_timeouts().await;
1066 assert_eq!(timed_out, 0);
1067
1068 let response = rx.await.unwrap();
1070 assert!(response.approved);
1071 assert!(response.reason.is_none());
1072 }
1073}