Skip to main content

agentshield/rules/
policy.rs

1use std::collections::{HashMap, HashSet};
2use std::path::Path;
3
4use chrono::Utc;
5use serde::{Deserialize, Serialize};
6
7use super::{Finding, Severity};
8
9/// Policy verdict — the final pass/fail decision after applying
10/// ignore list and severity overrides to raw findings.
11#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct PolicyVerdict {
13    pub pass: bool,
14    pub total_findings: usize,
15    pub effective_findings: usize,
16    pub highest_severity: Option<Severity>,
17    pub fail_threshold: Severity,
18}
19
20/// A suppression entry that silences a specific finding by fingerprint.
21#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct Suppression {
23    /// SHA-256 fingerprint of the finding to suppress.
24    pub fingerprint: String,
25    /// Mandatory reason explaining why this finding is suppressed.
26    pub reason: String,
27    /// Optional ISO-8601 date (YYYY-MM-DD) after which the suppression expires.
28    #[serde(default, skip_serializing_if = "Option::is_none")]
29    pub expires: Option<String>,
30    /// Optional ISO-8601 date when the suppression was created.
31    #[serde(default, skip_serializing_if = "Option::is_none")]
32    pub created_at: Option<String>,
33}
34
35impl Suppression {
36    /// Returns `true` if this suppression has passed its expiration date.
37    pub fn is_expired(&self) -> bool {
38        if let Some(ref date_str) = self.expires {
39            if let Ok(expires_date) = chrono::NaiveDate::parse_from_str(date_str, "%Y-%m-%d") {
40                return expires_date < Utc::now().date_naive();
41            }
42        }
43        false
44    }
45}
46
47/// Policy configuration loaded from `.agentshield.toml`.
48#[derive(Debug, Clone, Serialize, Deserialize)]
49pub struct Policy {
50    /// Minimum severity to fail the scan.
51    #[serde(default = "default_fail_on")]
52    pub fail_on: Severity,
53    /// Rule IDs to ignore entirely.
54    #[serde(default)]
55    pub ignore_rules: HashSet<String>,
56    /// Per-rule severity overrides.
57    #[serde(default)]
58    pub overrides: HashMap<String, Severity>,
59    /// Per-finding suppressions by fingerprint.
60    #[serde(default)]
61    pub suppressions: Vec<Suppression>,
62}
63
64fn default_fail_on() -> Severity {
65    Severity::High
66}
67
68impl Default for Policy {
69    fn default() -> Self {
70        Self {
71            fail_on: Severity::High,
72            ignore_rules: HashSet::new(),
73            overrides: HashMap::new(),
74            suppressions: Vec::new(),
75        }
76    }
77}
78
79impl Policy {
80    /// Evaluate findings against this policy and produce a verdict.
81    pub fn evaluate(&self, findings: &[Finding]) -> PolicyVerdict {
82        let effective: Vec<Severity> = findings
83            .iter()
84            .filter(|f| !self.ignore_rules.contains(&f.rule_id))
85            .map(|f| {
86                self.overrides
87                    .get(&f.rule_id)
88                    .copied()
89                    .unwrap_or(f.severity)
90            })
91            .collect();
92
93        let highest = effective.iter().copied().max();
94        let failed = effective.iter().any(|&sev| sev >= self.fail_on);
95
96        PolicyVerdict {
97            pass: !failed,
98            total_findings: findings.len(),
99            effective_findings: effective.len(),
100            highest_severity: highest,
101            fail_threshold: self.fail_on,
102        }
103    }
104
105    /// Build a set of active (non-expired) suppression fingerprints.
106    /// Logs a warning to stderr for each expired suppression.
107    fn active_suppressions(&self) -> HashSet<&str> {
108        let mut active = HashSet::new();
109        for s in &self.suppressions {
110            if s.is_expired() {
111                eprintln!(
112                    "warning: suppression for fingerprint {} has expired (expires: {})",
113                    &s.fingerprint,
114                    s.expires.as_deref().unwrap_or("unknown"),
115                );
116            } else {
117                active.insert(s.fingerprint.as_str());
118            }
119        }
120        active
121    }
122
123    /// Filter findings: remove ignored rules, apply overrides,
124    /// and filter out suppressed findings.
125    pub fn apply(&self, findings: &[Finding], scan_root: &Path) -> Vec<Finding> {
126        let suppressed = self.active_suppressions();
127
128        findings
129            .iter()
130            .filter(|f| !self.ignore_rules.contains(&f.rule_id))
131            .filter(|f| {
132                if suppressed.is_empty() {
133                    return true;
134                }
135                let fp = f.fingerprint(scan_root);
136                !suppressed.contains(fp.as_str())
137            })
138            .map(|f| {
139                let mut f = f.clone();
140                if let Some(&override_sev) = self.overrides.get(&f.rule_id) {
141                    f.severity = override_sev;
142                }
143                f
144            })
145            .collect()
146    }
147}
148
149#[cfg(test)]
150mod tests {
151    use std::path::PathBuf;
152
153    use super::*;
154    use crate::ir::SourceLocation;
155    use crate::rules::{AttackCategory, Confidence, Evidence};
156
157    fn make_finding(rule_id: &str, severity: Severity) -> Finding {
158        Finding {
159            rule_id: rule_id.into(),
160            rule_name: "Test".into(),
161            severity,
162            confidence: Confidence::High,
163            attack_category: AttackCategory::CommandInjection,
164            message: "test".into(),
165            location: None,
166            evidence: vec![],
167            taint_path: None,
168            remediation: None,
169            cwe_id: None,
170        }
171    }
172
173    fn make_finding_with_location(
174        rule_id: &str,
175        severity: Severity,
176        file: &str,
177        evidence_desc: &str,
178    ) -> Finding {
179        Finding {
180            rule_id: rule_id.into(),
181            rule_name: "Test".into(),
182            severity,
183            confidence: Confidence::High,
184            attack_category: AttackCategory::CommandInjection,
185            message: "test".into(),
186            location: Some(SourceLocation {
187                file: PathBuf::from(file),
188                line: 10,
189                column: 0,
190                end_line: None,
191                end_column: None,
192            }),
193            evidence: vec![Evidence {
194                description: evidence_desc.into(),
195                location: None,
196                snippet: None,
197            }],
198            taint_path: None,
199            remediation: None,
200            cwe_id: None,
201        }
202    }
203
204    #[test]
205    fn default_policy_fails_on_high() {
206        let policy = Policy::default();
207        let findings = vec![make_finding("SHIELD-001", Severity::High)];
208        let verdict = policy.evaluate(&findings);
209        assert!(!verdict.pass);
210    }
211
212    #[test]
213    fn default_policy_passes_on_medium() {
214        let policy = Policy::default();
215        let findings = vec![make_finding("SHIELD-009", Severity::Medium)];
216        let verdict = policy.evaluate(&findings);
217        assert!(verdict.pass);
218    }
219
220    #[test]
221    fn ignore_rule_removes_finding() {
222        let mut policy = Policy::default();
223        policy.ignore_rules.insert("SHIELD-001".into());
224        let findings = vec![make_finding("SHIELD-001", Severity::Critical)];
225        let verdict = policy.evaluate(&findings);
226        assert!(verdict.pass);
227        assert_eq!(verdict.effective_findings, 0);
228    }
229
230    #[test]
231    fn override_downgrades_severity() {
232        let mut policy = Policy::default();
233        policy.overrides.insert("SHIELD-001".into(), Severity::Info);
234        let findings = vec![make_finding("SHIELD-001", Severity::Critical)];
235        let verdict = policy.evaluate(&findings);
236        assert!(verdict.pass);
237    }
238
239    #[test]
240    fn suppression_filters_matching_finding() {
241        let scan_root = Path::new("/project");
242        let finding = make_finding_with_location(
243            "SHIELD-001",
244            Severity::Critical,
245            "/project/src/main.py",
246            "subprocess.run receives parameter",
247        );
248        let fp = finding.fingerprint(scan_root);
249
250        let mut policy = Policy::default();
251        policy.suppressions.push(Suppression {
252            fingerprint: fp,
253            reason: "False positive: validated by middleware".into(),
254            expires: None,
255            created_at: None,
256        });
257
258        let result = policy.apply(&[finding], scan_root);
259        assert!(
260            result.is_empty(),
261            "Suppressed finding should be filtered out"
262        );
263    }
264
265    #[test]
266    fn expired_suppression_does_not_filter() {
267        let scan_root = Path::new("/project");
268        let finding = make_finding_with_location(
269            "SHIELD-001",
270            Severity::Critical,
271            "/project/src/main.py",
272            "subprocess.run receives parameter",
273        );
274        let fp = finding.fingerprint(scan_root);
275
276        let mut policy = Policy::default();
277        policy.suppressions.push(Suppression {
278            fingerprint: fp,
279            reason: "Was a false positive".into(),
280            expires: Some("2020-01-01".into()),
281            created_at: None,
282        });
283
284        let result = policy.apply(&[finding], scan_root);
285        assert_eq!(
286            result.len(),
287            1,
288            "Expired suppression should not filter the finding"
289        );
290    }
291
292    #[test]
293    fn unexpired_suppression_filters() {
294        let scan_root = Path::new("/project");
295        let finding = make_finding_with_location(
296            "SHIELD-001",
297            Severity::Critical,
298            "/project/src/main.py",
299            "subprocess.run receives parameter",
300        );
301        let fp = finding.fingerprint(scan_root);
302
303        let mut policy = Policy::default();
304        policy.suppressions.push(Suppression {
305            fingerprint: fp,
306            reason: "Accepted risk: internal tool".into(),
307            expires: Some("2099-12-31".into()),
308            created_at: None,
309        });
310
311        let result = policy.apply(&[finding], scan_root);
312        assert!(
313            result.is_empty(),
314            "Unexpired suppression should filter the finding"
315        );
316    }
317
318    #[test]
319    fn suppression_no_expiry_always_filters() {
320        let scan_root = Path::new("/project");
321        let finding = make_finding_with_location(
322            "SHIELD-001",
323            Severity::Critical,
324            "/project/src/main.py",
325            "subprocess.run receives parameter",
326        );
327        let fp = finding.fingerprint(scan_root);
328
329        let mut policy = Policy::default();
330        policy.suppressions.push(Suppression {
331            fingerprint: fp,
332            reason: "Permanent suppression".into(),
333            expires: None,
334            created_at: None,
335        });
336
337        let result = policy.apply(&[finding], scan_root);
338        assert!(
339            result.is_empty(),
340            "Suppression without expiry should always filter"
341        );
342    }
343
344    #[test]
345    fn suppression_without_reason_rejected() {
346        let toml_str = r#"
347[policy]
348fail_on = "high"
349
350[[policy.suppressions]]
351fingerprint = "abc123"
352reason = "  "
353"#;
354        let config: crate::config::Config = toml::from_str(toml_str).unwrap();
355        let result = config.validate_for_test();
356        assert!(
357            result.is_err(),
358            "Suppression with whitespace-only reason should be rejected"
359        );
360    }
361
362    #[test]
363    fn is_expired_with_past_date() {
364        let s = Suppression {
365            fingerprint: "abc".into(),
366            reason: "test".into(),
367            expires: Some("2020-01-01".into()),
368            created_at: None,
369        };
370        assert!(s.is_expired());
371    }
372
373    #[test]
374    fn is_expired_with_future_date() {
375        let s = Suppression {
376            fingerprint: "abc".into(),
377            reason: "test".into(),
378            expires: Some("2099-12-31".into()),
379            created_at: None,
380        };
381        assert!(!s.is_expired());
382    }
383
384    #[test]
385    fn is_expired_with_no_date() {
386        let s = Suppression {
387            fingerprint: "abc".into(),
388            reason: "test".into(),
389            expires: None,
390            created_at: None,
391        };
392        assert!(!s.is_expired());
393    }
394}