Skip to main content

agentshield/
baseline.rs

1//! Baseline schema for tracking known findings across scan runs.
2//!
3//! A baseline file records previously seen findings (by fingerprint) so that
4//! subsequent scans can suppress already-known issues and surface only new ones.
5
6use std::path::Path;
7
8use serde::{Deserialize, Serialize};
9
10use crate::error::{Result, ShieldError};
11
12const CURRENT_SCHEMA_VERSION: u32 = 1;
13
14/// A versioned file that records known findings by fingerprint.
15#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct BaselineFile {
17    /// Schema version — incremented when the format changes in a breaking way.
18    pub schema_version: u32,
19    /// RFC3339 timestamp of when this baseline was created.
20    pub created_at: String,
21    /// Version of agentshield that wrote this file.
22    pub tool_version: String,
23    /// Known finding entries.
24    pub entries: Vec<BaselineEntry>,
25}
26
27/// A single known finding recorded in the baseline.
28#[derive(Debug, Clone, Serialize, Deserialize)]
29pub struct BaselineEntry {
30    /// Stable fingerprint that uniquely identifies this finding.
31    pub fingerprint: String,
32    /// Rule that produced this finding (e.g. `"SHIELD-001"`).
33    pub rule_id: String,
34    /// RFC3339 timestamp of when this finding was first observed.
35    pub first_seen: String,
36}
37
38impl BaselineFile {
39    /// Create a new baseline from a list of entries, stamped with the current time.
40    pub fn new(entries: Vec<BaselineEntry>) -> Self {
41        Self {
42            schema_version: CURRENT_SCHEMA_VERSION,
43            created_at: chrono::Utc::now().to_rfc3339(),
44            tool_version: env!("CARGO_PKG_VERSION").to_string(),
45            entries,
46        }
47    }
48
49    /// Load a baseline from a JSON file on disk.
50    ///
51    /// Returns an error if the file cannot be read, if the JSON is malformed,
52    /// or if the `schema_version` is newer than this tool supports.
53    pub fn load(path: &Path) -> Result<Self> {
54        let content = std::fs::read_to_string(path)?;
55        let baseline: Self = serde_json::from_str(&content)?;
56        if baseline.schema_version > CURRENT_SCHEMA_VERSION {
57            return Err(ShieldError::Internal(format!(
58                "Baseline schema version {} is newer than supported version {}; \
59                 please upgrade agentshield",
60                baseline.schema_version, CURRENT_SCHEMA_VERSION
61            )));
62        }
63        Ok(baseline)
64    }
65
66    /// Persist this baseline to a JSON file on disk (pretty-printed).
67    pub fn save(&self, path: &Path) -> Result<()> {
68        let content = serde_json::to_string_pretty(self)?;
69        std::fs::write(path, content)?;
70        Ok(())
71    }
72
73    /// Return `true` if the given fingerprint is recorded in this baseline.
74    pub fn contains(&self, fingerprint: &str) -> bool {
75        self.entries.iter().any(|e| e.fingerprint == fingerprint)
76    }
77}
78
79#[cfg(test)]
80mod tests {
81    use super::*;
82    use tempfile::NamedTempFile;
83
84    fn make_entry(fingerprint: &str, rule_id: &str) -> BaselineEntry {
85        BaselineEntry {
86            fingerprint: fingerprint.to_string(),
87            rule_id: rule_id.to_string(),
88            first_seen: "2026-03-20T00:00:00Z".to_string(),
89        }
90    }
91
92    #[test]
93    fn test_round_trip_serialization() {
94        let baseline = BaselineFile::new(vec![
95            make_entry("abc123", "SHIELD-001"),
96            make_entry("def456", "SHIELD-003"),
97        ]);
98
99        let tmp = NamedTempFile::new().unwrap();
100        baseline.save(tmp.path()).unwrap();
101
102        let loaded = BaselineFile::load(tmp.path()).unwrap();
103        assert_eq!(loaded.schema_version, 1);
104        assert_eq!(loaded.entries.len(), 2);
105        assert_eq!(loaded.entries[0].fingerprint, "abc123");
106        assert_eq!(loaded.entries[1].rule_id, "SHIELD-003");
107    }
108
109    #[test]
110    fn test_contains_present() {
111        let baseline = BaselineFile::new(vec![make_entry("abc123", "SHIELD-001")]);
112        assert!(baseline.contains("abc123"));
113    }
114
115    #[test]
116    fn test_contains_absent() {
117        let baseline = BaselineFile::new(vec![make_entry("abc123", "SHIELD-001")]);
118        assert!(!baseline.contains("xyz789"));
119    }
120
121    #[test]
122    fn test_empty_baseline_round_trip() {
123        let baseline = BaselineFile::new(vec![]);
124        let tmp = NamedTempFile::new().unwrap();
125        baseline.save(tmp.path()).unwrap();
126        let loaded = BaselineFile::load(tmp.path()).unwrap();
127        assert_eq!(loaded.entries.len(), 0);
128        assert_eq!(loaded.schema_version, 1);
129    }
130
131    #[test]
132    fn test_future_schema_version_rejected() {
133        let json = r#"{
134            "schema_version": 99,
135            "created_at": "2026-03-20T00:00:00Z",
136            "tool_version": "0.2.4",
137            "entries": []
138        }"#;
139        let tmp = NamedTempFile::new().unwrap();
140        std::fs::write(tmp.path(), json).unwrap();
141        let result = BaselineFile::load(tmp.path());
142        assert!(result.is_err());
143        let msg = result.unwrap_err().to_string();
144        assert!(
145            msg.contains("newer than supported"),
146            "error message should explain the version mismatch, got: {msg}"
147        );
148    }
149
150    #[test]
151    fn test_current_schema_version_accepted() {
152        let baseline = BaselineFile::new(vec![make_entry("fp1", "SHIELD-007")]);
153        let tmp = NamedTempFile::new().unwrap();
154        baseline.save(tmp.path()).unwrap();
155        // schema_version == CURRENT_SCHEMA_VERSION must load without error
156        assert!(BaselineFile::load(tmp.path()).is_ok());
157    }
158
159    #[test]
160    fn test_tool_version_populated() {
161        let baseline = BaselineFile::new(vec![]);
162        assert!(!baseline.tool_version.is_empty());
163    }
164
165    #[test]
166    fn test_created_at_is_rfc3339() {
167        let baseline = BaselineFile::new(vec![]);
168        // chrono should produce a valid RFC3339 string; verify it parses back
169        chrono::DateTime::parse_from_rfc3339(&baseline.created_at)
170            .expect("created_at must be valid RFC3339");
171    }
172}