agentshield/rules/
policy.rs1use std::collections::{HashMap, HashSet};
2
3use serde::{Deserialize, Serialize};
4
5use super::{Finding, Severity};
6
7#[derive(Debug, Clone, Serialize, Deserialize)]
10pub struct PolicyVerdict {
11 pub pass: bool,
12 pub total_findings: usize,
13 pub effective_findings: usize,
14 pub highest_severity: Option<Severity>,
15 pub fail_threshold: Severity,
16}
17
18#[derive(Debug, Clone, Serialize, Deserialize)]
20pub struct Policy {
21 #[serde(default = "default_fail_on")]
23 pub fail_on: Severity,
24 #[serde(default)]
26 pub ignore_rules: HashSet<String>,
27 #[serde(default)]
29 pub overrides: HashMap<String, Severity>,
30}
31
32fn default_fail_on() -> Severity {
33 Severity::High
34}
35
36impl Default for Policy {
37 fn default() -> Self {
38 Self {
39 fail_on: Severity::High,
40 ignore_rules: HashSet::new(),
41 overrides: HashMap::new(),
42 }
43 }
44}
45
46impl Policy {
47 pub fn evaluate(&self, findings: &[Finding]) -> PolicyVerdict {
49 let effective: Vec<Severity> = findings
50 .iter()
51 .filter(|f| !self.ignore_rules.contains(&f.rule_id))
52 .map(|f| {
53 self.overrides
54 .get(&f.rule_id)
55 .copied()
56 .unwrap_or(f.severity)
57 })
58 .collect();
59
60 let highest = effective.iter().copied().max();
61 let failed = effective.iter().any(|&sev| sev >= self.fail_on);
62
63 PolicyVerdict {
64 pass: !failed,
65 total_findings: findings.len(),
66 effective_findings: effective.len(),
67 highest_severity: highest,
68 fail_threshold: self.fail_on,
69 }
70 }
71
72 pub fn apply(&self, findings: &[Finding]) -> Vec<Finding> {
74 findings
75 .iter()
76 .filter(|f| !self.ignore_rules.contains(&f.rule_id))
77 .map(|f| {
78 let mut f = f.clone();
79 if let Some(&override_sev) = self.overrides.get(&f.rule_id) {
80 f.severity = override_sev;
81 }
82 f
83 })
84 .collect()
85 }
86}
87
88#[cfg(test)]
89mod tests {
90 use super::*;
91 use crate::rules::{AttackCategory, Confidence};
92
93 fn make_finding(rule_id: &str, severity: Severity) -> Finding {
94 Finding {
95 rule_id: rule_id.into(),
96 rule_name: "Test".into(),
97 severity,
98 confidence: Confidence::High,
99 attack_category: AttackCategory::CommandInjection,
100 message: "test".into(),
101 location: None,
102 evidence: vec![],
103 taint_path: None,
104 remediation: None,
105 cwe_id: None,
106 }
107 }
108
109 #[test]
110 fn default_policy_fails_on_high() {
111 let policy = Policy::default();
112 let findings = vec![make_finding("SHIELD-001", Severity::High)];
113 let verdict = policy.evaluate(&findings);
114 assert!(!verdict.pass);
115 }
116
117 #[test]
118 fn default_policy_passes_on_medium() {
119 let policy = Policy::default();
120 let findings = vec![make_finding("SHIELD-009", Severity::Medium)];
121 let verdict = policy.evaluate(&findings);
122 assert!(verdict.pass);
123 }
124
125 #[test]
126 fn ignore_rule_removes_finding() {
127 let mut policy = Policy::default();
128 policy.ignore_rules.insert("SHIELD-001".into());
129 let findings = vec![make_finding("SHIELD-001", Severity::Critical)];
130 let verdict = policy.evaluate(&findings);
131 assert!(verdict.pass);
132 assert_eq!(verdict.effective_findings, 0);
133 }
134
135 #[test]
136 fn override_downgrades_severity() {
137 let mut policy = Policy::default();
138 policy.overrides.insert("SHIELD-001".into(), Severity::Info);
139 let findings = vec![make_finding("SHIELD-001", Severity::Critical)];
140 let verdict = policy.evaluate(&findings);
141 assert!(verdict.pass);
142 }
143}