keyhog_core/
calibration.rs1use std::collections::HashMap;
24use std::path::{Path, PathBuf};
25
26use parking_lot::RwLock;
27use serde::{Deserialize, Serialize};
28
29#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
33pub struct BetaCounters {
34 pub alpha: u32,
35 pub beta: u32,
36}
37
38impl Default for BetaCounters {
39 fn default() -> Self {
40 Self { alpha: 1, beta: 1 }
41 }
42}
43
44impl BetaCounters {
45 pub fn posterior_mean(&self) -> f64 {
48 let total = self.alpha as f64 + self.beta as f64;
49 if total == 0.0 {
50 0.5
51 } else {
52 self.alpha as f64 / total
53 }
54 }
55
56 pub fn observations(&self) -> u32 {
59 self.alpha.saturating_sub(1) + self.beta.saturating_sub(1)
61 }
62}
63
64#[derive(Debug, Serialize, Deserialize)]
66struct OnDisk {
67 version: u32,
68 detectors: HashMap<String, BetaCounters>,
69}
70
71const SCHEMA_VERSION: u32 = 1;
72
73#[derive(Debug, Default)]
79pub struct Calibration {
80 inner: RwLock<HashMap<String, BetaCounters>>,
81}
82
83impl Calibration {
84 pub fn empty() -> Self {
85 Self::default()
86 }
87
88 pub fn load(path: &Path) -> Self {
89 let bytes = match std::fs::read(path) {
90 Ok(b) => b,
91 Err(_) => return Self::empty(),
92 };
93 let on_disk: OnDisk = match serde_json::from_slice(&bytes) {
94 Ok(d) => d,
95 Err(e) => {
96 tracing::warn!(
97 cache = %path.display(),
98 error = %e,
99 "calibration parse failed; treating as cold start"
100 );
101 return Self::empty();
102 }
103 };
104 if on_disk.version != SCHEMA_VERSION {
105 tracing::warn!(
106 cache = %path.display(),
107 version = on_disk.version,
108 expected = SCHEMA_VERSION,
109 "calibration schema mismatch; treating as cold start"
110 );
111 return Self::empty();
112 }
113 Self {
114 inner: RwLock::new(on_disk.detectors),
115 }
116 }
117
118 pub fn save(&self, path: &Path) -> std::io::Result<()> {
119 let detectors = self.inner.read().clone();
120 let on_disk = OnDisk {
121 version: SCHEMA_VERSION,
122 detectors,
123 };
124 let serialized = serde_json::to_vec_pretty(&on_disk)
125 .map_err(|e| std::io::Error::other(format!("calibration encode: {e}")))?;
126 let parent = path.parent().unwrap_or_else(|| std::path::Path::new("."));
127 std::fs::create_dir_all(parent)?;
128 let mut tmp = tempfile::NamedTempFile::new_in(parent)?;
131 std::io::Write::write_all(&mut tmp, &serialized)?;
132 tmp.as_file().sync_all()?;
133 tmp.persist(path).map_err(|e| e.error)?;
134 Ok(())
135 }
136
137 pub fn record_true_positive(&self, detector_id: &str) {
139 self.inner
140 .write()
141 .entry(detector_id.to_string())
142 .or_default()
143 .alpha += 1;
144 }
145
146 pub fn record_false_positive(&self, detector_id: &str) {
148 self.inner
149 .write()
150 .entry(detector_id.to_string())
151 .or_default()
152 .beta += 1;
153 }
154
155 pub fn confidence_multiplier(&self, detector_id: &str) -> f64 {
161 self.inner
162 .read()
163 .get(detector_id)
164 .copied()
165 .unwrap_or_default()
166 .posterior_mean()
167 }
168
169 pub fn counters(&self, detector_id: &str) -> BetaCounters {
171 self.inner
172 .read()
173 .get(detector_id)
174 .copied()
175 .unwrap_or_default()
176 }
177
178 pub fn entries(&self) -> Vec<(String, BetaCounters)> {
181 let mut out: Vec<_> = self
182 .inner
183 .read()
184 .iter()
185 .map(|(k, v)| (k.clone(), *v))
186 .collect();
187 out.sort_by(|a, b| a.0.cmp(&b.0));
188 out
189 }
190}
191
192pub fn default_cache_path() -> Option<PathBuf> {
195 dirs::cache_dir().map(|d| d.join("keyhog").join("calibration.json"))
196}
197
198#[cfg(test)]
199mod tests {
200 use super::*;
201
202 #[test]
203 fn fresh_detector_returns_uniform_prior() {
204 let c = Calibration::empty();
205 assert_eq!(c.confidence_multiplier("never-seen"), 0.5);
206 }
207
208 #[test]
209 fn true_positives_drive_posterior_up() {
210 let c = Calibration::empty();
211 for _ in 0..9 {
212 c.record_true_positive("aws-access-key");
213 }
214 let m = c.confidence_multiplier("aws-access-key");
216 assert!(m > 0.85, "expected >0.85, got {m}");
217 }
218
219 #[test]
220 fn false_positives_drive_posterior_down() {
221 let c = Calibration::empty();
222 for _ in 0..9 {
223 c.record_false_positive("noisy-detector");
224 }
225 let m = c.confidence_multiplier("noisy-detector");
227 assert!(m < 0.15, "expected <0.15, got {m}");
228 }
229
230 #[test]
231 fn observations_excludes_prior() {
232 let c = Calibration::empty();
233 assert_eq!(c.counters("x").observations(), 0);
234 c.record_true_positive("x");
235 c.record_false_positive("x");
236 assert_eq!(c.counters("x").observations(), 2);
237 }
238
239 #[test]
240 fn save_load_roundtrip() {
241 let dir = tempfile::tempdir().unwrap();
242 let path = dir.path().join("calibration.json");
243
244 let c = Calibration::empty();
245 c.record_true_positive("aws-access-key");
246 c.record_false_positive("aws-access-key");
247 c.record_true_positive("github-pat");
248 c.save(&path).unwrap();
249
250 let loaded = Calibration::load(&path);
251 let aws = loaded.counters("aws-access-key");
252 assert_eq!(aws.alpha, 2);
253 assert_eq!(aws.beta, 2);
254 let gh = loaded.counters("github-pat");
255 assert_eq!(gh.alpha, 2);
256 assert_eq!(gh.beta, 1);
257 }
258
259 #[test]
260 fn corrupted_cache_returns_empty() {
261 let dir = tempfile::tempdir().unwrap();
262 let path = dir.path().join("calibration.json");
263 std::fs::write(&path, b"this is not json").unwrap();
264 let loaded = Calibration::load(&path);
265 assert_eq!(loaded.entries().len(), 0);
266 }
267
268 #[test]
269 fn schema_mismatch_returns_empty() {
270 let dir = tempfile::tempdir().unwrap();
271 let path = dir.path().join("calibration.json");
272 let bad = serde_json::json!({
273 "version": 99,
274 "detectors": { "x": { "alpha": 5, "beta": 5 } }
275 });
276 std::fs::write(&path, serde_json::to_vec(&bad).unwrap()).unwrap();
277 let loaded = Calibration::load(&path);
278 assert_eq!(loaded.entries().len(), 0);
279 }
280
281 #[test]
282 fn entries_returns_sorted() {
283 let c = Calibration::empty();
284 c.record_true_positive("zzz");
285 c.record_true_positive("aaa");
286 c.record_true_positive("mmm");
287 let e = c.entries();
288 assert_eq!(e.len(), 3);
289 assert_eq!(e[0].0, "aaa");
290 assert_eq!(e[1].0, "mmm");
291 assert_eq!(e[2].0, "zzz");
292 }
293}