Skip to main content

enact_runner/
approval.rs

1//! Human-in-the-loop approval — check before tool execution.
2//!
3//! When set, the runner calls the checker before each tool execution.
4//! Return true to allow, false to block (tool result will indicate "blocked by approval policy").
5//!
6//! ## Policies
7//!
8//! - `AlwaysApprove` - Allow all tool calls without prompting
9//! - `AlwaysDeny` - Block all tool calls
10//! - `AskApprovalChecker` - Prompt user for every tool call
11//! - `AskOnceApprovalChecker` - Prompt once per tool, remember decisions
12//! - `PatternApprovalChecker` - Only prompt for tools matching patterns
13//! - `PolicyWithOverrides` - Per-tool policy overrides
14
15use async_trait::async_trait;
16use regex::Regex;
17use serde_json::Value;
18use std::collections::HashMap;
19use std::io::{self, BufRead, Write};
20use std::sync::Arc;
21use std::time::Duration;
22use tokio::sync::RwLock;
23use tokio::time::timeout;
24
25/// Checker invoked before each tool execution (PreToolUse). Return true to allow, false to block.
26#[async_trait]
27pub trait ApprovalChecker: Send + Sync {
28    async fn allow_tool(&self, tool_name: &str, args: &Value) -> bool;
29}
30
31/// Trait for prompting the user for approval decisions.
32///
33/// Implement this trait to customize how approval prompts are presented
34/// (CLI stdin, API callback, GUI dialog, etc.).
35#[async_trait]
36pub trait ApprovalPrompter: Send + Sync {
37    /// Prompt the user to approve a tool call.
38    ///
39    /// Returns `true` if approved, `false` if denied.
40    async fn prompt(&self, tool_name: &str, args: &Value) -> io::Result<bool>;
41}
42
43/// Always allow all tool calls (default when no approval config).
44pub struct AlwaysApprove;
45
46#[async_trait]
47impl ApprovalChecker for AlwaysApprove {
48    async fn allow_tool(&self, _tool_name: &str, _args: &Value) -> bool {
49        true
50    }
51}
52
53/// Always deny all tool calls.
54pub struct AlwaysDeny;
55
56#[async_trait]
57impl ApprovalChecker for AlwaysDeny {
58    async fn allow_tool(&self, _tool_name: &str, _args: &Value) -> bool {
59        false
60    }
61}
62
63// ─────────────────────────────────────────────────────────────────────────────
64// CLI Prompter
65// ─────────────────────────────────────────────────────────────────────────────
66
67/// CLI-based prompter that reads from stdin.
68///
69/// Displays tool name and arguments, asks user to approve [y/n].
70pub struct CliPrompter;
71
72impl CliPrompter {
73    pub fn new() -> Self {
74        Self
75    }
76}
77
78impl Default for CliPrompter {
79    fn default() -> Self {
80        Self::new()
81    }
82}
83
84#[async_trait]
85impl ApprovalPrompter for CliPrompter {
86    async fn prompt(&self, tool_name: &str, args: &Value) -> io::Result<bool> {
87        // Format arguments for display (truncate if too long)
88        let args_str = serde_json::to_string_pretty(args).unwrap_or_else(|_| args.to_string());
89        let args_display = if args_str.len() > 500 {
90            format!("{}...", &args_str[..500])
91        } else {
92            args_str
93        };
94
95        // Print prompt
96        let mut stdout = io::stdout();
97        writeln!(stdout)?;
98        writeln!(
99            stdout,
100            "╭─────────────────────────────────────────────────────╮"
101        )?;
102        writeln!(
103            stdout,
104            "│  Tool Approval Required                             │"
105        )?;
106        writeln!(
107            stdout,
108            "╰─────────────────────────────────────────────────────╯"
109        )?;
110        writeln!(stdout)?;
111        writeln!(stdout, "Tool: {}", tool_name)?;
112        writeln!(stdout, "Arguments:")?;
113        for line in args_display.lines() {
114            writeln!(stdout, "  {}", line)?;
115        }
116        writeln!(stdout)?;
117        write!(stdout, "Allow this tool call? [y/n]: ")?;
118        stdout.flush()?;
119
120        // Read response - use blocking I/O in spawn_blocking to avoid blocking async runtime
121        let response = tokio::task::spawn_blocking(|| {
122            let stdin = io::stdin();
123            let mut line = String::new();
124            stdin.lock().read_line(&mut line)?;
125            Ok::<_, io::Error>(line)
126        })
127        .await
128        .map_err(io::Error::other)??;
129
130        let answer = response.trim().to_lowercase();
131        Ok(answer == "y" || answer == "yes")
132    }
133}
134
135// ─────────────────────────────────────────────────────────────────────────────
136// Ask Policy (prompt every time)
137// ─────────────────────────────────────────────────────────────────────────────
138
139/// Ask policy — prompts the user for every tool call.
140///
141/// Use this for maximum control over what the agent can do.
142pub struct AskApprovalChecker {
143    prompter: Arc<dyn ApprovalPrompter>,
144    timeout_duration: Duration,
145}
146
147impl AskApprovalChecker {
148    /// Create a new AskApprovalChecker with the given prompter and timeout.
149    pub fn new(prompter: Arc<dyn ApprovalPrompter>, timeout_seconds: u64) -> Self {
150        Self {
151            prompter,
152            timeout_duration: Duration::from_secs(timeout_seconds),
153        }
154    }
155
156    /// Create with CLI prompter and default 5-minute timeout.
157    pub fn cli_default() -> Self {
158        Self::new(Arc::new(CliPrompter::new()), 300)
159    }
160}
161
162#[async_trait]
163impl ApprovalChecker for AskApprovalChecker {
164    async fn allow_tool(&self, tool_name: &str, args: &Value) -> bool {
165        match timeout(self.timeout_duration, self.prompter.prompt(tool_name, args)).await {
166            Ok(Ok(approved)) => approved,
167            Ok(Err(e)) => {
168                tracing::warn!(error = %e, tool = tool_name, "Approval prompt failed, denying");
169                false
170            }
171            Err(_) => {
172                tracing::warn!(tool = tool_name, "Approval prompt timed out, denying");
173                false
174            }
175        }
176    }
177}
178
179// ─────────────────────────────────────────────────────────────────────────────
180// AskOnce Policy (remember decisions)
181// ─────────────────────────────────────────────────────────────────────────────
182
183/// AskOnce policy — prompts once per tool name, remembers decisions.
184///
185/// After approving/denying a tool once, subsequent calls to the same tool
186/// use the cached decision without prompting again.
187pub struct AskOnceApprovalChecker {
188    prompter: Arc<dyn ApprovalPrompter>,
189    decisions: RwLock<HashMap<String, bool>>,
190    timeout_duration: Duration,
191}
192
193impl AskOnceApprovalChecker {
194    /// Create a new AskOnceApprovalChecker with the given prompter and timeout.
195    pub fn new(prompter: Arc<dyn ApprovalPrompter>, timeout_seconds: u64) -> Self {
196        Self {
197            prompter,
198            decisions: RwLock::new(HashMap::new()),
199            timeout_duration: Duration::from_secs(timeout_seconds),
200        }
201    }
202
203    /// Create with CLI prompter and default 5-minute timeout.
204    pub fn cli_default() -> Self {
205        Self::new(Arc::new(CliPrompter::new()), 300)
206    }
207
208    /// Clear all cached decisions.
209    pub async fn clear_decisions(&self) {
210        self.decisions.write().await.clear();
211    }
212
213    /// Pre-approve a tool (useful for allowing known-safe tools).
214    pub async fn pre_approve(&self, tool_name: &str) {
215        self.decisions
216            .write()
217            .await
218            .insert(tool_name.to_string(), true);
219    }
220
221    /// Pre-deny a tool (useful for blocking known-dangerous tools).
222    pub async fn pre_deny(&self, tool_name: &str) {
223        self.decisions
224            .write()
225            .await
226            .insert(tool_name.to_string(), false);
227    }
228}
229
230#[async_trait]
231impl ApprovalChecker for AskOnceApprovalChecker {
232    async fn allow_tool(&self, tool_name: &str, args: &Value) -> bool {
233        // Check cached decision first
234        {
235            let decisions = self.decisions.read().await;
236            if let Some(&cached) = decisions.get(tool_name) {
237                tracing::debug!(tool = tool_name, cached, "Using cached approval decision");
238                return cached;
239            }
240        }
241
242        // No cached decision — prompt user
243        let approved =
244            match timeout(self.timeout_duration, self.prompter.prompt(tool_name, args)).await {
245                Ok(Ok(approved)) => approved,
246                Ok(Err(e)) => {
247                    tracing::warn!(error = %e, tool = tool_name, "Approval prompt failed, denying");
248                    false
249                }
250                Err(_) => {
251                    tracing::warn!(tool = tool_name, "Approval prompt timed out, denying");
252                    false
253                }
254            };
255
256        // Cache the decision
257        self.decisions
258            .write()
259            .await
260            .insert(tool_name.to_string(), approved);
261
262        approved
263    }
264}
265
266// ─────────────────────────────────────────────────────────────────────────────
267// Pattern Policy (only prompt for matching tools)
268// ─────────────────────────────────────────────────────────────────────────────
269
270/// Pattern policy — only prompts for tools matching specified patterns.
271///
272/// Tools that don't match any pattern are automatically approved.
273/// Tools matching a pattern require user approval.
274pub struct PatternApprovalChecker {
275    prompter: Arc<dyn ApprovalPrompter>,
276    patterns: Vec<Regex>,
277    timeout_duration: Duration,
278}
279
280impl PatternApprovalChecker {
281    /// Create a new PatternApprovalChecker with the given patterns.
282    ///
283    /// Patterns are regular expressions matched against tool names.
284    pub fn new(
285        prompter: Arc<dyn ApprovalPrompter>,
286        patterns: Vec<String>,
287        timeout_seconds: u64,
288    ) -> Result<Self, regex::Error> {
289        let compiled: Result<Vec<Regex>, _> = patterns.iter().map(|p| Regex::new(p)).collect();
290        Ok(Self {
291            prompter,
292            patterns: compiled?,
293            timeout_duration: Duration::from_secs(timeout_seconds),
294        })
295    }
296
297    /// Create with CLI prompter and default timeout.
298    pub fn cli_with_patterns(patterns: Vec<String>) -> Result<Self, regex::Error> {
299        Self::new(Arc::new(CliPrompter::new()), patterns, 300)
300    }
301
302    /// Check if a tool name matches any pattern.
303    fn matches_pattern(&self, tool_name: &str) -> bool {
304        self.patterns.iter().any(|p| p.is_match(tool_name))
305    }
306}
307
308#[async_trait]
309impl ApprovalChecker for PatternApprovalChecker {
310    async fn allow_tool(&self, tool_name: &str, args: &Value) -> bool {
311        // If no pattern matches, auto-approve
312        if !self.matches_pattern(tool_name) {
313            tracing::debug!(
314                tool = tool_name,
315                "Tool doesn't match approval patterns, auto-approving"
316            );
317            return true;
318        }
319
320        // Pattern matched — prompt user
321        tracing::debug!(tool = tool_name, "Tool matches approval pattern, prompting");
322        match timeout(self.timeout_duration, self.prompter.prompt(tool_name, args)).await {
323            Ok(Ok(approved)) => approved,
324            Ok(Err(e)) => {
325                tracing::warn!(error = %e, tool = tool_name, "Approval prompt failed, denying");
326                false
327            }
328            Err(_) => {
329                tracing::warn!(tool = tool_name, "Approval prompt timed out, denying");
330                false
331            }
332        }
333    }
334}
335
336// ─────────────────────────────────────────────────────────────────────────────
337// Policy with Per-Tool Overrides
338// ─────────────────────────────────────────────────────────────────────────────
339
340/// Policy with per-tool overrides.
341///
342/// Allows specifying different policies for specific tools while using
343/// a default policy for everything else.
344///
345/// # Example
346///
347/// ```ignore
348/// let policy = PolicyWithOverrides::new(Arc::new(AskApprovalChecker::cli_default()))
349///     .with_override("Read", Arc::new(AlwaysApprove))  // Auto-approve Read
350///     .with_override("Write", Arc::new(AlwaysDeny));   // Block Write
351/// ```
352pub struct PolicyWithOverrides {
353    default: Arc<dyn ApprovalChecker>,
354    overrides: HashMap<String, Arc<dyn ApprovalChecker>>,
355}
356
357impl PolicyWithOverrides {
358    /// Create a new policy with the given default checker.
359    pub fn new(default: Arc<dyn ApprovalChecker>) -> Self {
360        Self {
361            default,
362            overrides: HashMap::new(),
363        }
364    }
365
366    /// Add a per-tool override.
367    pub fn with_override(mut self, tool_name: &str, checker: Arc<dyn ApprovalChecker>) -> Self {
368        self.overrides.insert(tool_name.to_string(), checker);
369        self
370    }
371
372    /// Add multiple overrides at once.
373    pub fn with_overrides(mut self, overrides: HashMap<String, Arc<dyn ApprovalChecker>>) -> Self {
374        self.overrides.extend(overrides);
375        self
376    }
377}
378
379#[async_trait]
380impl ApprovalChecker for PolicyWithOverrides {
381    async fn allow_tool(&self, tool_name: &str, args: &Value) -> bool {
382        // Check for tool-specific override
383        if let Some(checker) = self.overrides.get(tool_name) {
384            return checker.allow_tool(tool_name, args).await;
385        }
386
387        // Fall back to default policy
388        self.default.allow_tool(tool_name, args).await
389    }
390}
391
392// ─────────────────────────────────────────────────────────────────────────────
393// Factory function from config
394// ─────────────────────────────────────────────────────────────────────────────
395
396/// Create an approval checker from ApprovalConfig.
397///
398/// Maps policy strings to checker implementations:
399/// - "always_approve" → AlwaysApprove
400/// - "always_deny" → AlwaysDeny
401/// - "ask" or "always_require" → AskApprovalChecker
402/// - "ask_once" → AskOnceApprovalChecker
403/// - "pattern" → PatternApprovalChecker (requires patterns in config)
404pub fn checker_from_config(
405    policy: &str,
406    timeout_seconds: u64,
407    patterns: Option<&[String]>,
408) -> Result<Arc<dyn ApprovalChecker>, String> {
409    match policy {
410        "always_approve" => Ok(Arc::new(AlwaysApprove)),
411        "always_deny" => Ok(Arc::new(AlwaysDeny)),
412        "ask" | "always_require" => Ok(Arc::new(AskApprovalChecker::new(
413            Arc::new(CliPrompter::new()),
414            timeout_seconds,
415        ))),
416        "ask_once" => Ok(Arc::new(AskOnceApprovalChecker::new(
417            Arc::new(CliPrompter::new()),
418            timeout_seconds,
419        ))),
420        "pattern" => {
421            let patterns = patterns
422                .ok_or("Pattern policy requires 'require_patterns' in config")?
423                .to_vec();
424            PatternApprovalChecker::new(Arc::new(CliPrompter::new()), patterns, timeout_seconds)
425                .map(|c| Arc::new(c) as Arc<dyn ApprovalChecker>)
426                .map_err(|e| format!("Invalid pattern regex: {}", e))
427        }
428        other => Err(format!("Unknown approval policy: '{}'", other)),
429    }
430}
431
432/// Create an approval checker from full ApprovalConfig, including tool overrides.
433///
434/// If `tool_overrides` is set, wraps the base checker in `PolicyWithOverrides`.
435pub fn checker_from_approval_config(
436    config: &enact_config::ApprovalConfig,
437) -> Result<Arc<dyn ApprovalChecker>, String> {
438    // Create base checker from policy
439    let base = checker_from_config(
440        &config.policy,
441        config.timeout_seconds,
442        config.require_patterns.as_deref(),
443    )?;
444
445    // Apply tool overrides if present
446    if let Some(ref overrides) = config.tool_overrides {
447        if overrides.is_empty() {
448            return Ok(base);
449        }
450
451        let mut policy = PolicyWithOverrides::new(base);
452        for (tool_name, tool_policy) in overrides {
453            let override_checker = checker_from_config(tool_policy, config.timeout_seconds, None)?;
454            policy = policy.with_override(tool_name, override_checker);
455        }
456        Ok(Arc::new(policy))
457    } else {
458        Ok(base)
459    }
460}
461
462// ─────────────────────────────────────────────────────────────────────────────
463// Tests
464// ─────────────────────────────────────────────────────────────────────────────
465
466#[cfg(test)]
467mod tests {
468    use super::*;
469
470    #[tokio::test]
471    async fn always_approve_allows_all() {
472        let checker = AlwaysApprove;
473        assert!(checker.allow_tool("anything", &serde_json::json!({})).await);
474        assert!(
475            checker
476                .allow_tool("dangerous_tool", &serde_json::json!({"rm": "-rf /"}))
477                .await
478        );
479    }
480
481    #[tokio::test]
482    async fn always_deny_blocks_all() {
483        let checker = AlwaysDeny;
484        assert!(!checker.allow_tool("anything", &serde_json::json!({})).await);
485        assert!(
486            !checker
487                .allow_tool("safe_tool", &serde_json::json!({}))
488                .await
489        );
490    }
491
492    /// Mock prompter for testing
493    struct MockPrompter {
494        response: bool,
495    }
496
497    impl MockPrompter {
498        fn approving() -> Arc<Self> {
499            Arc::new(Self { response: true })
500        }
501
502        fn denying() -> Arc<Self> {
503            Arc::new(Self { response: false })
504        }
505    }
506
507    #[async_trait]
508    impl ApprovalPrompter for MockPrompter {
509        async fn prompt(&self, _tool_name: &str, _args: &Value) -> io::Result<bool> {
510            Ok(self.response)
511        }
512    }
513
514    #[tokio::test]
515    async fn ask_checker_uses_prompter() {
516        let approving = AskApprovalChecker::new(MockPrompter::approving(), 60);
517        assert!(approving.allow_tool("test", &serde_json::json!({})).await);
518
519        let denying = AskApprovalChecker::new(MockPrompter::denying(), 60);
520        assert!(!denying.allow_tool("test", &serde_json::json!({})).await);
521    }
522
523    #[tokio::test]
524    async fn ask_once_caches_decisions() {
525        let checker = AskOnceApprovalChecker::new(MockPrompter::approving(), 60);
526
527        // First call prompts and caches
528        assert!(
529            checker
530                .allow_tool("test_tool", &serde_json::json!({}))
531                .await
532        );
533
534        // Second call uses cache (even if prompter would deny)
535        // We can verify caching by checking the decision map
536        {
537            let decisions = checker.decisions.read().await;
538            assert_eq!(decisions.get("test_tool"), Some(&true));
539        }
540    }
541
542    #[tokio::test]
543    async fn ask_once_pre_approve_works() {
544        let checker = AskOnceApprovalChecker::new(MockPrompter::denying(), 60);
545
546        // Pre-approve the tool
547        checker.pre_approve("safe_tool").await;
548
549        // Should be approved without prompting
550        assert!(
551            checker
552                .allow_tool("safe_tool", &serde_json::json!({}))
553                .await
554        );
555    }
556
557    #[tokio::test]
558    async fn ask_once_pre_deny_works() {
559        let checker = AskOnceApprovalChecker::new(MockPrompter::approving(), 60);
560
561        // Pre-deny the tool
562        checker.pre_deny("dangerous_tool").await;
563
564        // Should be denied without prompting
565        assert!(
566            !checker
567                .allow_tool("dangerous_tool", &serde_json::json!({}))
568                .await
569        );
570    }
571
572    #[tokio::test]
573    async fn pattern_checker_auto_approves_non_matching() {
574        let checker =
575            PatternApprovalChecker::new(MockPrompter::denying(), vec!["^Write".to_string()], 60)
576                .unwrap();
577
578        // Read doesn't match ^Write pattern, auto-approved
579        assert!(checker.allow_tool("Read", &serde_json::json!({})).await);
580
581        // Write matches, uses prompter (which denies)
582        assert!(!checker.allow_tool("Write", &serde_json::json!({})).await);
583    }
584
585    #[tokio::test]
586    async fn pattern_checker_prompts_for_matching() {
587        let checker = PatternApprovalChecker::new(
588            MockPrompter::approving(),
589            vec!["Edit|Write|Bash".to_string()],
590            60,
591        )
592        .unwrap();
593
594        // These match the pattern
595        assert!(checker.allow_tool("Edit", &serde_json::json!({})).await);
596        assert!(checker.allow_tool("Write", &serde_json::json!({})).await);
597        assert!(checker.allow_tool("Bash", &serde_json::json!({})).await);
598
599        // This doesn't match
600        assert!(checker.allow_tool("Read", &serde_json::json!({})).await);
601    }
602
603    #[tokio::test]
604    async fn policy_with_overrides_uses_specific_policy() {
605        let default = Arc::new(AlwaysDeny);
606        let policy = PolicyWithOverrides::new(default)
607            .with_override("Read", Arc::new(AlwaysApprove))
608            .with_override("Glob", Arc::new(AlwaysApprove));
609
610        // Read and Glob use override (AlwaysApprove)
611        assert!(policy.allow_tool("Read", &serde_json::json!({})).await);
612        assert!(policy.allow_tool("Glob", &serde_json::json!({})).await);
613
614        // Others use default (AlwaysDeny)
615        assert!(!policy.allow_tool("Write", &serde_json::json!({})).await);
616        assert!(!policy.allow_tool("Edit", &serde_json::json!({})).await);
617    }
618
619    #[tokio::test]
620    async fn checker_from_config_creates_correct_types() {
621        // Test always_approve
622        let checker = checker_from_config("always_approve", 60, None).unwrap();
623        assert!(checker.allow_tool("test", &serde_json::json!({})).await);
624
625        // Test always_deny
626        let checker = checker_from_config("always_deny", 60, None).unwrap();
627        assert!(!checker.allow_tool("test", &serde_json::json!({})).await);
628
629        // Test pattern with valid regex
630        let patterns = vec!["^Edit$".to_string()];
631        let checker = checker_from_config("pattern", 60, Some(&patterns)).unwrap();
632        // Non-matching tools should be auto-approved
633        assert!(checker.allow_tool("Read", &serde_json::json!({})).await);
634
635        // Test pattern without patterns should error
636        let result = checker_from_config("pattern", 60, None);
637        assert!(result.is_err());
638
639        // Test unknown policy should error
640        let result = checker_from_config("unknown_policy", 60, None);
641        assert!(result.is_err());
642    }
643
644    #[tokio::test]
645    async fn checker_from_approval_config_with_overrides() {
646        use std::collections::HashMap;
647
648        // Create config with base policy=always_deny but override Read to always_approve
649        let mut overrides = HashMap::new();
650        overrides.insert("Read".to_string(), "always_approve".to_string());
651        overrides.insert("Glob".to_string(), "always_approve".to_string());
652
653        let config = enact_config::ApprovalConfig {
654            enabled: true,
655            policy: "always_deny".to_string(),
656            max_steps: None,
657            require_patterns: None,
658            timeout_seconds: 60,
659            tool_overrides: Some(overrides),
660        };
661
662        let checker = checker_from_approval_config(&config).unwrap();
663
664        // Read and Glob should be approved (override)
665        assert!(checker.allow_tool("Read", &serde_json::json!({})).await);
666        assert!(checker.allow_tool("Glob", &serde_json::json!({})).await);
667
668        // Other tools should be denied (base policy)
669        assert!(!checker.allow_tool("Write", &serde_json::json!({})).await);
670        assert!(!checker.allow_tool("Edit", &serde_json::json!({})).await);
671    }
672
673    #[tokio::test]
674    async fn checker_from_approval_config_without_overrides() {
675        let config = enact_config::ApprovalConfig {
676            enabled: true,
677            policy: "always_approve".to_string(),
678            max_steps: None,
679            require_patterns: None,
680            timeout_seconds: 60,
681            tool_overrides: None,
682        };
683
684        let checker = checker_from_approval_config(&config).unwrap();
685        assert!(checker.allow_tool("anything", &serde_json::json!({})).await);
686    }
687}