Skip to main content

cortexai_agents/
approvals.rs

1//! Human-in-the-loop approval system for dangerous tool executions
2//!
3//! This module provides a configurable approval mechanism that intercepts
4//! tool calls marked as `dangerous` and requires human approval before execution.
5//!
6//! # Architecture
7//!
8//! ```text
9//! ┌─────────────┐     ┌──────────────┐     ┌──────────────┐
10//! │   Engine    │────>│  Approvals   │────>│   Executor   │
11//! │             │     │   Handler    │     │              │
12//! └─────────────┘     └──────────────┘     └──────────────┘
13//!                            │
14//!                     ┌──────┴──────┐
15//!                     │             │
16//!               ┌─────▼─────┐ ┌─────▼─────┐
17//!               │ Terminal  │ │   Test    │
18//!               │ Handler   │ │  Handler  │
19//!               └───────────┘ └───────────┘
20//! ```
21//!
22//! # Example
23//!
24//! ```rust,ignore
25//! use cortexai_agents::approvals::{ApprovalConfig, TerminalApprovalHandler};
26//!
27//! let config = ApprovalConfig::builder()
28//!     .require_approval_for_dangerous(true)
29//!     .always_approve_tools(vec!["read_file"])
30//!     .always_deny_tools(vec!["delete_all"])
31//!     .build();
32//!
33//! let handler = TerminalApprovalHandler::new(config);
34//! ```
35
36use async_trait::async_trait;
37use parking_lot::Mutex;
38use cortexai_core::{ToolCall, ToolSchema};
39use serde::{Deserialize, Serialize};
40use std::collections::HashSet;
41use std::io::{self, Write};
42use std::sync::atomic::{AtomicUsize, Ordering};
43use std::time::Duration;
44use thiserror::Error;
45
46/// Errors that can occur during approval
47#[derive(Debug, Error)]
48pub enum ApprovalError {
49    #[error("Tool execution denied by user: {0}")]
50    Denied(String),
51
52    #[error("Approval timeout after {0} seconds")]
53    Timeout(u64),
54
55    #[error("Approval handler error: {0}")]
56    HandlerError(String),
57
58    #[error("IO error: {0}")]
59    IoError(#[from] io::Error),
60}
61
62/// Result of an approval request
63#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
64pub enum ApprovalDecision {
65    /// Approved - proceed with execution
66    Approved,
67
68    /// Denied - do not execute
69    Denied { reason: String },
70
71    /// Modify - execute with modified arguments
72    Modify { new_arguments: serde_json::Value },
73
74    /// Skip - skip this tool but continue the agent loop
75    Skip,
76}
77
78/// Lifecycle states for an approval request
79#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
80pub enum ApprovalStatus {
81    Pending,
82    Approved,
83    Rejected,
84    Modified,
85    TimedOut,
86}
87
88impl ApprovalDecision {
89    pub fn is_approved(&self) -> bool {
90        matches!(self, Self::Approved | Self::Modify { .. })
91    }
92
93    pub fn denied(reason: impl Into<String>) -> Self {
94        Self::Denied {
95            reason: reason.into(),
96        }
97    }
98}
99
100/// Information about a pending approval request
101#[derive(Debug, Clone, Serialize, Deserialize)]
102pub struct ApprovalRequest {
103    /// The tool call requiring approval
104    pub tool_call: ToolCall,
105
106    /// The tool schema (for context)
107    pub tool_schema: Option<ToolSchema>,
108
109    /// Why approval is required
110    pub reason: ApprovalReason,
111
112    /// Request timestamp
113    pub timestamp: chrono::DateTime<chrono::Utc>,
114
115    /// Additional context for the approver
116    pub context: Option<String>,
117}
118
119/// Reason why approval is required
120#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
121pub enum ApprovalReason {
122    /// Tool is marked as dangerous
123    DangerousTool,
124
125    /// Tool is explicitly configured as sensitive
126    SensitiveTool,
127
128    /// Tool metadata declares external API usage
129    ExternalApi,
130
131    /// Tool matches a pattern requiring approval
132    MatchesPattern(String),
133
134    /// All tools require approval
135    AllToolsRequireApproval,
136
137    /// Custom reason
138    Custom(String),
139}
140
141impl std::fmt::Display for ApprovalReason {
142    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
143        match self {
144            Self::DangerousTool => write!(f, "Tool is marked as dangerous"),
145            Self::SensitiveTool => write!(f, "Tool is marked as sensitive"),
146            Self::ExternalApi => write!(f, "Tool is marked as external"),
147            Self::MatchesPattern(p) => write!(f, "Tool matches pattern: {}", p),
148            Self::AllToolsRequireApproval => write!(f, "All tools require approval"),
149            Self::Custom(r) => write!(f, "{}", r),
150        }
151    }
152}
153
154/// Configuration for the approval system
155#[derive(Debug, Clone, Serialize, Deserialize)]
156pub struct ApprovalConfig {
157    /// Require approval for tools marked as dangerous
158    pub require_approval_for_dangerous: bool,
159
160    /// Additional tools that must always go through approval
161    pub additional_sensitive_tools: HashSet<String>,
162
163    /// Require approval for all tool calls
164    pub require_approval_for_all: bool,
165
166    /// Require approval for tools marked as external
167    pub require_approval_for_external: bool,
168
169    /// Auto-approve after timeout if configured (None = wait forever)
170    pub auto_approve_timeout: Option<Duration>,
171
172    /// Tools that are always approved (bypass approval)
173    pub always_approve_tools: HashSet<String>,
174
175    /// Tools that are always denied
176    pub always_deny_tools: HashSet<String>,
177
178    /// Patterns for tools requiring approval (regex)
179    pub approval_patterns: Vec<String>,
180
181    /// Timeout for approval requests in seconds (0 = no timeout)
182    pub timeout_seconds: u64,
183
184    /// Maximum number of approvals to request before auto-denying
185    pub max_pending_approvals: usize,
186}
187
188impl Default for ApprovalConfig {
189    fn default() -> Self {
190        Self {
191            require_approval_for_dangerous: true,
192            additional_sensitive_tools: HashSet::new(),
193            require_approval_for_all: false,
194            require_approval_for_external: false,
195            auto_approve_timeout: None,
196            always_approve_tools: HashSet::new(),
197            always_deny_tools: HashSet::new(),
198            approval_patterns: Vec::new(),
199            timeout_seconds: 0,
200            max_pending_approvals: 100,
201        }
202    }
203}
204
205impl ApprovalConfig {
206    /// Create a new builder
207    pub fn builder() -> ApprovalConfigBuilder {
208        ApprovalConfigBuilder::default()
209    }
210
211    /// Check if a tool requires approval
212    pub fn requires_approval(
213        &self,
214        tool_name: &str,
215        tool_schema: Option<&ToolSchema>,
216    ) -> Option<ApprovalReason> {
217        // Check always deny first
218        if self.always_deny_tools.contains(tool_name) {
219            return Some(ApprovalReason::Custom("Tool is in deny list".to_string()));
220        }
221
222        // Check always approve (bypass)
223        if self.always_approve_tools.contains(tool_name) {
224            return None;
225        }
226
227        // Additional sensitive tools override other checks
228        if self.additional_sensitive_tools.contains(tool_name) {
229            return Some(ApprovalReason::SensitiveTool);
230        }
231
232        // Check all tools require approval
233        if self.require_approval_for_all {
234            return Some(ApprovalReason::AllToolsRequireApproval);
235        }
236
237        // Check dangerous flag
238        let is_dangerous = tool_schema.map(|s| s.dangerous).unwrap_or(false);
239        if self.require_approval_for_dangerous && is_dangerous {
240            return Some(ApprovalReason::DangerousTool);
241        }
242
243        // Check external metadata
244        let is_external = tool_schema
245            .and_then(|s| s.metadata.get("external"))
246            .and_then(|v| v.as_bool())
247            .unwrap_or(false);
248        if self.require_approval_for_external && is_external {
249            return Some(ApprovalReason::ExternalApi);
250        }
251
252        // Check patterns
253        for pattern in &self.approval_patterns {
254            if let Ok(re) = regex::Regex::new(pattern) {
255                if re.is_match(tool_name) {
256                    return Some(ApprovalReason::MatchesPattern(pattern.clone()));
257                }
258            }
259        }
260
261        None
262    }
263}
264
265/// Builder for ApprovalConfig
266#[derive(Default)]
267pub struct ApprovalConfigBuilder {
268    config: ApprovalConfig,
269}
270
271impl ApprovalConfigBuilder {
272    pub fn require_approval_for_dangerous(mut self, value: bool) -> Self {
273        self.config.require_approval_for_dangerous = value;
274        self
275    }
276
277    pub fn additional_sensitive_tools(mut self, tools: Vec<impl Into<String>>) -> Self {
278        self.config.additional_sensitive_tools = tools.into_iter().map(Into::into).collect();
279        self
280    }
281
282    pub fn add_sensitive_tool(mut self, tool: impl Into<String>) -> Self {
283        self.config.additional_sensitive_tools.insert(tool.into());
284        self
285    }
286
287    pub fn require_approval_for_all(mut self, value: bool) -> Self {
288        self.config.require_approval_for_all = value;
289        self
290    }
291
292    pub fn require_approval_for_external(mut self, value: bool) -> Self {
293        self.config.require_approval_for_external = value;
294        self
295    }
296
297    pub fn auto_approve_timeout(mut self, timeout: Option<Duration>) -> Self {
298        self.config.auto_approve_timeout = timeout;
299        self
300    }
301
302    pub fn always_approve_tools(mut self, tools: Vec<impl Into<String>>) -> Self {
303        self.config.always_approve_tools = tools.into_iter().map(Into::into).collect();
304        self
305    }
306
307    pub fn always_deny_tools(mut self, tools: Vec<impl Into<String>>) -> Self {
308        self.config.always_deny_tools = tools.into_iter().map(Into::into).collect();
309        self
310    }
311
312    pub fn add_approval_pattern(mut self, pattern: impl Into<String>) -> Self {
313        self.config.approval_patterns.push(pattern.into());
314        self
315    }
316
317    pub fn timeout_seconds(mut self, seconds: u64) -> Self {
318        self.config.timeout_seconds = seconds;
319        self
320    }
321
322    pub fn max_pending_approvals(mut self, max: usize) -> Self {
323        self.config.max_pending_approvals = max;
324        self
325    }
326
327    pub fn build(self) -> ApprovalConfig {
328        self.config
329    }
330}
331
332/// Trait for handling approval requests
333#[async_trait]
334pub trait ApprovalHandler: Send + Sync {
335    /// Request approval for a tool call
336    async fn request_approval(
337        &self,
338        request: ApprovalRequest,
339    ) -> Result<ApprovalDecision, ApprovalError>;
340
341    /// Get the configuration
342    fn config(&self) -> &ApprovalConfig;
343
344    /// React to status updates; default no-op
345    async fn on_status_change(&self, _request_id: &str, _status: ApprovalStatus) {}
346
347    /// Check if a tool requires approval and return the reason
348    fn check_requires_approval(
349        &self,
350        tool_call: &ToolCall,
351        tool_schema: Option<&ToolSchema>,
352    ) -> Option<ApprovalReason> {
353        self.config()
354            .requires_approval(&tool_call.name, tool_schema)
355    }
356}
357
358/// Terminal-based approval handler for interactive use
359pub struct TerminalApprovalHandler {
360    config: ApprovalConfig,
361    pending_count: AtomicUsize,
362}
363
364impl TerminalApprovalHandler {
365    pub fn new(config: ApprovalConfig) -> Self {
366        Self {
367            config,
368            pending_count: AtomicUsize::new(0),
369        }
370    }
371
372    fn format_request(&self, request: &ApprovalRequest) -> String {
373        let mut output = String::new();
374        output.push('\n');
375        output.push_str("╔══════════════════════════════════════════════════════════════╗\n");
376        output.push_str("║              🔒 TOOL APPROVAL REQUIRED                       ║\n");
377        output.push_str("╠══════════════════════════════════════════════════════════════╣\n");
378        output.push_str(&format!("║ Tool: {:<55} ║\n", request.tool_call.name));
379        output.push_str(&format!("║ Reason: {:<53} ║\n", request.reason));
380        output.push_str("╠══════════════════════════════════════════════════════════════╣\n");
381        output.push_str("║ Arguments:                                                   ║\n");
382
383        // Format arguments nicely
384        let args_str = serde_json::to_string_pretty(&request.tool_call.arguments)
385            .unwrap_or_else(|_| request.tool_call.arguments.to_string());
386
387        for line in args_str.lines().take(10) {
388            let truncated = if line.len() > 60 { &line[..57] } else { line };
389            output.push_str(&format!("║   {:<59} ║\n", truncated));
390        }
391
392        if args_str.lines().count() > 10 {
393            output.push_str("║   ... (truncated)                                           ║\n");
394        }
395
396        if let Some(ref ctx) = request.context {
397            output.push_str("╠══════════════════════════════════════════════════════════════╣\n");
398            output.push_str(&format!("║ Context: {:<52} ║\n", ctx));
399        }
400
401        output.push_str("╠══════════════════════════════════════════════════════════════╣\n");
402        output.push_str("║ [Y]es / [N]o / [S]kip / [A]lways approve / [D]eny always     ║\n");
403        output.push_str("╚══════════════════════════════════════════════════════════════╝\n");
404        output.push_str("> ");
405
406        output
407    }
408}
409
410#[async_trait]
411impl ApprovalHandler for TerminalApprovalHandler {
412    async fn request_approval(
413        &self,
414        request: ApprovalRequest,
415    ) -> Result<ApprovalDecision, ApprovalError> {
416        // Check pending count
417        let pending = self.pending_count.fetch_add(1, Ordering::SeqCst);
418        if pending >= self.config.max_pending_approvals {
419            self.pending_count.fetch_sub(1, Ordering::SeqCst);
420            return Ok(ApprovalDecision::denied("Too many pending approvals"));
421        }
422
423        // Display the request
424        let prompt = self.format_request(&request);
425        print!("{}", prompt);
426        io::stdout().flush()?;
427
428        // Read response
429        let input: String;
430
431        // Handle timeout if configured
432        if self.config.timeout_seconds > 0 {
433            let timeout = tokio::time::Duration::from_secs(self.config.timeout_seconds);
434            let read_result = tokio::time::timeout(timeout, async {
435                tokio::task::spawn_blocking(|| {
436                    let mut buf = String::new();
437                    io::stdin().read_line(&mut buf)?;
438                    Ok::<_, io::Error>(buf)
439                })
440                .await
441            })
442            .await;
443
444            match read_result {
445                Ok(Ok(Ok(s))) => input = s,
446                Ok(Ok(Err(e))) => {
447                    self.pending_count.fetch_sub(1, Ordering::SeqCst);
448                    return Err(ApprovalError::IoError(e));
449                }
450                Ok(Err(e)) => {
451                    self.pending_count.fetch_sub(1, Ordering::SeqCst);
452                    return Err(ApprovalError::HandlerError(e.to_string()));
453                }
454                Err(_) => {
455                    self.pending_count.fetch_sub(1, Ordering::SeqCst);
456                    return Err(ApprovalError::Timeout(self.config.timeout_seconds));
457                }
458            }
459        } else {
460            let mut buf = String::new();
461            io::stdin().read_line(&mut buf)?;
462            input = buf;
463        }
464
465        self.pending_count.fetch_sub(1, Ordering::SeqCst);
466
467        let decision = match input.trim().to_lowercase().as_str() {
468            "y" | "yes" => ApprovalDecision::Approved,
469            "n" | "no" => ApprovalDecision::denied("User denied"),
470            "s" | "skip" => ApprovalDecision::Skip,
471            "a" | "always" => {
472                // Note: In a real implementation, this would modify the config
473                // For now, just approve
474                println!(
475                    "  ✓ Tool '{}' will be auto-approved in this session",
476                    request.tool_call.name
477                );
478                ApprovalDecision::Approved
479            }
480            "d" | "deny" => {
481                println!(
482                    "  ✗ Tool '{}' will be auto-denied in this session",
483                    request.tool_call.name
484                );
485                ApprovalDecision::denied("User set to always deny")
486            }
487            _ => {
488                println!("  Invalid input, defaulting to deny");
489                ApprovalDecision::denied("Invalid input")
490            }
491        };
492
493        Ok(decision)
494    }
495
496    fn config(&self) -> &ApprovalConfig {
497        &self.config
498    }
499}
500
501/// Test approval handler with configurable responses
502pub struct TestApprovalHandler {
503    config: ApprovalConfig,
504    /// Pre-configured decisions (tool_name -> decision)
505    decisions: Mutex<std::collections::HashMap<String, ApprovalDecision>>,
506    /// Default decision for unspecified tools
507    default_decision: ApprovalDecision,
508    /// Recorded approval requests
509    requests: Mutex<Vec<ApprovalRequest>>,
510    /// Call counter
511    call_count: AtomicUsize,
512}
513
514impl TestApprovalHandler {
515    /// Create a handler that approves everything
516    pub fn approve_all() -> Self {
517        Self {
518            config: ApprovalConfig::default(),
519            decisions: Mutex::new(std::collections::HashMap::new()),
520            default_decision: ApprovalDecision::Approved,
521            requests: Mutex::new(Vec::new()),
522            call_count: AtomicUsize::new(0),
523        }
524    }
525
526    /// Create a handler that denies everything
527    pub fn deny_all() -> Self {
528        Self {
529            config: ApprovalConfig::default(),
530            decisions: Mutex::new(std::collections::HashMap::new()),
531            default_decision: ApprovalDecision::denied("Test handler denies all"),
532            requests: Mutex::new(Vec::new()),
533            call_count: AtomicUsize::new(0),
534        }
535    }
536
537    /// Create a handler with specific decisions per tool
538    pub fn with_decisions(decisions: Vec<(impl Into<String>, ApprovalDecision)>) -> Self {
539        let map: std::collections::HashMap<String, ApprovalDecision> =
540            decisions.into_iter().map(|(k, v)| (k.into(), v)).collect();
541
542        Self {
543            config: ApprovalConfig::default(),
544            decisions: Mutex::new(map),
545            default_decision: ApprovalDecision::Approved,
546            requests: Mutex::new(Vec::new()),
547            call_count: AtomicUsize::new(0),
548        }
549    }
550
551    /// Create with custom config
552    pub fn with_config(config: ApprovalConfig) -> Self {
553        Self {
554            config,
555            decisions: Mutex::new(std::collections::HashMap::new()),
556            default_decision: ApprovalDecision::Approved,
557            requests: Mutex::new(Vec::new()),
558            call_count: AtomicUsize::new(0),
559        }
560    }
561
562    /// Set decision for a specific tool
563    pub fn set_decision(&self, tool_name: impl Into<String>, decision: ApprovalDecision) {
564        self.decisions.lock().insert(tool_name.into(), decision);
565    }
566
567    /// Get all recorded requests
568    pub fn get_requests(&self) -> Vec<ApprovalRequest> {
569        self.requests.lock().clone()
570    }
571
572    /// Get the number of approval requests made
573    pub fn request_count(&self) -> usize {
574        self.call_count.load(Ordering::SeqCst)
575    }
576
577    /// Clear recorded requests
578    pub fn clear(&self) {
579        self.requests.lock().clear();
580        self.call_count.store(0, Ordering::SeqCst);
581    }
582
583    /// Check if a specific tool had approval requested
584    pub fn was_approval_requested(&self, tool_name: &str) -> bool {
585        self.requests
586            .lock()
587            .iter()
588            .any(|r| r.tool_call.name == tool_name)
589    }
590}
591
592#[async_trait]
593impl ApprovalHandler for TestApprovalHandler {
594    async fn request_approval(
595        &self,
596        request: ApprovalRequest,
597    ) -> Result<ApprovalDecision, ApprovalError> {
598        self.call_count.fetch_add(1, Ordering::SeqCst);
599
600        let tool_name = request.tool_call.name.clone();
601        self.requests.lock().push(request);
602
603        let decisions = self.decisions.lock();
604        let decision = decisions
605            .get(&tool_name)
606            .cloned()
607            .unwrap_or_else(|| self.default_decision.clone());
608
609        Ok(decision)
610    }
611
612    fn config(&self) -> &ApprovalConfig {
613        &self.config
614    }
615}
616
617/// Auto-approve handler (no human interaction)
618#[derive(Clone)]
619pub struct AutoApproveHandler {
620    config: ApprovalConfig,
621}
622
623impl AutoApproveHandler {
624    pub fn new() -> Self {
625        Self {
626            config: ApprovalConfig {
627                require_approval_for_dangerous: false,
628                require_approval_for_all: false,
629                ..Default::default()
630            },
631        }
632    }
633}
634
635impl Default for AutoApproveHandler {
636    fn default() -> Self {
637        Self::new()
638    }
639}
640
641#[async_trait]
642impl ApprovalHandler for AutoApproveHandler {
643    async fn request_approval(
644        &self,
645        _request: ApprovalRequest,
646    ) -> Result<ApprovalDecision, ApprovalError> {
647        Ok(ApprovalDecision::Approved)
648    }
649
650    fn config(&self) -> &ApprovalConfig {
651        &self.config
652    }
653}
654
655#[cfg(test)]
656mod tests {
657    use super::*;
658    use serde_json::json;
659
660    #[test]
661    fn test_approval_config_dangerous() {
662        let config = ApprovalConfig::builder()
663            .require_approval_for_dangerous(true)
664            .build();
665
666        let dangerous_schema = ToolSchema {
667            name: "delete_file".to_string(),
668            description: "delete".to_string(),
669            parameters: json!({}),
670            dangerous: true,
671            metadata: std::collections::HashMap::new(),
672            required_scopes: vec![],
673        };
674
675        let safe_schema = ToolSchema {
676            name: "read_file".to_string(),
677            description: "read".to_string(),
678            parameters: json!({}),
679            dangerous: false,
680            metadata: std::collections::HashMap::new(),
681            required_scopes: vec![],
682        };
683
684        // Dangerous tool should require approval
685        assert!(config
686            .requires_approval(&dangerous_schema.name, Some(&dangerous_schema))
687            .is_some());
688
689        // Non-dangerous tool should not
690        assert!(config
691            .requires_approval(&safe_schema.name, Some(&safe_schema))
692            .is_none());
693    }
694
695    #[test]
696    fn test_approval_config_always_approve() {
697        let config = ApprovalConfig::builder()
698            .require_approval_for_all(true)
699            .always_approve_tools(vec!["safe_tool"])
700            .build();
701
702        // Always approved tool bypasses
703        let safe_schema = ToolSchema {
704            name: "safe_tool".to_string(),
705            description: "".to_string(),
706            parameters: json!({}),
707            dangerous: true,
708            metadata: std::collections::HashMap::new(),
709            required_scopes: vec![],
710        };
711        assert!(config
712            .requires_approval(&safe_schema.name, Some(&safe_schema))
713            .is_none());
714
715        // Other tools still require approval
716        assert!(config.requires_approval("other_tool", None).is_some());
717    }
718
719    #[test]
720    fn test_approval_config_always_deny() {
721        let config = ApprovalConfig::builder()
722            .always_deny_tools(vec!["dangerous_tool"])
723            .build();
724
725        // Always deny takes precedence
726        let reason = config.requires_approval("dangerous_tool", None);
727        assert!(reason.is_some());
728    }
729
730    #[test]
731    fn test_approval_config_patterns() {
732        let config = ApprovalConfig::builder()
733            .add_approval_pattern("delete_.*")
734            .add_approval_pattern(".*_dangerous")
735            .build();
736
737        assert!(config.requires_approval("delete_file", None).is_some());
738        assert!(config.requires_approval("delete_folder", None).is_some());
739        assert!(config.requires_approval("run_dangerous", None).is_some());
740        assert!(config.requires_approval("read_file", None).is_none());
741    }
742
743    #[test]
744    fn test_approval_config_additional_sensitive() {
745        let config = ApprovalConfig::builder()
746            .add_sensitive_tool("export_data")
747            .build();
748
749        assert!(config.requires_approval("export_data", None).is_some());
750        assert!(config.requires_approval("other", None).is_none());
751    }
752
753    #[test]
754    fn test_approval_config_external_metadata() {
755        let mut metadata = std::collections::HashMap::new();
756        metadata.insert("external".to_string(), json!(true));
757
758        let schema = ToolSchema {
759            name: "call_api".to_string(),
760            description: "External API".to_string(),
761            parameters: json!({}),
762            dangerous: false,
763            metadata,
764            required_scopes: vec![],
765        };
766
767        let config = ApprovalConfig::builder()
768            .require_approval_for_external(true)
769            .build();
770
771        assert!(config
772            .requires_approval(&schema.name, Some(&schema))
773            .is_some());
774    }
775
776    #[tokio::test]
777    async fn test_test_handler_approve_all() {
778        let handler = TestApprovalHandler::approve_all();
779
780        let request = ApprovalRequest {
781            tool_call: ToolCall {
782                id: "1".to_string(),
783                name: "any_tool".to_string(),
784                arguments: json!({}),
785            },
786            tool_schema: None,
787            reason: ApprovalReason::DangerousTool,
788            timestamp: chrono::Utc::now(),
789            context: None,
790        };
791
792        let decision = handler.request_approval(request).await.unwrap();
793        assert_eq!(decision, ApprovalDecision::Approved);
794        assert_eq!(handler.request_count(), 1);
795    }
796
797    #[tokio::test]
798    async fn test_test_handler_deny_all() {
799        let handler = TestApprovalHandler::deny_all();
800
801        let request = ApprovalRequest {
802            tool_call: ToolCall {
803                id: "1".to_string(),
804                name: "any_tool".to_string(),
805                arguments: json!({}),
806            },
807            tool_schema: None,
808            reason: ApprovalReason::DangerousTool,
809            timestamp: chrono::Utc::now(),
810            context: None,
811        };
812
813        let decision = handler.request_approval(request).await.unwrap();
814        assert!(!decision.is_approved());
815    }
816
817    #[tokio::test]
818    async fn test_test_handler_specific_decisions() {
819        let handler = TestApprovalHandler::with_decisions(vec![
820            ("tool_a", ApprovalDecision::Approved),
821            ("tool_b", ApprovalDecision::denied("Not allowed")),
822        ]);
823
824        let make_request = |name: &str| ApprovalRequest {
825            tool_call: ToolCall {
826                id: "1".to_string(),
827                name: name.to_string(),
828                arguments: json!({}),
829            },
830            tool_schema: None,
831            reason: ApprovalReason::DangerousTool,
832            timestamp: chrono::Utc::now(),
833            context: None,
834        };
835
836        let d1 = handler
837            .request_approval(make_request("tool_a"))
838            .await
839            .unwrap();
840        assert_eq!(d1, ApprovalDecision::Approved);
841
842        let d2 = handler
843            .request_approval(make_request("tool_b"))
844            .await
845            .unwrap();
846        assert!(!d2.is_approved());
847
848        // Default is approve
849        let d3 = handler
850            .request_approval(make_request("tool_c"))
851            .await
852            .unwrap();
853        assert_eq!(d3, ApprovalDecision::Approved);
854    }
855
856    #[tokio::test]
857    async fn test_test_handler_records_requests() {
858        let handler = TestApprovalHandler::approve_all();
859
860        let request = ApprovalRequest {
861            tool_call: ToolCall {
862                id: "1".to_string(),
863                name: "test_tool".to_string(),
864                arguments: json!({"key": "value"}),
865            },
866            tool_schema: None,
867            reason: ApprovalReason::DangerousTool,
868            timestamp: chrono::Utc::now(),
869            context: Some("Test context".to_string()),
870        };
871
872        handler.request_approval(request).await.unwrap();
873
874        assert!(handler.was_approval_requested("test_tool"));
875        assert!(!handler.was_approval_requested("other_tool"));
876
877        let requests = handler.get_requests();
878        assert_eq!(requests.len(), 1);
879        assert_eq!(requests[0].tool_call.name, "test_tool");
880    }
881
882    #[tokio::test]
883    async fn test_auto_approve_handler() {
884        let handler = AutoApproveHandler::new();
885
886        let request = ApprovalRequest {
887            tool_call: ToolCall {
888                id: "1".to_string(),
889                name: "dangerous_tool".to_string(),
890                arguments: json!({}),
891            },
892            tool_schema: Some(ToolSchema {
893                name: "dangerous_tool".to_string(),
894                description: "A dangerous tool".to_string(),
895                parameters: json!({}),
896                dangerous: true,
897                metadata: std::collections::HashMap::new(),
898                required_scopes: vec![],
899            }),
900            reason: ApprovalReason::DangerousTool,
901            timestamp: chrono::Utc::now(),
902            context: None,
903        };
904
905        let decision = handler.request_approval(request).await.unwrap();
906        assert_eq!(decision, ApprovalDecision::Approved);
907    }
908
909    #[test]
910    fn test_approval_decision_is_approved() {
911        assert!(ApprovalDecision::Approved.is_approved());
912        assert!(ApprovalDecision::Modify {
913            new_arguments: json!({})
914        }
915        .is_approved());
916        assert!(!ApprovalDecision::Denied {
917            reason: "no".to_string()
918        }
919        .is_approved());
920        assert!(!ApprovalDecision::Skip.is_approved());
921    }
922
923    #[test]
924    fn test_check_requires_approval() {
925        let config = ApprovalConfig::builder()
926            .require_approval_for_dangerous(true)
927            .build();
928
929        let handler = TestApprovalHandler::with_config(config);
930
931        let safe_call = ToolCall {
932            id: "1".to_string(),
933            name: "safe_tool".to_string(),
934            arguments: json!({}),
935        };
936
937        let safe_schema = ToolSchema {
938            name: "safe_tool".to_string(),
939            description: "Safe".to_string(),
940            parameters: json!({}),
941            dangerous: false,
942            metadata: std::collections::HashMap::new(),
943            required_scopes: vec![],
944        };
945
946        let dangerous_schema = ToolSchema {
947            name: "dangerous_tool".to_string(),
948            description: "Dangerous".to_string(),
949            parameters: json!({}),
950            dangerous: true,
951            metadata: std::collections::HashMap::new(),
952            required_scopes: vec![],
953        };
954
955        // Safe tool with safe schema - no approval needed
956        assert!(handler
957            .check_requires_approval(&safe_call, Some(&safe_schema))
958            .is_none());
959
960        // Safe tool with dangerous schema - approval needed
961        assert!(handler
962            .check_requires_approval(&safe_call, Some(&dangerous_schema))
963            .is_some());
964
965        // No schema, defaults to not dangerous
966        assert!(handler.check_requires_approval(&safe_call, None).is_none());
967    }
968}