1use async_trait::async_trait;
37use parking_lot::Mutex;
38use cortexai_core::{ToolCall, ToolSchema};
39use serde::{Deserialize, Serialize};
40use std::collections::HashSet;
41use std::io::{self, Write};
42use std::sync::atomic::{AtomicUsize, Ordering};
43use std::time::Duration;
44use thiserror::Error;
45
46#[derive(Debug, Error)]
48pub enum ApprovalError {
49 #[error("Tool execution denied by user: {0}")]
50 Denied(String),
51
52 #[error("Approval timeout after {0} seconds")]
53 Timeout(u64),
54
55 #[error("Approval handler error: {0}")]
56 HandlerError(String),
57
58 #[error("IO error: {0}")]
59 IoError(#[from] io::Error),
60}
61
62#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
64pub enum ApprovalDecision {
65 Approved,
67
68 Denied { reason: String },
70
71 Modify { new_arguments: serde_json::Value },
73
74 Skip,
76}
77
78#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
80pub enum ApprovalStatus {
81 Pending,
82 Approved,
83 Rejected,
84 Modified,
85 TimedOut,
86}
87
88impl ApprovalDecision {
89 pub fn is_approved(&self) -> bool {
90 matches!(self, Self::Approved | Self::Modify { .. })
91 }
92
93 pub fn denied(reason: impl Into<String>) -> Self {
94 Self::Denied {
95 reason: reason.into(),
96 }
97 }
98}
99
100#[derive(Debug, Clone, Serialize, Deserialize)]
102pub struct ApprovalRequest {
103 pub tool_call: ToolCall,
105
106 pub tool_schema: Option<ToolSchema>,
108
109 pub reason: ApprovalReason,
111
112 pub timestamp: chrono::DateTime<chrono::Utc>,
114
115 pub context: Option<String>,
117}
118
119#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
121pub enum ApprovalReason {
122 DangerousTool,
124
125 SensitiveTool,
127
128 ExternalApi,
130
131 MatchesPattern(String),
133
134 AllToolsRequireApproval,
136
137 Custom(String),
139}
140
141impl std::fmt::Display for ApprovalReason {
142 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
143 match self {
144 Self::DangerousTool => write!(f, "Tool is marked as dangerous"),
145 Self::SensitiveTool => write!(f, "Tool is marked as sensitive"),
146 Self::ExternalApi => write!(f, "Tool is marked as external"),
147 Self::MatchesPattern(p) => write!(f, "Tool matches pattern: {}", p),
148 Self::AllToolsRequireApproval => write!(f, "All tools require approval"),
149 Self::Custom(r) => write!(f, "{}", r),
150 }
151 }
152}
153
154#[derive(Debug, Clone, Serialize, Deserialize)]
156pub struct ApprovalConfig {
157 pub require_approval_for_dangerous: bool,
159
160 pub additional_sensitive_tools: HashSet<String>,
162
163 pub require_approval_for_all: bool,
165
166 pub require_approval_for_external: bool,
168
169 pub auto_approve_timeout: Option<Duration>,
171
172 pub always_approve_tools: HashSet<String>,
174
175 pub always_deny_tools: HashSet<String>,
177
178 pub approval_patterns: Vec<String>,
180
181 pub timeout_seconds: u64,
183
184 pub max_pending_approvals: usize,
186}
187
188impl Default for ApprovalConfig {
189 fn default() -> Self {
190 Self {
191 require_approval_for_dangerous: true,
192 additional_sensitive_tools: HashSet::new(),
193 require_approval_for_all: false,
194 require_approval_for_external: false,
195 auto_approve_timeout: None,
196 always_approve_tools: HashSet::new(),
197 always_deny_tools: HashSet::new(),
198 approval_patterns: Vec::new(),
199 timeout_seconds: 0,
200 max_pending_approvals: 100,
201 }
202 }
203}
204
205impl ApprovalConfig {
206 pub fn builder() -> ApprovalConfigBuilder {
208 ApprovalConfigBuilder::default()
209 }
210
211 pub fn requires_approval(
213 &self,
214 tool_name: &str,
215 tool_schema: Option<&ToolSchema>,
216 ) -> Option<ApprovalReason> {
217 if self.always_deny_tools.contains(tool_name) {
219 return Some(ApprovalReason::Custom("Tool is in deny list".to_string()));
220 }
221
222 if self.always_approve_tools.contains(tool_name) {
224 return None;
225 }
226
227 if self.additional_sensitive_tools.contains(tool_name) {
229 return Some(ApprovalReason::SensitiveTool);
230 }
231
232 if self.require_approval_for_all {
234 return Some(ApprovalReason::AllToolsRequireApproval);
235 }
236
237 let is_dangerous = tool_schema.map(|s| s.dangerous).unwrap_or(false);
239 if self.require_approval_for_dangerous && is_dangerous {
240 return Some(ApprovalReason::DangerousTool);
241 }
242
243 let is_external = tool_schema
245 .and_then(|s| s.metadata.get("external"))
246 .and_then(|v| v.as_bool())
247 .unwrap_or(false);
248 if self.require_approval_for_external && is_external {
249 return Some(ApprovalReason::ExternalApi);
250 }
251
252 for pattern in &self.approval_patterns {
254 if let Ok(re) = regex::Regex::new(pattern) {
255 if re.is_match(tool_name) {
256 return Some(ApprovalReason::MatchesPattern(pattern.clone()));
257 }
258 }
259 }
260
261 None
262 }
263}
264
265#[derive(Default)]
267pub struct ApprovalConfigBuilder {
268 config: ApprovalConfig,
269}
270
271impl ApprovalConfigBuilder {
272 pub fn require_approval_for_dangerous(mut self, value: bool) -> Self {
273 self.config.require_approval_for_dangerous = value;
274 self
275 }
276
277 pub fn additional_sensitive_tools(mut self, tools: Vec<impl Into<String>>) -> Self {
278 self.config.additional_sensitive_tools = tools.into_iter().map(Into::into).collect();
279 self
280 }
281
282 pub fn add_sensitive_tool(mut self, tool: impl Into<String>) -> Self {
283 self.config.additional_sensitive_tools.insert(tool.into());
284 self
285 }
286
287 pub fn require_approval_for_all(mut self, value: bool) -> Self {
288 self.config.require_approval_for_all = value;
289 self
290 }
291
292 pub fn require_approval_for_external(mut self, value: bool) -> Self {
293 self.config.require_approval_for_external = value;
294 self
295 }
296
297 pub fn auto_approve_timeout(mut self, timeout: Option<Duration>) -> Self {
298 self.config.auto_approve_timeout = timeout;
299 self
300 }
301
302 pub fn always_approve_tools(mut self, tools: Vec<impl Into<String>>) -> Self {
303 self.config.always_approve_tools = tools.into_iter().map(Into::into).collect();
304 self
305 }
306
307 pub fn always_deny_tools(mut self, tools: Vec<impl Into<String>>) -> Self {
308 self.config.always_deny_tools = tools.into_iter().map(Into::into).collect();
309 self
310 }
311
312 pub fn add_approval_pattern(mut self, pattern: impl Into<String>) -> Self {
313 self.config.approval_patterns.push(pattern.into());
314 self
315 }
316
317 pub fn timeout_seconds(mut self, seconds: u64) -> Self {
318 self.config.timeout_seconds = seconds;
319 self
320 }
321
322 pub fn max_pending_approvals(mut self, max: usize) -> Self {
323 self.config.max_pending_approvals = max;
324 self
325 }
326
327 pub fn build(self) -> ApprovalConfig {
328 self.config
329 }
330}
331
332#[async_trait]
334pub trait ApprovalHandler: Send + Sync {
335 async fn request_approval(
337 &self,
338 request: ApprovalRequest,
339 ) -> Result<ApprovalDecision, ApprovalError>;
340
341 fn config(&self) -> &ApprovalConfig;
343
344 async fn on_status_change(&self, _request_id: &str, _status: ApprovalStatus) {}
346
347 fn check_requires_approval(
349 &self,
350 tool_call: &ToolCall,
351 tool_schema: Option<&ToolSchema>,
352 ) -> Option<ApprovalReason> {
353 self.config()
354 .requires_approval(&tool_call.name, tool_schema)
355 }
356}
357
358pub struct TerminalApprovalHandler {
360 config: ApprovalConfig,
361 pending_count: AtomicUsize,
362}
363
364impl TerminalApprovalHandler {
365 pub fn new(config: ApprovalConfig) -> Self {
366 Self {
367 config,
368 pending_count: AtomicUsize::new(0),
369 }
370 }
371
372 fn format_request(&self, request: &ApprovalRequest) -> String {
373 let mut output = String::new();
374 output.push('\n');
375 output.push_str("╔══════════════════════════════════════════════════════════════╗\n");
376 output.push_str("║ 🔒 TOOL APPROVAL REQUIRED ║\n");
377 output.push_str("╠══════════════════════════════════════════════════════════════╣\n");
378 output.push_str(&format!("║ Tool: {:<55} ║\n", request.tool_call.name));
379 output.push_str(&format!("║ Reason: {:<53} ║\n", request.reason));
380 output.push_str("╠══════════════════════════════════════════════════════════════╣\n");
381 output.push_str("║ Arguments: ║\n");
382
383 let args_str = serde_json::to_string_pretty(&request.tool_call.arguments)
385 .unwrap_or_else(|_| request.tool_call.arguments.to_string());
386
387 for line in args_str.lines().take(10) {
388 let truncated = if line.len() > 60 { &line[..57] } else { line };
389 output.push_str(&format!("║ {:<59} ║\n", truncated));
390 }
391
392 if args_str.lines().count() > 10 {
393 output.push_str("║ ... (truncated) ║\n");
394 }
395
396 if let Some(ref ctx) = request.context {
397 output.push_str("╠══════════════════════════════════════════════════════════════╣\n");
398 output.push_str(&format!("║ Context: {:<52} ║\n", ctx));
399 }
400
401 output.push_str("╠══════════════════════════════════════════════════════════════╣\n");
402 output.push_str("║ [Y]es / [N]o / [S]kip / [A]lways approve / [D]eny always ║\n");
403 output.push_str("╚══════════════════════════════════════════════════════════════╝\n");
404 output.push_str("> ");
405
406 output
407 }
408}
409
410#[async_trait]
411impl ApprovalHandler for TerminalApprovalHandler {
412 async fn request_approval(
413 &self,
414 request: ApprovalRequest,
415 ) -> Result<ApprovalDecision, ApprovalError> {
416 let pending = self.pending_count.fetch_add(1, Ordering::SeqCst);
418 if pending >= self.config.max_pending_approvals {
419 self.pending_count.fetch_sub(1, Ordering::SeqCst);
420 return Ok(ApprovalDecision::denied("Too many pending approvals"));
421 }
422
423 let prompt = self.format_request(&request);
425 print!("{}", prompt);
426 io::stdout().flush()?;
427
428 let input: String;
430
431 if self.config.timeout_seconds > 0 {
433 let timeout = tokio::time::Duration::from_secs(self.config.timeout_seconds);
434 let read_result = tokio::time::timeout(timeout, async {
435 tokio::task::spawn_blocking(|| {
436 let mut buf = String::new();
437 io::stdin().read_line(&mut buf)?;
438 Ok::<_, io::Error>(buf)
439 })
440 .await
441 })
442 .await;
443
444 match read_result {
445 Ok(Ok(Ok(s))) => input = s,
446 Ok(Ok(Err(e))) => {
447 self.pending_count.fetch_sub(1, Ordering::SeqCst);
448 return Err(ApprovalError::IoError(e));
449 }
450 Ok(Err(e)) => {
451 self.pending_count.fetch_sub(1, Ordering::SeqCst);
452 return Err(ApprovalError::HandlerError(e.to_string()));
453 }
454 Err(_) => {
455 self.pending_count.fetch_sub(1, Ordering::SeqCst);
456 return Err(ApprovalError::Timeout(self.config.timeout_seconds));
457 }
458 }
459 } else {
460 let mut buf = String::new();
461 io::stdin().read_line(&mut buf)?;
462 input = buf;
463 }
464
465 self.pending_count.fetch_sub(1, Ordering::SeqCst);
466
467 let decision = match input.trim().to_lowercase().as_str() {
468 "y" | "yes" => ApprovalDecision::Approved,
469 "n" | "no" => ApprovalDecision::denied("User denied"),
470 "s" | "skip" => ApprovalDecision::Skip,
471 "a" | "always" => {
472 println!(
475 " ✓ Tool '{}' will be auto-approved in this session",
476 request.tool_call.name
477 );
478 ApprovalDecision::Approved
479 }
480 "d" | "deny" => {
481 println!(
482 " ✗ Tool '{}' will be auto-denied in this session",
483 request.tool_call.name
484 );
485 ApprovalDecision::denied("User set to always deny")
486 }
487 _ => {
488 println!(" Invalid input, defaulting to deny");
489 ApprovalDecision::denied("Invalid input")
490 }
491 };
492
493 Ok(decision)
494 }
495
496 fn config(&self) -> &ApprovalConfig {
497 &self.config
498 }
499}
500
501pub struct TestApprovalHandler {
503 config: ApprovalConfig,
504 decisions: Mutex<std::collections::HashMap<String, ApprovalDecision>>,
506 default_decision: ApprovalDecision,
508 requests: Mutex<Vec<ApprovalRequest>>,
510 call_count: AtomicUsize,
512}
513
514impl TestApprovalHandler {
515 pub fn approve_all() -> Self {
517 Self {
518 config: ApprovalConfig::default(),
519 decisions: Mutex::new(std::collections::HashMap::new()),
520 default_decision: ApprovalDecision::Approved,
521 requests: Mutex::new(Vec::new()),
522 call_count: AtomicUsize::new(0),
523 }
524 }
525
526 pub fn deny_all() -> Self {
528 Self {
529 config: ApprovalConfig::default(),
530 decisions: Mutex::new(std::collections::HashMap::new()),
531 default_decision: ApprovalDecision::denied("Test handler denies all"),
532 requests: Mutex::new(Vec::new()),
533 call_count: AtomicUsize::new(0),
534 }
535 }
536
537 pub fn with_decisions(decisions: Vec<(impl Into<String>, ApprovalDecision)>) -> Self {
539 let map: std::collections::HashMap<String, ApprovalDecision> =
540 decisions.into_iter().map(|(k, v)| (k.into(), v)).collect();
541
542 Self {
543 config: ApprovalConfig::default(),
544 decisions: Mutex::new(map),
545 default_decision: ApprovalDecision::Approved,
546 requests: Mutex::new(Vec::new()),
547 call_count: AtomicUsize::new(0),
548 }
549 }
550
551 pub fn with_config(config: ApprovalConfig) -> Self {
553 Self {
554 config,
555 decisions: Mutex::new(std::collections::HashMap::new()),
556 default_decision: ApprovalDecision::Approved,
557 requests: Mutex::new(Vec::new()),
558 call_count: AtomicUsize::new(0),
559 }
560 }
561
562 pub fn set_decision(&self, tool_name: impl Into<String>, decision: ApprovalDecision) {
564 self.decisions.lock().insert(tool_name.into(), decision);
565 }
566
567 pub fn get_requests(&self) -> Vec<ApprovalRequest> {
569 self.requests.lock().clone()
570 }
571
572 pub fn request_count(&self) -> usize {
574 self.call_count.load(Ordering::SeqCst)
575 }
576
577 pub fn clear(&self) {
579 self.requests.lock().clear();
580 self.call_count.store(0, Ordering::SeqCst);
581 }
582
583 pub fn was_approval_requested(&self, tool_name: &str) -> bool {
585 self.requests
586 .lock()
587 .iter()
588 .any(|r| r.tool_call.name == tool_name)
589 }
590}
591
592#[async_trait]
593impl ApprovalHandler for TestApprovalHandler {
594 async fn request_approval(
595 &self,
596 request: ApprovalRequest,
597 ) -> Result<ApprovalDecision, ApprovalError> {
598 self.call_count.fetch_add(1, Ordering::SeqCst);
599
600 let tool_name = request.tool_call.name.clone();
601 self.requests.lock().push(request);
602
603 let decisions = self.decisions.lock();
604 let decision = decisions
605 .get(&tool_name)
606 .cloned()
607 .unwrap_or_else(|| self.default_decision.clone());
608
609 Ok(decision)
610 }
611
612 fn config(&self) -> &ApprovalConfig {
613 &self.config
614 }
615}
616
617#[derive(Clone)]
619pub struct AutoApproveHandler {
620 config: ApprovalConfig,
621}
622
623impl AutoApproveHandler {
624 pub fn new() -> Self {
625 Self {
626 config: ApprovalConfig {
627 require_approval_for_dangerous: false,
628 require_approval_for_all: false,
629 ..Default::default()
630 },
631 }
632 }
633}
634
635impl Default for AutoApproveHandler {
636 fn default() -> Self {
637 Self::new()
638 }
639}
640
641#[async_trait]
642impl ApprovalHandler for AutoApproveHandler {
643 async fn request_approval(
644 &self,
645 _request: ApprovalRequest,
646 ) -> Result<ApprovalDecision, ApprovalError> {
647 Ok(ApprovalDecision::Approved)
648 }
649
650 fn config(&self) -> &ApprovalConfig {
651 &self.config
652 }
653}
654
655#[cfg(test)]
656mod tests {
657 use super::*;
658 use serde_json::json;
659
660 #[test]
661 fn test_approval_config_dangerous() {
662 let config = ApprovalConfig::builder()
663 .require_approval_for_dangerous(true)
664 .build();
665
666 let dangerous_schema = ToolSchema {
667 name: "delete_file".to_string(),
668 description: "delete".to_string(),
669 parameters: json!({}),
670 dangerous: true,
671 metadata: std::collections::HashMap::new(),
672 required_scopes: vec![],
673 };
674
675 let safe_schema = ToolSchema {
676 name: "read_file".to_string(),
677 description: "read".to_string(),
678 parameters: json!({}),
679 dangerous: false,
680 metadata: std::collections::HashMap::new(),
681 required_scopes: vec![],
682 };
683
684 assert!(config
686 .requires_approval(&dangerous_schema.name, Some(&dangerous_schema))
687 .is_some());
688
689 assert!(config
691 .requires_approval(&safe_schema.name, Some(&safe_schema))
692 .is_none());
693 }
694
695 #[test]
696 fn test_approval_config_always_approve() {
697 let config = ApprovalConfig::builder()
698 .require_approval_for_all(true)
699 .always_approve_tools(vec!["safe_tool"])
700 .build();
701
702 let safe_schema = ToolSchema {
704 name: "safe_tool".to_string(),
705 description: "".to_string(),
706 parameters: json!({}),
707 dangerous: true,
708 metadata: std::collections::HashMap::new(),
709 required_scopes: vec![],
710 };
711 assert!(config
712 .requires_approval(&safe_schema.name, Some(&safe_schema))
713 .is_none());
714
715 assert!(config.requires_approval("other_tool", None).is_some());
717 }
718
719 #[test]
720 fn test_approval_config_always_deny() {
721 let config = ApprovalConfig::builder()
722 .always_deny_tools(vec!["dangerous_tool"])
723 .build();
724
725 let reason = config.requires_approval("dangerous_tool", None);
727 assert!(reason.is_some());
728 }
729
730 #[test]
731 fn test_approval_config_patterns() {
732 let config = ApprovalConfig::builder()
733 .add_approval_pattern("delete_.*")
734 .add_approval_pattern(".*_dangerous")
735 .build();
736
737 assert!(config.requires_approval("delete_file", None).is_some());
738 assert!(config.requires_approval("delete_folder", None).is_some());
739 assert!(config.requires_approval("run_dangerous", None).is_some());
740 assert!(config.requires_approval("read_file", None).is_none());
741 }
742
743 #[test]
744 fn test_approval_config_additional_sensitive() {
745 let config = ApprovalConfig::builder()
746 .add_sensitive_tool("export_data")
747 .build();
748
749 assert!(config.requires_approval("export_data", None).is_some());
750 assert!(config.requires_approval("other", None).is_none());
751 }
752
753 #[test]
754 fn test_approval_config_external_metadata() {
755 let mut metadata = std::collections::HashMap::new();
756 metadata.insert("external".to_string(), json!(true));
757
758 let schema = ToolSchema {
759 name: "call_api".to_string(),
760 description: "External API".to_string(),
761 parameters: json!({}),
762 dangerous: false,
763 metadata,
764 required_scopes: vec![],
765 };
766
767 let config = ApprovalConfig::builder()
768 .require_approval_for_external(true)
769 .build();
770
771 assert!(config
772 .requires_approval(&schema.name, Some(&schema))
773 .is_some());
774 }
775
776 #[tokio::test]
777 async fn test_test_handler_approve_all() {
778 let handler = TestApprovalHandler::approve_all();
779
780 let request = ApprovalRequest {
781 tool_call: ToolCall {
782 id: "1".to_string(),
783 name: "any_tool".to_string(),
784 arguments: json!({}),
785 },
786 tool_schema: None,
787 reason: ApprovalReason::DangerousTool,
788 timestamp: chrono::Utc::now(),
789 context: None,
790 };
791
792 let decision = handler.request_approval(request).await.unwrap();
793 assert_eq!(decision, ApprovalDecision::Approved);
794 assert_eq!(handler.request_count(), 1);
795 }
796
797 #[tokio::test]
798 async fn test_test_handler_deny_all() {
799 let handler = TestApprovalHandler::deny_all();
800
801 let request = ApprovalRequest {
802 tool_call: ToolCall {
803 id: "1".to_string(),
804 name: "any_tool".to_string(),
805 arguments: json!({}),
806 },
807 tool_schema: None,
808 reason: ApprovalReason::DangerousTool,
809 timestamp: chrono::Utc::now(),
810 context: None,
811 };
812
813 let decision = handler.request_approval(request).await.unwrap();
814 assert!(!decision.is_approved());
815 }
816
817 #[tokio::test]
818 async fn test_test_handler_specific_decisions() {
819 let handler = TestApprovalHandler::with_decisions(vec![
820 ("tool_a", ApprovalDecision::Approved),
821 ("tool_b", ApprovalDecision::denied("Not allowed")),
822 ]);
823
824 let make_request = |name: &str| ApprovalRequest {
825 tool_call: ToolCall {
826 id: "1".to_string(),
827 name: name.to_string(),
828 arguments: json!({}),
829 },
830 tool_schema: None,
831 reason: ApprovalReason::DangerousTool,
832 timestamp: chrono::Utc::now(),
833 context: None,
834 };
835
836 let d1 = handler
837 .request_approval(make_request("tool_a"))
838 .await
839 .unwrap();
840 assert_eq!(d1, ApprovalDecision::Approved);
841
842 let d2 = handler
843 .request_approval(make_request("tool_b"))
844 .await
845 .unwrap();
846 assert!(!d2.is_approved());
847
848 let d3 = handler
850 .request_approval(make_request("tool_c"))
851 .await
852 .unwrap();
853 assert_eq!(d3, ApprovalDecision::Approved);
854 }
855
856 #[tokio::test]
857 async fn test_test_handler_records_requests() {
858 let handler = TestApprovalHandler::approve_all();
859
860 let request = ApprovalRequest {
861 tool_call: ToolCall {
862 id: "1".to_string(),
863 name: "test_tool".to_string(),
864 arguments: json!({"key": "value"}),
865 },
866 tool_schema: None,
867 reason: ApprovalReason::DangerousTool,
868 timestamp: chrono::Utc::now(),
869 context: Some("Test context".to_string()),
870 };
871
872 handler.request_approval(request).await.unwrap();
873
874 assert!(handler.was_approval_requested("test_tool"));
875 assert!(!handler.was_approval_requested("other_tool"));
876
877 let requests = handler.get_requests();
878 assert_eq!(requests.len(), 1);
879 assert_eq!(requests[0].tool_call.name, "test_tool");
880 }
881
882 #[tokio::test]
883 async fn test_auto_approve_handler() {
884 let handler = AutoApproveHandler::new();
885
886 let request = ApprovalRequest {
887 tool_call: ToolCall {
888 id: "1".to_string(),
889 name: "dangerous_tool".to_string(),
890 arguments: json!({}),
891 },
892 tool_schema: Some(ToolSchema {
893 name: "dangerous_tool".to_string(),
894 description: "A dangerous tool".to_string(),
895 parameters: json!({}),
896 dangerous: true,
897 metadata: std::collections::HashMap::new(),
898 required_scopes: vec![],
899 }),
900 reason: ApprovalReason::DangerousTool,
901 timestamp: chrono::Utc::now(),
902 context: None,
903 };
904
905 let decision = handler.request_approval(request).await.unwrap();
906 assert_eq!(decision, ApprovalDecision::Approved);
907 }
908
909 #[test]
910 fn test_approval_decision_is_approved() {
911 assert!(ApprovalDecision::Approved.is_approved());
912 assert!(ApprovalDecision::Modify {
913 new_arguments: json!({})
914 }
915 .is_approved());
916 assert!(!ApprovalDecision::Denied {
917 reason: "no".to_string()
918 }
919 .is_approved());
920 assert!(!ApprovalDecision::Skip.is_approved());
921 }
922
923 #[test]
924 fn test_check_requires_approval() {
925 let config = ApprovalConfig::builder()
926 .require_approval_for_dangerous(true)
927 .build();
928
929 let handler = TestApprovalHandler::with_config(config);
930
931 let safe_call = ToolCall {
932 id: "1".to_string(),
933 name: "safe_tool".to_string(),
934 arguments: json!({}),
935 };
936
937 let safe_schema = ToolSchema {
938 name: "safe_tool".to_string(),
939 description: "Safe".to_string(),
940 parameters: json!({}),
941 dangerous: false,
942 metadata: std::collections::HashMap::new(),
943 required_scopes: vec![],
944 };
945
946 let dangerous_schema = ToolSchema {
947 name: "dangerous_tool".to_string(),
948 description: "Dangerous".to_string(),
949 parameters: json!({}),
950 dangerous: true,
951 metadata: std::collections::HashMap::new(),
952 required_scopes: vec![],
953 };
954
955 assert!(handler
957 .check_requires_approval(&safe_call, Some(&safe_schema))
958 .is_none());
959
960 assert!(handler
962 .check_requires_approval(&safe_call, Some(&dangerous_schema))
963 .is_some());
964
965 assert!(handler.check_requires_approval(&safe_call, None).is_none());
967 }
968}