1use std::path::Path;
7
8use serde::{Deserialize, Serialize};
9
10use crate::error::{Result, ShieldError};
11
12const CURRENT_SCHEMA_VERSION: u32 = 1;
13
14#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct BaselineFile {
17 pub schema_version: u32,
19 pub created_at: String,
21 pub tool_version: String,
23 pub entries: Vec<BaselineEntry>,
25}
26
27#[derive(Debug, Clone, Serialize, Deserialize)]
29pub struct BaselineEntry {
30 pub fingerprint: String,
32 pub rule_id: String,
34 pub first_seen: String,
36}
37
38impl BaselineFile {
39 pub fn new(entries: Vec<BaselineEntry>) -> Self {
41 Self {
42 schema_version: CURRENT_SCHEMA_VERSION,
43 created_at: chrono::Utc::now().to_rfc3339(),
44 tool_version: env!("CARGO_PKG_VERSION").to_string(),
45 entries,
46 }
47 }
48
49 pub fn load(path: &Path) -> Result<Self> {
54 let content = std::fs::read_to_string(path)?;
55 let baseline: Self = serde_json::from_str(&content)?;
56 if baseline.schema_version > CURRENT_SCHEMA_VERSION {
57 return Err(ShieldError::Internal(format!(
58 "Baseline schema version {} is newer than supported version {}; \
59 please upgrade agentshield",
60 baseline.schema_version, CURRENT_SCHEMA_VERSION
61 )));
62 }
63 Ok(baseline)
64 }
65
66 pub fn save(&self, path: &Path) -> Result<()> {
68 let content = serde_json::to_string_pretty(self)?;
69 std::fs::write(path, content)?;
70 Ok(())
71 }
72
73 pub fn contains(&self, fingerprint: &str) -> bool {
75 self.entries.iter().any(|e| e.fingerprint == fingerprint)
76 }
77}
78
79#[cfg(test)]
80mod tests {
81 use super::*;
82 use tempfile::NamedTempFile;
83
84 fn make_entry(fingerprint: &str, rule_id: &str) -> BaselineEntry {
85 BaselineEntry {
86 fingerprint: fingerprint.to_string(),
87 rule_id: rule_id.to_string(),
88 first_seen: "2026-03-20T00:00:00Z".to_string(),
89 }
90 }
91
92 #[test]
93 fn test_round_trip_serialization() {
94 let baseline = BaselineFile::new(vec![
95 make_entry("abc123", "SHIELD-001"),
96 make_entry("def456", "SHIELD-003"),
97 ]);
98
99 let tmp = NamedTempFile::new().unwrap();
100 baseline.save(tmp.path()).unwrap();
101
102 let loaded = BaselineFile::load(tmp.path()).unwrap();
103 assert_eq!(loaded.schema_version, 1);
104 assert_eq!(loaded.entries.len(), 2);
105 assert_eq!(loaded.entries[0].fingerprint, "abc123");
106 assert_eq!(loaded.entries[1].rule_id, "SHIELD-003");
107 }
108
109 #[test]
110 fn test_contains_present() {
111 let baseline = BaselineFile::new(vec![make_entry("abc123", "SHIELD-001")]);
112 assert!(baseline.contains("abc123"));
113 }
114
115 #[test]
116 fn test_contains_absent() {
117 let baseline = BaselineFile::new(vec![make_entry("abc123", "SHIELD-001")]);
118 assert!(!baseline.contains("xyz789"));
119 }
120
121 #[test]
122 fn test_empty_baseline_round_trip() {
123 let baseline = BaselineFile::new(vec![]);
124 let tmp = NamedTempFile::new().unwrap();
125 baseline.save(tmp.path()).unwrap();
126 let loaded = BaselineFile::load(tmp.path()).unwrap();
127 assert_eq!(loaded.entries.len(), 0);
128 assert_eq!(loaded.schema_version, 1);
129 }
130
131 #[test]
132 fn test_future_schema_version_rejected() {
133 let json = r#"{
134 "schema_version": 99,
135 "created_at": "2026-03-20T00:00:00Z",
136 "tool_version": "0.2.4",
137 "entries": []
138 }"#;
139 let tmp = NamedTempFile::new().unwrap();
140 std::fs::write(tmp.path(), json).unwrap();
141 let result = BaselineFile::load(tmp.path());
142 assert!(result.is_err());
143 let msg = result.unwrap_err().to_string();
144 assert!(
145 msg.contains("newer than supported"),
146 "error message should explain the version mismatch, got: {msg}"
147 );
148 }
149
150 #[test]
151 fn test_current_schema_version_accepted() {
152 let baseline = BaselineFile::new(vec![make_entry("fp1", "SHIELD-007")]);
153 let tmp = NamedTempFile::new().unwrap();
154 baseline.save(tmp.path()).unwrap();
155 assert!(BaselineFile::load(tmp.path()).is_ok());
157 }
158
159 #[test]
160 fn test_tool_version_populated() {
161 let baseline = BaselineFile::new(vec![]);
162 assert!(!baseline.tool_version.is_empty());
163 }
164
165 #[test]
166 fn test_created_at_is_rfc3339() {
167 let baseline = BaselineFile::new(vec![]);
168 chrono::DateTime::parse_from_rfc3339(&baseline.created_at)
170 .expect("created_at must be valid RFC3339");
171 }
172}