keyhog_core/
calibration.rs1use std::collections::HashMap;
24use std::path::{Path, PathBuf};
25
26use parking_lot::RwLock;
27use serde::{Deserialize, Serialize};
28
29#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
33pub struct BetaCounters {
34 pub alpha: u32,
35 pub beta: u32,
36}
37
38impl Default for BetaCounters {
39 fn default() -> Self {
40 Self { alpha: 1, beta: 1 }
41 }
42}
43
44impl BetaCounters {
45 pub fn posterior_mean(&self) -> f64 {
48 let total = self.alpha as f64 + self.beta as f64;
49 if total == 0.0 {
50 0.5
51 } else {
52 self.alpha as f64 / total
53 }
54 }
55
56 pub fn observations(&self) -> u32 {
59 self.alpha.saturating_sub(1) + self.beta.saturating_sub(1)
61 }
62}
63
64#[derive(Debug, Serialize, Deserialize)]
66struct OnDisk {
67 version: u32,
68 detectors: HashMap<String, BetaCounters>,
69}
70
71const SCHEMA_VERSION: u32 = 1;
72
73#[derive(Debug, Default)]
79pub struct Calibration {
80 inner: RwLock<HashMap<String, BetaCounters>>,
81}
82
83impl Calibration {
84 pub fn empty() -> Self {
85 Self::default()
86 }
87
88 pub fn load(path: &Path) -> Self {
89 let bytes = match std::fs::read(path) {
90 Ok(b) => b,
91 Err(_) => return Self::empty(),
92 };
93 let on_disk: OnDisk = match serde_json::from_slice(&bytes) {
94 Ok(d) => d,
95 Err(e) => {
96 tracing::warn!(
97 cache = %path.display(),
98 error = %e,
99 "calibration parse failed; treating as cold start"
100 );
101 return Self::empty();
102 }
103 };
104 if on_disk.version != SCHEMA_VERSION {
105 tracing::warn!(
106 cache = %path.display(),
107 version = on_disk.version,
108 expected = SCHEMA_VERSION,
109 "calibration schema mismatch; treating as cold start"
110 );
111 return Self::empty();
112 }
113 Self {
114 inner: RwLock::new(on_disk.detectors),
115 }
116 }
117
118 pub fn save(&self, path: &Path) -> std::io::Result<()> {
119 let detectors = self.inner.read().clone();
120 let on_disk = OnDisk {
121 version: SCHEMA_VERSION,
122 detectors,
123 };
124 let serialized = serde_json::to_vec_pretty(&on_disk)
125 .map_err(|e| std::io::Error::other(format!("calibration encode: {e}")))?;
126 if let Some(parent) = path.parent() {
127 std::fs::create_dir_all(parent)?;
128 }
129 let tmp = path.with_extension(format!("tmp.{}", std::process::id()));
130 std::fs::write(&tmp, &serialized)?;
131 std::fs::rename(&tmp, path)?;
132 Ok(())
133 }
134
135 pub fn record_true_positive(&self, detector_id: &str) {
137 self.inner
138 .write()
139 .entry(detector_id.to_string())
140 .or_default()
141 .alpha += 1;
142 }
143
144 pub fn record_false_positive(&self, detector_id: &str) {
146 self.inner
147 .write()
148 .entry(detector_id.to_string())
149 .or_default()
150 .beta += 1;
151 }
152
153 pub fn confidence_multiplier(&self, detector_id: &str) -> f64 {
159 self.inner
160 .read()
161 .get(detector_id)
162 .copied()
163 .unwrap_or_default()
164 .posterior_mean()
165 }
166
167 pub fn counters(&self, detector_id: &str) -> BetaCounters {
169 self.inner
170 .read()
171 .get(detector_id)
172 .copied()
173 .unwrap_or_default()
174 }
175
176 pub fn entries(&self) -> Vec<(String, BetaCounters)> {
179 let mut out: Vec<_> = self
180 .inner
181 .read()
182 .iter()
183 .map(|(k, v)| (k.clone(), *v))
184 .collect();
185 out.sort_by(|a, b| a.0.cmp(&b.0));
186 out
187 }
188}
189
190pub fn default_cache_path() -> Option<PathBuf> {
193 dirs::cache_dir().map(|d| d.join("keyhog").join("calibration.json"))
194}
195
196#[cfg(test)]
197mod tests {
198 use super::*;
199
200 #[test]
201 fn fresh_detector_returns_uniform_prior() {
202 let c = Calibration::empty();
203 assert_eq!(c.confidence_multiplier("never-seen"), 0.5);
204 }
205
206 #[test]
207 fn true_positives_drive_posterior_up() {
208 let c = Calibration::empty();
209 for _ in 0..9 {
210 c.record_true_positive("aws-access-key");
211 }
212 let m = c.confidence_multiplier("aws-access-key");
214 assert!(m > 0.85, "expected >0.85, got {m}");
215 }
216
217 #[test]
218 fn false_positives_drive_posterior_down() {
219 let c = Calibration::empty();
220 for _ in 0..9 {
221 c.record_false_positive("noisy-detector");
222 }
223 let m = c.confidence_multiplier("noisy-detector");
225 assert!(m < 0.15, "expected <0.15, got {m}");
226 }
227
228 #[test]
229 fn observations_excludes_prior() {
230 let c = Calibration::empty();
231 assert_eq!(c.counters("x").observations(), 0);
232 c.record_true_positive("x");
233 c.record_false_positive("x");
234 assert_eq!(c.counters("x").observations(), 2);
235 }
236
237 #[test]
238 fn save_load_roundtrip() {
239 let dir = tempfile::tempdir().unwrap();
240 let path = dir.path().join("calibration.json");
241
242 let c = Calibration::empty();
243 c.record_true_positive("aws-access-key");
244 c.record_false_positive("aws-access-key");
245 c.record_true_positive("github-pat");
246 c.save(&path).unwrap();
247
248 let loaded = Calibration::load(&path);
249 let aws = loaded.counters("aws-access-key");
250 assert_eq!(aws.alpha, 2);
251 assert_eq!(aws.beta, 2);
252 let gh = loaded.counters("github-pat");
253 assert_eq!(gh.alpha, 2);
254 assert_eq!(gh.beta, 1);
255 }
256
257 #[test]
258 fn corrupted_cache_returns_empty() {
259 let dir = tempfile::tempdir().unwrap();
260 let path = dir.path().join("calibration.json");
261 std::fs::write(&path, b"this is not json").unwrap();
262 let loaded = Calibration::load(&path);
263 assert_eq!(loaded.entries().len(), 0);
264 }
265
266 #[test]
267 fn schema_mismatch_returns_empty() {
268 let dir = tempfile::tempdir().unwrap();
269 let path = dir.path().join("calibration.json");
270 let bad = serde_json::json!({
271 "version": 99,
272 "detectors": { "x": { "alpha": 5, "beta": 5 } }
273 });
274 std::fs::write(&path, serde_json::to_vec(&bad).unwrap()).unwrap();
275 let loaded = Calibration::load(&path);
276 assert_eq!(loaded.entries().len(), 0);
277 }
278
279 #[test]
280 fn entries_returns_sorted() {
281 let c = Calibration::empty();
282 c.record_true_positive("zzz");
283 c.record_true_positive("aaa");
284 c.record_true_positive("mmm");
285 let e = c.entries();
286 assert_eq!(e.len(), 3);
287 assert_eq!(e[0].0, "aaa");
288 assert_eq!(e[1].0, "mmm");
289 assert_eq!(e[2].0, "zzz");
290 }
291}