1use crate::agent::AgentEvent;
11use crate::queue::SessionLane;
12use serde::{Deserialize, Serialize};
13use std::collections::{HashMap, HashSet};
14use std::sync::Arc;
15use std::time::{Duration, Instant};
16use tokio::sync::{broadcast, oneshot, RwLock};
17
18#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
20pub enum TimeoutAction {
21 #[default]
23 Reject,
24 AutoApprove,
26}
27
28#[derive(Debug, Clone, Serialize, Deserialize)]
34pub struct ConfirmationPolicy {
35 pub enabled: bool,
37
38 pub default_timeout_ms: u64,
40
41 pub timeout_action: TimeoutAction,
43
44 pub yolo_lanes: HashSet<SessionLane>,
48}
49
50impl Default for ConfirmationPolicy {
51 fn default() -> Self {
52 Self {
53 enabled: false, default_timeout_ms: 30_000, timeout_action: TimeoutAction::Reject,
56 yolo_lanes: HashSet::new(), }
58 }
59}
60
61impl ConfirmationPolicy {
62 pub fn enabled() -> Self {
64 Self {
65 enabled: true,
66 ..Default::default()
67 }
68 }
69
70 pub fn with_yolo_lanes(mut self, lanes: impl IntoIterator<Item = SessionLane>) -> Self {
72 self.yolo_lanes = lanes.into_iter().collect();
73 self
74 }
75
76 pub fn with_timeout(mut self, timeout_ms: u64, action: TimeoutAction) -> Self {
78 self.default_timeout_ms = timeout_ms;
79 self.timeout_action = action;
80 self
81 }
82
83 pub fn is_yolo(&self, tool_name: &str) -> bool {
88 if !self.enabled {
89 return true; }
91 let lane = SessionLane::from_tool_name(tool_name);
92 self.yolo_lanes.contains(&lane)
93 }
94
95 pub fn requires_confirmation(&self, tool_name: &str) -> bool {
100 !self.is_yolo(tool_name)
101 }
102}
103
104#[derive(Debug, Clone, Serialize, Deserialize)]
106pub struct ConfirmationResponse {
107 pub approved: bool,
109 pub reason: Option<String>,
111}
112
113#[derive(Debug, Clone, Serialize, Deserialize)]
115pub struct PendingConfirmationInfo {
116 pub tool_id: String,
117 pub tool_name: String,
118 pub args: serde_json::Value,
119 pub remaining_ms: u64,
120}
121
122#[async_trait::async_trait]
127pub trait ConfirmationProvider: Send + Sync {
128 async fn requires_confirmation(&self, tool_name: &str) -> bool;
130
131 async fn request_confirmation(
135 &self,
136 tool_id: &str,
137 tool_name: &str,
138 args: &serde_json::Value,
139 ) -> oneshot::Receiver<ConfirmationResponse>;
140
141 async fn confirm(
146 &self,
147 tool_id: &str,
148 approved: bool,
149 reason: Option<String>,
150 ) -> Result<bool, String>;
151
152 async fn policy(&self) -> ConfirmationPolicy;
154
155 async fn set_policy(&self, policy: ConfirmationPolicy);
157
158 async fn check_timeouts(&self) -> usize;
160
161 async fn cancel_all(&self) -> usize;
163
164 async fn pending_confirmations(&self) -> Vec<PendingConfirmationInfo> {
166 Vec::new()
167 }
168}
169
170pub struct PendingConfirmation {
172 pub tool_id: String,
174 pub tool_name: String,
176 pub args: serde_json::Value,
178 pub created_at: Instant,
180 pub timeout_ms: u64,
182 response_tx: oneshot::Sender<ConfirmationResponse>,
184}
185
186impl PendingConfirmation {
187 pub fn is_timed_out(&self) -> bool {
189 self.created_at.elapsed() > Duration::from_millis(self.timeout_ms)
190 }
191
192 pub fn remaining_ms(&self) -> u64 {
194 let elapsed = self.created_at.elapsed().as_millis() as u64;
195 self.timeout_ms.saturating_sub(elapsed)
196 }
197}
198
199pub struct ConfirmationManager {
201 policy: RwLock<ConfirmationPolicy>,
203 pending: Arc<RwLock<HashMap<String, PendingConfirmation>>>,
205 event_tx: broadcast::Sender<AgentEvent>,
207}
208
209impl ConfirmationManager {
210 pub fn new(policy: ConfirmationPolicy, event_tx: broadcast::Sender<AgentEvent>) -> Self {
212 Self {
213 policy: RwLock::new(policy),
214 pending: Arc::new(RwLock::new(HashMap::new())),
215 event_tx,
216 }
217 }
218
219 pub async fn policy(&self) -> ConfirmationPolicy {
221 self.policy.read().await.clone()
222 }
223
224 pub async fn set_policy(&self, policy: ConfirmationPolicy) {
226 *self.policy.write().await = policy;
227 }
228
229 pub async fn requires_confirmation(&self, tool_name: &str) -> bool {
231 self.policy.read().await.requires_confirmation(tool_name)
232 }
233
234 pub async fn request_confirmation(
239 &self,
240 tool_id: &str,
241 tool_name: &str,
242 args: &serde_json::Value,
243 ) -> oneshot::Receiver<ConfirmationResponse> {
244 let (tx, rx) = oneshot::channel();
245
246 let policy = self.policy.read().await;
247 let timeout_ms = policy.default_timeout_ms;
248 drop(policy);
249
250 let pending = PendingConfirmation {
251 tool_id: tool_id.to_string(),
252 tool_name: tool_name.to_string(),
253 args: args.clone(),
254 created_at: Instant::now(),
255 timeout_ms,
256 response_tx: tx,
257 };
258
259 {
261 let mut pending_map = self.pending.write().await;
262 pending_map.insert(tool_id.to_string(), pending);
263 }
264
265 let _ = self.event_tx.send(AgentEvent::ConfirmationRequired {
267 tool_id: tool_id.to_string(),
268 tool_name: tool_name.to_string(),
269 args: args.clone(),
270 timeout_ms,
271 });
272
273 rx
274 }
275
276 pub async fn confirm(
281 &self,
282 tool_id: &str,
283 approved: bool,
284 reason: Option<String>,
285 ) -> Result<bool, String> {
286 let pending = {
287 let mut pending_map = self.pending.write().await;
288 pending_map.remove(tool_id)
289 };
290
291 if let Some(confirmation) = pending {
292 let _ = self.event_tx.send(AgentEvent::ConfirmationReceived {
294 tool_id: tool_id.to_string(),
295 approved,
296 reason: reason.clone(),
297 });
298
299 let response = ConfirmationResponse { approved, reason };
301 let _ = confirmation.response_tx.send(response);
302
303 Ok(true)
304 } else {
305 Ok(false)
306 }
307 }
308
309 pub async fn check_timeouts(&self) -> usize {
313 let policy = self.policy.read().await;
314 let timeout_action = policy.timeout_action;
315 drop(policy);
316
317 let mut timed_out = Vec::new();
318
319 {
321 let pending_map = self.pending.read().await;
322 for (tool_id, pending) in pending_map.iter() {
323 if pending.is_timed_out() {
324 timed_out.push(tool_id.clone());
325 }
326 }
327 }
328
329 for tool_id in &timed_out {
331 let pending = {
332 let mut pending_map = self.pending.write().await;
333 pending_map.remove(tool_id)
334 };
335
336 if let Some(confirmation) = pending {
337 let (approved, action_taken) = match timeout_action {
338 TimeoutAction::Reject => (false, "rejected"),
339 TimeoutAction::AutoApprove => (true, "auto_approved"),
340 };
341
342 let _ = self.event_tx.send(AgentEvent::ConfirmationTimeout {
344 tool_id: tool_id.clone(),
345 action_taken: action_taken.to_string(),
346 });
347
348 let response = ConfirmationResponse {
350 approved,
351 reason: Some(format!("Confirmation timed out, action: {}", action_taken)),
352 };
353 let _ = confirmation.response_tx.send(response);
354 }
355 }
356
357 timed_out.len()
358 }
359
360 pub async fn pending_count(&self) -> usize {
362 self.pending.read().await.len()
363 }
364
365 pub async fn pending_confirmations(&self) -> Vec<(String, String, u64)> {
367 let pending_map = self.pending.read().await;
368 pending_map
369 .values()
370 .map(|p| (p.tool_id.clone(), p.tool_name.clone(), p.remaining_ms()))
371 .collect()
372 }
373
374 pub async fn pending_confirmation_details(&self) -> Vec<PendingConfirmationInfo> {
376 let pending_map = self.pending.read().await;
377 pending_map
378 .values()
379 .map(|p| PendingConfirmationInfo {
380 tool_id: p.tool_id.clone(),
381 tool_name: p.tool_name.clone(),
382 args: p.args.clone(),
383 remaining_ms: p.remaining_ms(),
384 })
385 .collect()
386 }
387
388 pub async fn cancel(&self, tool_id: &str) -> bool {
390 let pending = {
391 let mut pending_map = self.pending.write().await;
392 pending_map.remove(tool_id)
393 };
394
395 if let Some(confirmation) = pending {
396 let response = ConfirmationResponse {
397 approved: false,
398 reason: Some("Confirmation cancelled".to_string()),
399 };
400 let _ = confirmation.response_tx.send(response);
401 true
402 } else {
403 false
404 }
405 }
406
407 pub async fn cancel_all(&self) -> usize {
409 let pending_list: Vec<_> = {
410 let mut pending_map = self.pending.write().await;
411 pending_map.drain().collect()
412 };
413
414 let count = pending_list.len();
415
416 for (_, confirmation) in pending_list {
417 let response = ConfirmationResponse {
418 approved: false,
419 reason: Some("Confirmation cancelled".to_string()),
420 };
421 let _ = confirmation.response_tx.send(response);
422 }
423
424 count
425 }
426}
427
428#[async_trait::async_trait]
430impl ConfirmationProvider for ConfirmationManager {
431 async fn requires_confirmation(&self, tool_name: &str) -> bool {
432 self.requires_confirmation(tool_name).await
433 }
434
435 async fn request_confirmation(
436 &self,
437 tool_id: &str,
438 tool_name: &str,
439 args: &serde_json::Value,
440 ) -> oneshot::Receiver<ConfirmationResponse> {
441 self.request_confirmation(tool_id, tool_name, args).await
442 }
443
444 async fn confirm(
445 &self,
446 tool_id: &str,
447 approved: bool,
448 reason: Option<String>,
449 ) -> Result<bool, String> {
450 self.confirm(tool_id, approved, reason).await
451 }
452
453 async fn policy(&self) -> ConfirmationPolicy {
454 self.policy().await
455 }
456
457 async fn set_policy(&self, policy: ConfirmationPolicy) {
458 self.set_policy(policy).await
459 }
460
461 async fn check_timeouts(&self) -> usize {
462 self.check_timeouts().await
463 }
464
465 async fn cancel_all(&self) -> usize {
466 self.cancel_all().await
467 }
468
469 async fn pending_confirmations(&self) -> Vec<PendingConfirmationInfo> {
470 self.pending_confirmation_details().await
471 }
472}
473
474#[cfg(test)]
475mod tests {
476 use super::*;
477
478 #[test]
483 fn test_session_lane() {
484 assert_eq!(SessionLane::from_tool_name("read"), SessionLane::Query);
485 assert_eq!(SessionLane::from_tool_name("grep"), SessionLane::Query);
486 assert_eq!(SessionLane::from_tool_name("bash"), SessionLane::Execute);
487 assert_eq!(SessionLane::from_tool_name("write"), SessionLane::Execute);
488 }
489
490 #[test]
491 fn test_session_lane_priority() {
492 assert_eq!(SessionLane::Control.priority(), 0);
493 assert_eq!(SessionLane::Query.priority(), 1);
494 assert_eq!(SessionLane::Execute.priority(), 2);
495 assert_eq!(SessionLane::Generate.priority(), 3);
496
497 assert!(SessionLane::Control.priority() < SessionLane::Query.priority());
499 assert!(SessionLane::Query.priority() < SessionLane::Execute.priority());
500 assert!(SessionLane::Execute.priority() < SessionLane::Generate.priority());
501 }
502
503 #[test]
504 fn test_session_lane_all_query() {
505 let query_tools = ["read", "glob", "ls", "grep", "list_files", "search"];
506 for tool in query_tools {
507 assert_eq!(
508 SessionLane::from_tool_name(tool),
509 SessionLane::Query,
510 "Tool '{}' should be in Query lane",
511 tool
512 );
513 }
514 }
515
516 #[test]
517 fn test_session_lane_all_execute() {
518 let execute_tools = ["bash", "write", "edit", "delete", "move", "copy", "execute"];
519 for tool in execute_tools {
520 assert_eq!(
521 SessionLane::from_tool_name(tool),
522 SessionLane::Execute,
523 "Tool '{}' should be in Execute lane",
524 tool
525 );
526 }
527 }
528
529 #[test]
538 fn test_confirmation_policy_default() {
539 let policy = ConfirmationPolicy::default();
540 assert!(!policy.enabled);
541 assert!(!policy.requires_confirmation("bash"));
543 assert!(!policy.requires_confirmation("write"));
544 assert!(!policy.requires_confirmation("read"));
545 }
546
547 #[test]
548 fn test_confirmation_policy_enabled() {
549 let policy = ConfirmationPolicy::enabled();
550 assert!(policy.enabled);
551 assert!(policy.requires_confirmation("bash"));
553 assert!(policy.requires_confirmation("write"));
554 assert!(policy.requires_confirmation("read"));
555 assert!(policy.requires_confirmation("grep"));
556 }
557
558 #[test]
559 fn test_confirmation_policy_yolo_mode() {
560 let policy = ConfirmationPolicy::enabled().with_yolo_lanes([SessionLane::Execute]);
561
562 assert!(!policy.requires_confirmation("bash")); assert!(!policy.requires_confirmation("write")); assert!(policy.requires_confirmation("read")); }
566
567 #[test]
568 fn test_confirmation_policy_yolo_multiple_lanes() {
569 let policy = ConfirmationPolicy::enabled()
570 .with_yolo_lanes([SessionLane::Query, SessionLane::Execute]);
571
572 assert!(!policy.requires_confirmation("bash")); assert!(!policy.requires_confirmation("read")); assert!(!policy.requires_confirmation("grep")); }
577
578 #[test]
579 fn test_confirmation_policy_is_yolo() {
580 let policy = ConfirmationPolicy::enabled().with_yolo_lanes([SessionLane::Execute]);
581
582 assert!(policy.is_yolo("bash")); assert!(policy.is_yolo("write")); assert!(!policy.is_yolo("read")); }
586
587 #[test]
588 fn test_confirmation_policy_disabled_is_always_yolo() {
589 let policy = ConfirmationPolicy::default(); assert!(policy.is_yolo("bash"));
591 assert!(policy.is_yolo("read"));
592 assert!(policy.is_yolo("unknown_tool"));
593 }
594
595 #[test]
596 fn test_confirmation_policy_with_timeout() {
597 let policy = ConfirmationPolicy::enabled().with_timeout(5000, TimeoutAction::AutoApprove);
598
599 assert_eq!(policy.default_timeout_ms, 5000);
600 assert_eq!(policy.timeout_action, TimeoutAction::AutoApprove);
601 }
602
603 #[tokio::test]
608 async fn test_confirmation_manager_no_hitl() {
609 let (event_tx, _) = broadcast::channel(100);
610 let manager = ConfirmationManager::new(ConfirmationPolicy::default(), event_tx);
611
612 assert!(!manager.requires_confirmation("bash").await);
613 }
614
615 #[tokio::test]
616 async fn test_confirmation_manager_with_hitl() {
617 let (event_tx, _) = broadcast::channel(100);
618 let manager = ConfirmationManager::new(ConfirmationPolicy::enabled(), event_tx);
619
620 assert!(manager.requires_confirmation("bash").await);
622 assert!(manager.requires_confirmation("read").await);
623 }
624
625 #[tokio::test]
626 async fn test_confirmation_manager_with_yolo() {
627 let (event_tx, _) = broadcast::channel(100);
628 let policy = ConfirmationPolicy::enabled().with_yolo_lanes([SessionLane::Query]);
629 let manager = ConfirmationManager::new(policy, event_tx);
630
631 assert!(manager.requires_confirmation("bash").await); assert!(!manager.requires_confirmation("read").await); }
634
635 #[tokio::test]
636 async fn test_confirmation_manager_policy_update() {
637 let (event_tx, _) = broadcast::channel(100);
638 let manager = ConfirmationManager::new(ConfirmationPolicy::default(), event_tx);
639
640 assert!(!manager.requires_confirmation("bash").await);
642
643 manager.set_policy(ConfirmationPolicy::enabled()).await;
645 assert!(manager.requires_confirmation("bash").await);
646
647 manager
649 .set_policy(ConfirmationPolicy::enabled().with_yolo_lanes([SessionLane::Execute]))
650 .await;
651 assert!(!manager.requires_confirmation("bash").await);
652 }
653
654 #[tokio::test]
659 async fn test_confirmation_flow_approve() {
660 let (event_tx, mut event_rx) = broadcast::channel(100);
661 let manager = ConfirmationManager::new(ConfirmationPolicy::enabled(), event_tx);
662
663 let rx = manager
665 .request_confirmation("tool-1", "bash", &serde_json::json!({"command": "ls"}))
666 .await;
667
668 let event = event_rx.recv().await.unwrap();
670 match event {
671 AgentEvent::ConfirmationRequired {
672 tool_id,
673 tool_name,
674 timeout_ms,
675 ..
676 } => {
677 assert_eq!(tool_id, "tool-1");
678 assert_eq!(tool_name, "bash");
679 assert_eq!(timeout_ms, 30_000); }
681 _ => panic!("Expected ConfirmationRequired event"),
682 }
683
684 let result = manager.confirm("tool-1", true, None).await;
686 assert!(result.is_ok());
687 assert!(result.unwrap());
688
689 let event = event_rx.recv().await.unwrap();
691 match event {
692 AgentEvent::ConfirmationReceived {
693 tool_id, approved, ..
694 } => {
695 assert_eq!(tool_id, "tool-1");
696 assert!(approved);
697 }
698 _ => panic!("Expected ConfirmationReceived event"),
699 }
700
701 let response = rx.await.unwrap();
703 assert!(response.approved);
704 assert!(response.reason.is_none());
705 }
706
707 #[tokio::test]
708 async fn test_confirmation_flow_reject() {
709 let (event_tx, mut event_rx) = broadcast::channel(100);
710 let manager = ConfirmationManager::new(ConfirmationPolicy::enabled(), event_tx);
711
712 let rx = manager
714 .request_confirmation(
715 "tool-1",
716 "bash",
717 &serde_json::json!({"command": "rm -rf /"}),
718 )
719 .await;
720
721 let _ = event_rx.recv().await.unwrap();
723
724 let result = manager
726 .confirm("tool-1", false, Some("Dangerous command".to_string()))
727 .await;
728 assert!(result.is_ok());
729 assert!(result.unwrap());
730
731 let event = event_rx.recv().await.unwrap();
733 match event {
734 AgentEvent::ConfirmationReceived {
735 tool_id,
736 approved,
737 reason,
738 } => {
739 assert_eq!(tool_id, "tool-1");
740 assert!(!approved);
741 assert_eq!(reason, Some("Dangerous command".to_string()));
742 }
743 _ => panic!("Expected ConfirmationReceived event"),
744 }
745
746 let response = rx.await.unwrap();
748 assert!(!response.approved);
749 assert_eq!(response.reason, Some("Dangerous command".to_string()));
750 }
751
752 #[tokio::test]
753 async fn test_confirmation_not_found() {
754 let (event_tx, _) = broadcast::channel(100);
755 let manager = ConfirmationManager::new(ConfirmationPolicy::enabled(), event_tx);
756
757 let result = manager.confirm("non-existent", true, None).await;
759 assert!(result.is_ok());
760 assert!(!result.unwrap()); }
762
763 #[tokio::test]
768 async fn test_multiple_confirmations() {
769 let (event_tx, _) = broadcast::channel(100);
770 let manager = ConfirmationManager::new(ConfirmationPolicy::enabled(), event_tx);
771
772 let rx1 = manager
774 .request_confirmation("tool-1", "bash", &serde_json::json!({"cmd": "1"}))
775 .await;
776 let rx2 = manager
777 .request_confirmation("tool-2", "write", &serde_json::json!({"cmd": "2"}))
778 .await;
779 let rx3 = manager
780 .request_confirmation("tool-3", "edit", &serde_json::json!({"cmd": "3"}))
781 .await;
782
783 assert_eq!(manager.pending_count().await, 3);
785
786 manager.confirm("tool-1", true, None).await.unwrap();
788 let response1 = rx1.await.unwrap();
789 assert!(response1.approved);
790
791 manager.confirm("tool-2", false, None).await.unwrap();
793 let response2 = rx2.await.unwrap();
794 assert!(!response2.approved);
795
796 manager.confirm("tool-3", true, None).await.unwrap();
798 let response3 = rx3.await.unwrap();
799 assert!(response3.approved);
800
801 assert_eq!(manager.pending_count().await, 0);
803 }
804
805 #[tokio::test]
806 async fn test_pending_confirmations_info() {
807 let (event_tx, _) = broadcast::channel(100);
808 let manager = ConfirmationManager::new(ConfirmationPolicy::enabled(), event_tx);
809
810 let _rx1 = manager
812 .request_confirmation("tool-1", "bash", &serde_json::json!({}))
813 .await;
814 let _rx2 = manager
815 .request_confirmation("tool-2", "write", &serde_json::json!({}))
816 .await;
817
818 let pending = manager.pending_confirmations().await;
819 assert_eq!(pending.len(), 2);
820
821 let tool_ids: Vec<&str> = pending.iter().map(|(id, _, _)| id.as_str()).collect();
823 assert!(tool_ids.contains(&"tool-1"));
824 assert!(tool_ids.contains(&"tool-2"));
825 }
826
827 #[tokio::test]
832 async fn test_cancel_confirmation() {
833 let (event_tx, _) = broadcast::channel(100);
834 let manager = ConfirmationManager::new(ConfirmationPolicy::enabled(), event_tx);
835
836 let rx = manager
838 .request_confirmation("tool-1", "bash", &serde_json::json!({}))
839 .await;
840
841 assert_eq!(manager.pending_count().await, 1);
842
843 let cancelled = manager.cancel("tool-1").await;
845 assert!(cancelled);
846 assert_eq!(manager.pending_count().await, 0);
847
848 let response = rx.await.unwrap();
850 assert!(!response.approved);
851 assert_eq!(response.reason, Some("Confirmation cancelled".to_string()));
852 }
853
854 #[tokio::test]
855 async fn test_cancel_nonexistent() {
856 let (event_tx, _) = broadcast::channel(100);
857 let manager = ConfirmationManager::new(ConfirmationPolicy::enabled(), event_tx);
858
859 let cancelled = manager.cancel("non-existent").await;
860 assert!(!cancelled);
861 }
862
863 #[tokio::test]
864 async fn test_cancel_all() {
865 let (event_tx, _) = broadcast::channel(100);
866 let manager = ConfirmationManager::new(ConfirmationPolicy::enabled(), event_tx);
867
868 let rx1 = manager
870 .request_confirmation("tool-1", "bash", &serde_json::json!({}))
871 .await;
872 let rx2 = manager
873 .request_confirmation("tool-2", "write", &serde_json::json!({}))
874 .await;
875 let rx3 = manager
876 .request_confirmation("tool-3", "edit", &serde_json::json!({}))
877 .await;
878
879 assert_eq!(manager.pending_count().await, 3);
880
881 let cancelled_count = manager.cancel_all().await;
883 assert_eq!(cancelled_count, 3);
884 assert_eq!(manager.pending_count().await, 0);
885
886 for rx in [rx1, rx2, rx3] {
888 let response = rx.await.unwrap();
889 assert!(!response.approved);
890 assert_eq!(response.reason, Some("Confirmation cancelled".to_string()));
891 }
892 }
893
894 #[tokio::test]
899 async fn test_timeout_reject() {
900 let (event_tx, mut event_rx) = broadcast::channel(100);
901 let policy = ConfirmationPolicy {
902 enabled: true,
903 default_timeout_ms: 50, timeout_action: TimeoutAction::Reject,
905 ..Default::default()
906 };
907 let manager = ConfirmationManager::new(policy, event_tx);
908
909 let rx = manager
911 .request_confirmation("tool-1", "bash", &serde_json::json!({}))
912 .await;
913
914 let _ = event_rx.recv().await.unwrap();
916
917 tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
919
920 let timed_out = manager.check_timeouts().await;
922 assert_eq!(timed_out, 1);
923
924 let event = event_rx.recv().await.unwrap();
926 match event {
927 AgentEvent::ConfirmationTimeout {
928 tool_id,
929 action_taken,
930 } => {
931 assert_eq!(tool_id, "tool-1");
932 assert_eq!(action_taken, "rejected");
933 }
934 _ => panic!("Expected ConfirmationTimeout event"),
935 }
936
937 let response = rx.await.unwrap();
939 assert!(!response.approved);
940 assert!(response.reason.as_ref().unwrap().contains("timed out"));
941 }
942
943 #[tokio::test]
944 async fn test_timeout_auto_approve() {
945 let (event_tx, mut event_rx) = broadcast::channel(100);
946 let policy = ConfirmationPolicy {
947 enabled: true,
948 default_timeout_ms: 50, timeout_action: TimeoutAction::AutoApprove,
950 ..Default::default()
951 };
952 let manager = ConfirmationManager::new(policy, event_tx);
953
954 let rx = manager
956 .request_confirmation("tool-1", "bash", &serde_json::json!({}))
957 .await;
958
959 let _ = event_rx.recv().await.unwrap();
961
962 tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
964
965 let timed_out = manager.check_timeouts().await;
967 assert_eq!(timed_out, 1);
968
969 let event = event_rx.recv().await.unwrap();
971 match event {
972 AgentEvent::ConfirmationTimeout {
973 tool_id,
974 action_taken,
975 } => {
976 assert_eq!(tool_id, "tool-1");
977 assert_eq!(action_taken, "auto_approved");
978 }
979 _ => panic!("Expected ConfirmationTimeout event"),
980 }
981
982 let response = rx.await.unwrap();
984 assert!(response.approved);
985 assert!(response.reason.as_ref().unwrap().contains("auto_approved"));
986 }
987
988 #[tokio::test]
989 async fn test_no_timeout_when_confirmed() {
990 let (event_tx, _) = broadcast::channel(100);
991 let policy = ConfirmationPolicy {
992 enabled: true,
993 default_timeout_ms: 50,
994 timeout_action: TimeoutAction::Reject,
995 ..Default::default()
996 };
997 let manager = ConfirmationManager::new(policy, event_tx);
998
999 let rx = manager
1001 .request_confirmation("tool-1", "bash", &serde_json::json!({}))
1002 .await;
1003
1004 manager.confirm("tool-1", true, None).await.unwrap();
1006
1007 tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
1009
1010 let timed_out = manager.check_timeouts().await;
1012 assert_eq!(timed_out, 0);
1013
1014 let response = rx.await.unwrap();
1016 assert!(response.approved);
1017 assert!(response.reason.is_none());
1018 }
1019}