1use std::sync::Arc;
18use std::sync::atomic::{AtomicU32, Ordering};
19
20use async_trait::async_trait;
21use serde::{Deserialize, Serialize};
22
23use crate::error::PunchResult;
24use crate::fighter::FighterId;
25
26#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, PartialOrd, Ord)]
35#[serde(rename_all = "snake_case")]
36pub enum RiskLevel {
37 Low,
39 Medium,
41 High,
43 Critical,
46}
47
48impl std::fmt::Display for RiskLevel {
49 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
50 match self {
51 Self::Low => write!(f, "low"),
52 Self::Medium => write!(f, "medium"),
53 Self::High => write!(f, "high"),
54 Self::Critical => write!(f, "critical"),
55 }
56 }
57}
58
59#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
65#[serde(rename_all = "snake_case", tag = "decision", content = "reason")]
66pub enum ApprovalDecision {
67 Allow,
69 Deny(String),
71 NeedsApproval(String),
73}
74
75#[derive(Debug, Clone, Serialize, Deserialize)]
81pub struct ApprovalRequest {
82 pub tool_name: String,
84 pub input_summary: String,
86 pub risk_level: RiskLevel,
88 pub fighter_id: FighterId,
90 pub reason: String,
92}
93
94#[derive(Debug, Clone, Serialize, Deserialize)]
104pub struct ApprovalPolicy {
105 pub name: String,
107 pub tool_patterns: Vec<String>,
109 pub risk_level: RiskLevel,
111 pub auto_approve: bool,
114 pub max_auto_approvals: Option<u32>,
117}
118
119#[async_trait]
129pub trait ApprovalHandler: Send + Sync {
130 async fn request_approval(&self, request: &ApprovalRequest) -> PunchResult<ApprovalDecision>;
133}
134
135#[derive(Debug, Clone)]
144pub struct AutoApproveHandler;
145
146#[async_trait]
147impl ApprovalHandler for AutoApproveHandler {
148 async fn request_approval(&self, _request: &ApprovalRequest) -> PunchResult<ApprovalDecision> {
149 Ok(ApprovalDecision::Allow)
150 }
151}
152
153#[derive(Debug, Clone)]
158pub struct DenyAllHandler;
159
160#[async_trait]
161impl ApprovalHandler for DenyAllHandler {
162 async fn request_approval(&self, request: &ApprovalRequest) -> PunchResult<ApprovalDecision> {
163 Ok(ApprovalDecision::Deny(format!(
164 "all tool calls denied by policy: {}",
165 request.tool_name
166 )))
167 }
168}
169
170pub struct PolicyEngine {
180 policies: Vec<ApprovalPolicy>,
182 handler: Arc<dyn ApprovalHandler>,
184 auto_approve_counts: Vec<AtomicU32>,
187}
188
189impl std::fmt::Debug for PolicyEngine {
190 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
191 f.debug_struct("PolicyEngine")
192 .field("policies", &self.policies)
193 .field(
194 "auto_approve_counts",
195 &self
196 .auto_approve_counts
197 .iter()
198 .map(|c| c.load(Ordering::Relaxed))
199 .collect::<Vec<_>>(),
200 )
201 .finish()
202 }
203}
204
205impl PolicyEngine {
206 pub fn new(policies: Vec<ApprovalPolicy>, handler: Arc<dyn ApprovalHandler>) -> Self {
208 let auto_approve_counts = policies.iter().map(|_| AtomicU32::new(0)).collect();
209 Self {
210 policies,
211 handler,
212 auto_approve_counts,
213 }
214 }
215
216 pub async fn evaluate(
225 &self,
226 tool_name: &str,
227 input: &serde_json::Value,
228 fighter_id: &FighterId,
229 ) -> PunchResult<ApprovalDecision> {
230 let matched = self.find_matching_policy(tool_name);
232
233 let Some((policy_index, policy)) = matched else {
234 return Ok(ApprovalDecision::Allow);
236 };
237
238 if policy.risk_level == RiskLevel::Critical {
240 let request = Self::build_request(
241 tool_name,
242 input,
243 policy.risk_level,
244 fighter_id,
245 &format!(
246 "critical risk tool '{}' matched policy '{}'",
247 tool_name, policy.name
248 ),
249 );
250 return self.handler.request_approval(&request).await;
251 }
252
253 if policy.auto_approve {
255 if let Some(max) = policy.max_auto_approvals {
256 let current =
257 self.auto_approve_counts[policy_index].fetch_add(1, Ordering::Relaxed);
258 if current < max {
259 return Ok(ApprovalDecision::Allow);
260 }
261 let request = Self::build_request(
263 tool_name,
264 input,
265 policy.risk_level,
266 fighter_id,
267 &format!(
268 "auto-approval limit ({}) reached for policy '{}'",
269 max, policy.name
270 ),
271 );
272 return self.handler.request_approval(&request).await;
273 }
274 return Ok(ApprovalDecision::Allow);
276 }
277
278 let request = Self::build_request(
280 tool_name,
281 input,
282 policy.risk_level,
283 fighter_id,
284 &format!(
285 "tool '{}' matched policy '{}' (risk: {})",
286 tool_name, policy.name, policy.risk_level
287 ),
288 );
289 self.handler.request_approval(&request).await
290 }
291
292 fn find_matching_policy(&self, tool_name: &str) -> Option<(usize, &ApprovalPolicy)> {
295 for (i, policy) in self.policies.iter().enumerate() {
296 for pattern_str in &policy.tool_patterns {
297 if pattern_str == "*" || pattern_str == "**" {
298 return Some((i, policy));
299 }
300 if let Ok(pattern) = glob::Pattern::new(pattern_str)
301 && pattern.matches(tool_name)
302 {
303 return Some((i, policy));
304 }
305 }
306 }
307 None
308 }
309
310 fn build_request(
312 tool_name: &str,
313 input: &serde_json::Value,
314 risk_level: RiskLevel,
315 fighter_id: &FighterId,
316 reason: &str,
317 ) -> ApprovalRequest {
318 let input_summary = match input {
320 serde_json::Value::Object(map) => {
321 let pairs: Vec<String> = map
322 .iter()
323 .take(5)
324 .map(|(k, v)| {
325 let v_str = match v {
326 serde_json::Value::String(s) => {
327 if s.len() > 100 {
328 format!("{}...", &s[..100])
329 } else {
330 s.clone()
331 }
332 }
333 other => {
334 let s = other.to_string();
335 if s.len() > 100 {
336 format!("{}...", &s[..100])
337 } else {
338 s
339 }
340 }
341 };
342 format!("{}: {}", k, v_str)
343 })
344 .collect();
345 pairs.join(", ")
346 }
347 other => {
348 let s = other.to_string();
349 if s.len() > 200 {
350 format!("{}...", &s[..200])
351 } else {
352 s
353 }
354 }
355 };
356
357 ApprovalRequest {
358 tool_name: tool_name.to_string(),
359 input_summary,
360 risk_level,
361 fighter_id: *fighter_id,
362 reason: reason.to_string(),
363 }
364 }
365
366 pub fn auto_approve_count(&self, policy_index: usize) -> Option<u32> {
369 self.auto_approve_counts
370 .get(policy_index)
371 .map(|c| c.load(Ordering::Relaxed))
372 }
373
374 pub fn reset_counters(&self) {
376 for counter in &self.auto_approve_counts {
377 counter.store(0, Ordering::Relaxed);
378 }
379 }
380}
381
382#[cfg(test)]
387mod tests {
388 use super::*;
389 use uuid::Uuid;
390
391 fn test_fighter_id() -> FighterId {
392 FighterId(Uuid::nil())
393 }
394
395 #[test]
398 fn test_risk_level_ordering() {
399 assert!(RiskLevel::Low < RiskLevel::Medium);
400 assert!(RiskLevel::Medium < RiskLevel::High);
401 assert!(RiskLevel::High < RiskLevel::Critical);
402 assert!(RiskLevel::Low < RiskLevel::Critical);
403 }
404
405 #[test]
408 fn test_policy_matching_exact() {
409 let engine = PolicyEngine::new(
410 vec![ApprovalPolicy {
411 name: "block-shell".into(),
412 tool_patterns: vec!["shell_exec".into()],
413 risk_level: RiskLevel::High,
414 auto_approve: false,
415 max_auto_approvals: None,
416 }],
417 Arc::new(DenyAllHandler),
418 );
419 let matched = engine.find_matching_policy("shell_exec");
420 assert!(matched.is_some());
421 assert_eq!(
422 matched.as_ref().map(|(_, p)| p.name.as_str()),
423 Some("block-shell")
424 );
425 }
426
427 #[test]
428 fn test_policy_matching_wildcard() {
429 let engine = PolicyEngine::new(
430 vec![ApprovalPolicy {
431 name: "all-file-ops".into(),
432 tool_patterns: vec!["file_*".into()],
433 risk_level: RiskLevel::Medium,
434 auto_approve: true,
435 max_auto_approvals: None,
436 }],
437 Arc::new(AutoApproveHandler),
438 );
439 assert!(engine.find_matching_policy("file_read").is_some());
440 assert!(engine.find_matching_policy("file_write").is_some());
441 assert!(engine.find_matching_policy("file_list").is_some());
442 assert!(engine.find_matching_policy("shell_exec").is_none());
443 }
444
445 #[test]
446 fn test_policy_matching_no_match() {
447 let engine = PolicyEngine::new(
448 vec![ApprovalPolicy {
449 name: "shell-only".into(),
450 tool_patterns: vec!["shell_*".into()],
451 risk_level: RiskLevel::High,
452 auto_approve: false,
453 max_auto_approvals: None,
454 }],
455 Arc::new(DenyAllHandler),
456 );
457 assert!(engine.find_matching_policy("file_read").is_none());
458 assert!(engine.find_matching_policy("web_fetch").is_none());
459 }
460
461 #[tokio::test]
464 async fn test_auto_approve_counter() {
465 let engine = PolicyEngine::new(
466 vec![ApprovalPolicy {
467 name: "limited-reads".into(),
468 tool_patterns: vec!["file_read".into()],
469 risk_level: RiskLevel::Low,
470 auto_approve: true,
471 max_auto_approvals: Some(3),
472 }],
473 Arc::new(DenyAllHandler),
474 );
475
476 let fid = test_fighter_id();
477 let input = serde_json::json!({"path": "test.txt"});
478
479 for _ in 0..3 {
481 let decision = engine
482 .evaluate("file_read", &input, &fid)
483 .await
484 .expect("evaluate failed");
485 assert_eq!(decision, ApprovalDecision::Allow);
486 }
487
488 let decision = engine
490 .evaluate("file_read", &input, &fid)
491 .await
492 .expect("evaluate failed");
493 match decision {
494 ApprovalDecision::Deny(_) => {} other => panic!("expected Deny, got {:?}", other),
496 }
497 }
498
499 #[tokio::test]
502 async fn test_auto_approve_handler_always_approves() {
503 let handler = AutoApproveHandler;
504 let request = ApprovalRequest {
505 tool_name: "shell_exec".into(),
506 input_summary: "rm -rf /".into(),
507 risk_level: RiskLevel::Critical,
508 fighter_id: test_fighter_id(),
509 reason: "test".into(),
510 };
511 let decision = handler
512 .request_approval(&request)
513 .await
514 .expect("handler failed");
515 assert_eq!(decision, ApprovalDecision::Allow);
516 }
517
518 #[tokio::test]
521 async fn test_deny_all_handler_always_denies() {
522 let handler = DenyAllHandler;
523 let request = ApprovalRequest {
524 tool_name: "file_read".into(),
525 input_summary: "path: readme.md".into(),
526 risk_level: RiskLevel::Low,
527 fighter_id: test_fighter_id(),
528 reason: "test".into(),
529 };
530 let decision = handler
531 .request_approval(&request)
532 .await
533 .expect("handler failed");
534 match decision {
535 ApprovalDecision::Deny(_) => {} other => panic!("expected Deny, got {:?}", other),
537 }
538 }
539
540 #[tokio::test]
543 async fn test_evaluate_first_match_wins() {
544 let engine = PolicyEngine::new(
545 vec![
546 ApprovalPolicy {
547 name: "allow-file-read".into(),
548 tool_patterns: vec!["file_read".into()],
549 risk_level: RiskLevel::Low,
550 auto_approve: true,
551 max_auto_approvals: None,
552 },
553 ApprovalPolicy {
554 name: "deny-all-files".into(),
555 tool_patterns: vec!["file_*".into()],
556 risk_level: RiskLevel::High,
557 auto_approve: false,
558 max_auto_approvals: None,
559 },
560 ],
561 Arc::new(DenyAllHandler),
562 );
563
564 let fid = test_fighter_id();
565 let input = serde_json::json!({"path": "test.txt"});
566
567 let decision = engine
569 .evaluate("file_read", &input, &fid)
570 .await
571 .expect("evaluate failed");
572 assert_eq!(decision, ApprovalDecision::Allow);
573
574 let decision = engine
576 .evaluate("file_write", &input, &fid)
577 .await
578 .expect("evaluate failed");
579 match decision {
580 ApprovalDecision::Deny(_) => {} other => panic!("expected Deny for file_write, got {:?}", other),
582 }
583 }
584
585 #[tokio::test]
588 async fn test_empty_policies_allow_all() {
589 let engine = PolicyEngine::new(vec![], Arc::new(DenyAllHandler));
590 let fid = test_fighter_id();
591 let input = serde_json::json!({"command": "rm -rf /"});
592
593 let decision = engine
594 .evaluate("shell_exec", &input, &fid)
595 .await
596 .expect("evaluate failed");
597 assert_eq!(decision, ApprovalDecision::Allow);
598 }
599
600 #[tokio::test]
603 async fn test_critical_risk_requires_approval_even_with_auto_approve() {
604 let engine = PolicyEngine::new(
605 vec![ApprovalPolicy {
606 name: "critical-shell".into(),
607 tool_patterns: vec!["shell_exec".into()],
608 risk_level: RiskLevel::Critical,
609 auto_approve: true, max_auto_approvals: None,
611 }],
612 Arc::new(DenyAllHandler),
613 );
614
615 let fid = test_fighter_id();
616 let input = serde_json::json!({"command": "rm -rf /"});
617
618 let decision = engine
620 .evaluate("shell_exec", &input, &fid)
621 .await
622 .expect("evaluate failed");
623 match decision {
624 ApprovalDecision::Deny(_) => {} other => panic!("expected Deny for critical tool, got {:?}", other),
626 }
627 }
628
629 #[test]
632 fn test_approval_request_serialization() {
633 let request = ApprovalRequest {
634 tool_name: "file_write".into(),
635 input_summary: "path: /etc/passwd, content: hacked".into(),
636 risk_level: RiskLevel::Critical,
637 fighter_id: test_fighter_id(),
638 reason: "critical operation detected".into(),
639 };
640
641 let json = serde_json::to_string(&request).expect("serialization failed");
642 let deserialized: ApprovalRequest =
643 serde_json::from_str(&json).expect("deserialization failed");
644
645 assert_eq!(deserialized.tool_name, "file_write");
646 assert_eq!(deserialized.risk_level, RiskLevel::Critical);
647 assert_eq!(deserialized.reason, "critical operation detected");
648 }
649
650 #[test]
653 fn test_approval_decision_serialization() {
654 let allow = ApprovalDecision::Allow;
655 let deny = ApprovalDecision::Deny("not permitted".into());
656 let needs = ApprovalDecision::NeedsApproval("requires human review".into());
657
658 let allow_json = serde_json::to_string(&allow).expect("serialize allow");
659 let deny_json = serde_json::to_string(&deny).expect("serialize deny");
660 let needs_json = serde_json::to_string(&needs).expect("serialize needs_approval");
661
662 let allow_back: ApprovalDecision = serde_json::from_str(&allow_json).expect("deser allow");
663 let deny_back: ApprovalDecision = serde_json::from_str(&deny_json).expect("deser deny");
664 let needs_back: ApprovalDecision = serde_json::from_str(&needs_json).expect("deser needs");
665
666 assert_eq!(allow_back, ApprovalDecision::Allow);
667 assert_eq!(deny_back, ApprovalDecision::Deny("not permitted".into()));
668 assert_eq!(
669 needs_back,
670 ApprovalDecision::NeedsApproval("requires human review".into())
671 );
672 }
673
674 #[tokio::test]
677 async fn test_catch_all_policy() {
678 let engine = PolicyEngine::new(
679 vec![ApprovalPolicy {
680 name: "catch-all".into(),
681 tool_patterns: vec!["*".into()],
682 risk_level: RiskLevel::Medium,
683 auto_approve: false,
684 max_auto_approvals: None,
685 }],
686 Arc::new(DenyAllHandler),
687 );
688
689 let fid = test_fighter_id();
690 let input = serde_json::json!({});
691
692 for tool in &["file_read", "shell_exec", "web_fetch", "memory_store"] {
694 let decision = engine
695 .evaluate(tool, &input, &fid)
696 .await
697 .expect("evaluate failed");
698 match decision {
699 ApprovalDecision::Deny(_) => {} other => panic!("expected Deny for {}, got {:?}", tool, other),
701 }
702 }
703 }
704
705 #[tokio::test]
708 async fn test_reset_counters() {
709 let engine = PolicyEngine::new(
710 vec![ApprovalPolicy {
711 name: "limited".into(),
712 tool_patterns: vec!["file_read".into()],
713 risk_level: RiskLevel::Low,
714 auto_approve: true,
715 max_auto_approvals: Some(2),
716 }],
717 Arc::new(DenyAllHandler),
718 );
719
720 let fid = test_fighter_id();
721 let input = serde_json::json!({"path": "test.txt"});
722
723 engine
725 .evaluate("file_read", &input, &fid)
726 .await
727 .expect("eval 1");
728 engine
729 .evaluate("file_read", &input, &fid)
730 .await
731 .expect("eval 2");
732 assert_eq!(engine.auto_approve_count(0), Some(2));
733
734 engine.reset_counters();
736 assert_eq!(engine.auto_approve_count(0), Some(0));
737
738 let decision = engine
740 .evaluate("file_read", &input, &fid)
741 .await
742 .expect("eval after reset");
743 assert_eq!(decision, ApprovalDecision::Allow);
744 }
745}