1use async_trait::async_trait;
16use regex::Regex;
17use serde_json::Value;
18use std::collections::HashMap;
19use std::io::{self, BufRead, Write};
20use std::sync::Arc;
21use std::time::Duration;
22use tokio::sync::RwLock;
23use tokio::time::timeout;
24
25#[async_trait]
27pub trait ApprovalChecker: Send + Sync {
28 async fn allow_tool(&self, tool_name: &str, args: &Value) -> bool;
29}
30
31#[async_trait]
36pub trait ApprovalPrompter: Send + Sync {
37 async fn prompt(&self, tool_name: &str, args: &Value) -> io::Result<bool>;
41}
42
43pub struct AlwaysApprove;
45
46#[async_trait]
47impl ApprovalChecker for AlwaysApprove {
48 async fn allow_tool(&self, _tool_name: &str, _args: &Value) -> bool {
49 true
50 }
51}
52
53pub struct AlwaysDeny;
55
56#[async_trait]
57impl ApprovalChecker for AlwaysDeny {
58 async fn allow_tool(&self, _tool_name: &str, _args: &Value) -> bool {
59 false
60 }
61}
62
63pub struct CliPrompter;
71
72impl CliPrompter {
73 pub fn new() -> Self {
74 Self
75 }
76}
77
78impl Default for CliPrompter {
79 fn default() -> Self {
80 Self::new()
81 }
82}
83
84#[async_trait]
85impl ApprovalPrompter for CliPrompter {
86 async fn prompt(&self, tool_name: &str, args: &Value) -> io::Result<bool> {
87 let args_str = serde_json::to_string_pretty(args).unwrap_or_else(|_| args.to_string());
89 let args_display = if args_str.len() > 500 {
90 format!("{}...", &args_str[..500])
91 } else {
92 args_str
93 };
94
95 let mut stdout = io::stdout();
97 writeln!(stdout)?;
98 writeln!(
99 stdout,
100 "╭─────────────────────────────────────────────────────╮"
101 )?;
102 writeln!(
103 stdout,
104 "│ Tool Approval Required │"
105 )?;
106 writeln!(
107 stdout,
108 "╰─────────────────────────────────────────────────────╯"
109 )?;
110 writeln!(stdout)?;
111 writeln!(stdout, "Tool: {}", tool_name)?;
112 writeln!(stdout, "Arguments:")?;
113 for line in args_display.lines() {
114 writeln!(stdout, " {}", line)?;
115 }
116 writeln!(stdout)?;
117 write!(stdout, "Allow this tool call? [y/n]: ")?;
118 stdout.flush()?;
119
120 let response = tokio::task::spawn_blocking(|| {
122 let stdin = io::stdin();
123 let mut line = String::new();
124 stdin.lock().read_line(&mut line)?;
125 Ok::<_, io::Error>(line)
126 })
127 .await
128 .map_err(io::Error::other)??;
129
130 let answer = response.trim().to_lowercase();
131 Ok(answer == "y" || answer == "yes")
132 }
133}
134
135pub struct AskApprovalChecker {
143 prompter: Arc<dyn ApprovalPrompter>,
144 timeout_duration: Duration,
145}
146
147impl AskApprovalChecker {
148 pub fn new(prompter: Arc<dyn ApprovalPrompter>, timeout_seconds: u64) -> Self {
150 Self {
151 prompter,
152 timeout_duration: Duration::from_secs(timeout_seconds),
153 }
154 }
155
156 pub fn cli_default() -> Self {
158 Self::new(Arc::new(CliPrompter::new()), 300)
159 }
160}
161
162#[async_trait]
163impl ApprovalChecker for AskApprovalChecker {
164 async fn allow_tool(&self, tool_name: &str, args: &Value) -> bool {
165 match timeout(self.timeout_duration, self.prompter.prompt(tool_name, args)).await {
166 Ok(Ok(approved)) => approved,
167 Ok(Err(e)) => {
168 tracing::warn!(error = %e, tool = tool_name, "Approval prompt failed, denying");
169 false
170 }
171 Err(_) => {
172 tracing::warn!(tool = tool_name, "Approval prompt timed out, denying");
173 false
174 }
175 }
176 }
177}
178
179pub struct AskOnceApprovalChecker {
188 prompter: Arc<dyn ApprovalPrompter>,
189 decisions: RwLock<HashMap<String, bool>>,
190 timeout_duration: Duration,
191}
192
193impl AskOnceApprovalChecker {
194 pub fn new(prompter: Arc<dyn ApprovalPrompter>, timeout_seconds: u64) -> Self {
196 Self {
197 prompter,
198 decisions: RwLock::new(HashMap::new()),
199 timeout_duration: Duration::from_secs(timeout_seconds),
200 }
201 }
202
203 pub fn cli_default() -> Self {
205 Self::new(Arc::new(CliPrompter::new()), 300)
206 }
207
208 pub async fn clear_decisions(&self) {
210 self.decisions.write().await.clear();
211 }
212
213 pub async fn pre_approve(&self, tool_name: &str) {
215 self.decisions
216 .write()
217 .await
218 .insert(tool_name.to_string(), true);
219 }
220
221 pub async fn pre_deny(&self, tool_name: &str) {
223 self.decisions
224 .write()
225 .await
226 .insert(tool_name.to_string(), false);
227 }
228}
229
230#[async_trait]
231impl ApprovalChecker for AskOnceApprovalChecker {
232 async fn allow_tool(&self, tool_name: &str, args: &Value) -> bool {
233 {
235 let decisions = self.decisions.read().await;
236 if let Some(&cached) = decisions.get(tool_name) {
237 tracing::debug!(tool = tool_name, cached, "Using cached approval decision");
238 return cached;
239 }
240 }
241
242 let approved =
244 match timeout(self.timeout_duration, self.prompter.prompt(tool_name, args)).await {
245 Ok(Ok(approved)) => approved,
246 Ok(Err(e)) => {
247 tracing::warn!(error = %e, tool = tool_name, "Approval prompt failed, denying");
248 false
249 }
250 Err(_) => {
251 tracing::warn!(tool = tool_name, "Approval prompt timed out, denying");
252 false
253 }
254 };
255
256 self.decisions
258 .write()
259 .await
260 .insert(tool_name.to_string(), approved);
261
262 approved
263 }
264}
265
266pub struct PatternApprovalChecker {
275 prompter: Arc<dyn ApprovalPrompter>,
276 patterns: Vec<Regex>,
277 timeout_duration: Duration,
278}
279
280impl PatternApprovalChecker {
281 pub fn new(
285 prompter: Arc<dyn ApprovalPrompter>,
286 patterns: Vec<String>,
287 timeout_seconds: u64,
288 ) -> Result<Self, regex::Error> {
289 let compiled: Result<Vec<Regex>, _> = patterns.iter().map(|p| Regex::new(p)).collect();
290 Ok(Self {
291 prompter,
292 patterns: compiled?,
293 timeout_duration: Duration::from_secs(timeout_seconds),
294 })
295 }
296
297 pub fn cli_with_patterns(patterns: Vec<String>) -> Result<Self, regex::Error> {
299 Self::new(Arc::new(CliPrompter::new()), patterns, 300)
300 }
301
302 fn matches_pattern(&self, tool_name: &str) -> bool {
304 self.patterns.iter().any(|p| p.is_match(tool_name))
305 }
306}
307
308#[async_trait]
309impl ApprovalChecker for PatternApprovalChecker {
310 async fn allow_tool(&self, tool_name: &str, args: &Value) -> bool {
311 if !self.matches_pattern(tool_name) {
313 tracing::debug!(
314 tool = tool_name,
315 "Tool doesn't match approval patterns, auto-approving"
316 );
317 return true;
318 }
319
320 tracing::debug!(tool = tool_name, "Tool matches approval pattern, prompting");
322 match timeout(self.timeout_duration, self.prompter.prompt(tool_name, args)).await {
323 Ok(Ok(approved)) => approved,
324 Ok(Err(e)) => {
325 tracing::warn!(error = %e, tool = tool_name, "Approval prompt failed, denying");
326 false
327 }
328 Err(_) => {
329 tracing::warn!(tool = tool_name, "Approval prompt timed out, denying");
330 false
331 }
332 }
333 }
334}
335
336pub struct PolicyWithOverrides {
353 default: Arc<dyn ApprovalChecker>,
354 overrides: HashMap<String, Arc<dyn ApprovalChecker>>,
355}
356
357impl PolicyWithOverrides {
358 pub fn new(default: Arc<dyn ApprovalChecker>) -> Self {
360 Self {
361 default,
362 overrides: HashMap::new(),
363 }
364 }
365
366 pub fn with_override(mut self, tool_name: &str, checker: Arc<dyn ApprovalChecker>) -> Self {
368 self.overrides.insert(tool_name.to_string(), checker);
369 self
370 }
371
372 pub fn with_overrides(mut self, overrides: HashMap<String, Arc<dyn ApprovalChecker>>) -> Self {
374 self.overrides.extend(overrides);
375 self
376 }
377}
378
379#[async_trait]
380impl ApprovalChecker for PolicyWithOverrides {
381 async fn allow_tool(&self, tool_name: &str, args: &Value) -> bool {
382 if let Some(checker) = self.overrides.get(tool_name) {
384 return checker.allow_tool(tool_name, args).await;
385 }
386
387 self.default.allow_tool(tool_name, args).await
389 }
390}
391
392pub fn checker_from_config(
405 policy: &str,
406 timeout_seconds: u64,
407 patterns: Option<&[String]>,
408) -> Result<Arc<dyn ApprovalChecker>, String> {
409 match policy {
410 "always_approve" => Ok(Arc::new(AlwaysApprove)),
411 "always_deny" => Ok(Arc::new(AlwaysDeny)),
412 "ask" | "always_require" => Ok(Arc::new(AskApprovalChecker::new(
413 Arc::new(CliPrompter::new()),
414 timeout_seconds,
415 ))),
416 "ask_once" => Ok(Arc::new(AskOnceApprovalChecker::new(
417 Arc::new(CliPrompter::new()),
418 timeout_seconds,
419 ))),
420 "pattern" => {
421 let patterns = patterns
422 .ok_or("Pattern policy requires 'require_patterns' in config")?
423 .to_vec();
424 PatternApprovalChecker::new(Arc::new(CliPrompter::new()), patterns, timeout_seconds)
425 .map(|c| Arc::new(c) as Arc<dyn ApprovalChecker>)
426 .map_err(|e| format!("Invalid pattern regex: {}", e))
427 }
428 other => Err(format!("Unknown approval policy: '{}'", other)),
429 }
430}
431
432pub fn checker_from_approval_config(
436 config: &enact_config::ApprovalConfig,
437) -> Result<Arc<dyn ApprovalChecker>, String> {
438 let base = checker_from_config(
440 &config.policy,
441 config.timeout_seconds,
442 config.require_patterns.as_deref(),
443 )?;
444
445 if let Some(ref overrides) = config.tool_overrides {
447 if overrides.is_empty() {
448 return Ok(base);
449 }
450
451 let mut policy = PolicyWithOverrides::new(base);
452 for (tool_name, tool_policy) in overrides {
453 let override_checker = checker_from_config(tool_policy, config.timeout_seconds, None)?;
454 policy = policy.with_override(tool_name, override_checker);
455 }
456 Ok(Arc::new(policy))
457 } else {
458 Ok(base)
459 }
460}
461
462#[cfg(test)]
467mod tests {
468 use super::*;
469
470 #[tokio::test]
471 async fn always_approve_allows_all() {
472 let checker = AlwaysApprove;
473 assert!(checker.allow_tool("anything", &serde_json::json!({})).await);
474 assert!(
475 checker
476 .allow_tool("dangerous_tool", &serde_json::json!({"rm": "-rf /"}))
477 .await
478 );
479 }
480
481 #[tokio::test]
482 async fn always_deny_blocks_all() {
483 let checker = AlwaysDeny;
484 assert!(!checker.allow_tool("anything", &serde_json::json!({})).await);
485 assert!(
486 !checker
487 .allow_tool("safe_tool", &serde_json::json!({}))
488 .await
489 );
490 }
491
492 struct MockPrompter {
494 response: bool,
495 }
496
497 impl MockPrompter {
498 fn approving() -> Arc<Self> {
499 Arc::new(Self { response: true })
500 }
501
502 fn denying() -> Arc<Self> {
503 Arc::new(Self { response: false })
504 }
505 }
506
507 #[async_trait]
508 impl ApprovalPrompter for MockPrompter {
509 async fn prompt(&self, _tool_name: &str, _args: &Value) -> io::Result<bool> {
510 Ok(self.response)
511 }
512 }
513
514 #[tokio::test]
515 async fn ask_checker_uses_prompter() {
516 let approving = AskApprovalChecker::new(MockPrompter::approving(), 60);
517 assert!(approving.allow_tool("test", &serde_json::json!({})).await);
518
519 let denying = AskApprovalChecker::new(MockPrompter::denying(), 60);
520 assert!(!denying.allow_tool("test", &serde_json::json!({})).await);
521 }
522
523 #[tokio::test]
524 async fn ask_once_caches_decisions() {
525 let checker = AskOnceApprovalChecker::new(MockPrompter::approving(), 60);
526
527 assert!(
529 checker
530 .allow_tool("test_tool", &serde_json::json!({}))
531 .await
532 );
533
534 {
537 let decisions = checker.decisions.read().await;
538 assert_eq!(decisions.get("test_tool"), Some(&true));
539 }
540 }
541
542 #[tokio::test]
543 async fn ask_once_pre_approve_works() {
544 let checker = AskOnceApprovalChecker::new(MockPrompter::denying(), 60);
545
546 checker.pre_approve("safe_tool").await;
548
549 assert!(
551 checker
552 .allow_tool("safe_tool", &serde_json::json!({}))
553 .await
554 );
555 }
556
557 #[tokio::test]
558 async fn ask_once_pre_deny_works() {
559 let checker = AskOnceApprovalChecker::new(MockPrompter::approving(), 60);
560
561 checker.pre_deny("dangerous_tool").await;
563
564 assert!(
566 !checker
567 .allow_tool("dangerous_tool", &serde_json::json!({}))
568 .await
569 );
570 }
571
572 #[tokio::test]
573 async fn pattern_checker_auto_approves_non_matching() {
574 let checker =
575 PatternApprovalChecker::new(MockPrompter::denying(), vec!["^Write".to_string()], 60)
576 .unwrap();
577
578 assert!(checker.allow_tool("Read", &serde_json::json!({})).await);
580
581 assert!(!checker.allow_tool("Write", &serde_json::json!({})).await);
583 }
584
585 #[tokio::test]
586 async fn pattern_checker_prompts_for_matching() {
587 let checker = PatternApprovalChecker::new(
588 MockPrompter::approving(),
589 vec!["Edit|Write|Bash".to_string()],
590 60,
591 )
592 .unwrap();
593
594 assert!(checker.allow_tool("Edit", &serde_json::json!({})).await);
596 assert!(checker.allow_tool("Write", &serde_json::json!({})).await);
597 assert!(checker.allow_tool("Bash", &serde_json::json!({})).await);
598
599 assert!(checker.allow_tool("Read", &serde_json::json!({})).await);
601 }
602
603 #[tokio::test]
604 async fn policy_with_overrides_uses_specific_policy() {
605 let default = Arc::new(AlwaysDeny);
606 let policy = PolicyWithOverrides::new(default)
607 .with_override("Read", Arc::new(AlwaysApprove))
608 .with_override("Glob", Arc::new(AlwaysApprove));
609
610 assert!(policy.allow_tool("Read", &serde_json::json!({})).await);
612 assert!(policy.allow_tool("Glob", &serde_json::json!({})).await);
613
614 assert!(!policy.allow_tool("Write", &serde_json::json!({})).await);
616 assert!(!policy.allow_tool("Edit", &serde_json::json!({})).await);
617 }
618
619 #[tokio::test]
620 async fn checker_from_config_creates_correct_types() {
621 let checker = checker_from_config("always_approve", 60, None).unwrap();
623 assert!(checker.allow_tool("test", &serde_json::json!({})).await);
624
625 let checker = checker_from_config("always_deny", 60, None).unwrap();
627 assert!(!checker.allow_tool("test", &serde_json::json!({})).await);
628
629 let patterns = vec!["^Edit$".to_string()];
631 let checker = checker_from_config("pattern", 60, Some(&patterns)).unwrap();
632 assert!(checker.allow_tool("Read", &serde_json::json!({})).await);
634
635 let result = checker_from_config("pattern", 60, None);
637 assert!(result.is_err());
638
639 let result = checker_from_config("unknown_policy", 60, None);
641 assert!(result.is_err());
642 }
643
644 #[tokio::test]
645 async fn checker_from_approval_config_with_overrides() {
646 use std::collections::HashMap;
647
648 let mut overrides = HashMap::new();
650 overrides.insert("Read".to_string(), "always_approve".to_string());
651 overrides.insert("Glob".to_string(), "always_approve".to_string());
652
653 let config = enact_config::ApprovalConfig {
654 enabled: true,
655 policy: "always_deny".to_string(),
656 max_steps: None,
657 require_patterns: None,
658 timeout_seconds: 60,
659 tool_overrides: Some(overrides),
660 };
661
662 let checker = checker_from_approval_config(&config).unwrap();
663
664 assert!(checker.allow_tool("Read", &serde_json::json!({})).await);
666 assert!(checker.allow_tool("Glob", &serde_json::json!({})).await);
667
668 assert!(!checker.allow_tool("Write", &serde_json::json!({})).await);
670 assert!(!checker.allow_tool("Edit", &serde_json::json!({})).await);
671 }
672
673 #[tokio::test]
674 async fn checker_from_approval_config_without_overrides() {
675 let config = enact_config::ApprovalConfig {
676 enabled: true,
677 policy: "always_approve".to_string(),
678 max_steps: None,
679 require_patterns: None,
680 timeout_seconds: 60,
681 tool_overrides: None,
682 };
683
684 let checker = checker_from_approval_config(&config).unwrap();
685 assert!(checker.allow_tool("anything", &serde_json::json!({})).await);
686 }
687}