Skip to main content

keyhog_core/
calibration.rs

1//! Bayesian Beta(α, β) calibration per detector.
2//!
3//! Tier-B moat innovation #4 from audits/legendary-2026-04-26: surface
4//! per-detector reliability based on observed true-positive vs false-
5//! positive history rather than a fixed threshold. Detectors with a long
6//! history of clean hits get a higher confidence multiplier; detectors
7//! that fire-then-suppress repeatedly get downweighted.
8//!
9//! Mathematical model:
10//!     each detector has a Beta(α, β) prior over P(true positive | match).
11//!     α counts confirmed TPs, β counts confirmed FPs (both incremented from
12//!     a starting prior of α=1, β=1 — uniform Beta(1, 1)).
13//!     posterior mean = α / (α + β)  ∈ [0, 1].
14//!
15//! Storage: JSON at `$XDG_CACHE_HOME/keyhog/calibration.json` with a schema
16//! version field. Load returns an empty store on miss / corrupted JSON /
17//! schema mismatch — never poison the cache from a damaged artifact.
18//!
19//! This module ships the DATA layer only. Live integration into the
20//! scanner's confidence-scoring path is a separate change that needs
21//! per-detector lookup at `apply_post_ml_penalties` time.
22
23use std::collections::HashMap;
24use std::path::{Path, PathBuf};
25
26use parking_lot::RwLock;
27use serde::{Deserialize, Serialize};
28
29/// A detector's running Beta posterior counters. Always ≥1 each (Beta(1,1)
30/// uniform prior baseline) to avoid posterior_mean undefined when a detector
31/// has had no observations yet.
32#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
33pub struct BetaCounters {
34    pub alpha: u32,
35    pub beta: u32,
36}
37
38impl Default for BetaCounters {
39    fn default() -> Self {
40        Self { alpha: 1, beta: 1 }
41    }
42}
43
44impl BetaCounters {
45    /// Posterior mean: α / (α + β). Falls in [0, 1]; the higher, the more
46    /// reliable the detector is historically.
47    pub fn posterior_mean(&self) -> f64 {
48        let total = self.alpha as f64 + self.beta as f64;
49        if total == 0.0 {
50            0.5
51        } else {
52            self.alpha as f64 / total
53        }
54    }
55
56    /// Number of observations (excluding the prior) the posterior is built
57    /// on. Useful for "trust the recent history" UI gates.
58    pub fn observations(&self) -> u32 {
59        // Subtract the Beta(1, 1) prior baseline.
60        self.alpha.saturating_sub(1) + self.beta.saturating_sub(1)
61    }
62}
63
64/// On-disk format. The version field gates breaking schema changes.
65#[derive(Debug, Serialize, Deserialize)]
66struct OnDisk {
67    version: u32,
68    detectors: HashMap<String, BetaCounters>,
69}
70
71const SCHEMA_VERSION: u32 = 1;
72
73/// Process-wide calibration store. Concurrent updates are serialized via
74/// a single `RwLock` because update events are rare (one per `keyhog
75/// calibrate` invocation or per verifier outcome) and the locked region is
76/// constant-time. We deliberately don't shard via DashMap — the persisted
77/// artifact is small enough that contention is a non-issue.
78#[derive(Debug, Default)]
79pub struct Calibration {
80    inner: RwLock<HashMap<String, BetaCounters>>,
81}
82
83impl Calibration {
84    pub fn empty() -> Self {
85        Self::default()
86    }
87
88    pub fn load(path: &Path) -> Self {
89        let bytes = match std::fs::read(path) {
90            Ok(b) => b,
91            Err(_) => return Self::empty(),
92        };
93        let on_disk: OnDisk = match serde_json::from_slice(&bytes) {
94            Ok(d) => d,
95            Err(e) => {
96                tracing::warn!(
97                    cache = %path.display(),
98                    error = %e,
99                    "calibration parse failed; treating as cold start"
100                );
101                return Self::empty();
102            }
103        };
104        if on_disk.version != SCHEMA_VERSION {
105            tracing::warn!(
106                cache = %path.display(),
107                version = on_disk.version,
108                expected = SCHEMA_VERSION,
109                "calibration schema mismatch; treating as cold start"
110            );
111            return Self::empty();
112        }
113        Self {
114            inner: RwLock::new(on_disk.detectors),
115        }
116    }
117
118    pub fn save(&self, path: &Path) -> std::io::Result<()> {
119        let detectors = self.inner.read().clone();
120        let on_disk = OnDisk {
121            version: SCHEMA_VERSION,
122            detectors,
123        };
124        let serialized = serde_json::to_vec_pretty(&on_disk)
125            .map_err(|e| std::io::Error::other(format!("calibration encode: {e}")))?;
126        if let Some(parent) = path.parent() {
127            std::fs::create_dir_all(parent)?;
128        }
129        let tmp = path.with_extension(format!("tmp.{}", std::process::id()));
130        std::fs::write(&tmp, &serialized)?;
131        std::fs::rename(&tmp, path)?;
132        Ok(())
133    }
134
135    /// Record a true positive for `detector_id` (α += 1).
136    pub fn record_true_positive(&self, detector_id: &str) {
137        self.inner
138            .write()
139            .entry(detector_id.to_string())
140            .or_default()
141            .alpha += 1;
142    }
143
144    /// Record a false positive for `detector_id` (β += 1).
145    pub fn record_false_positive(&self, detector_id: &str) {
146        self.inner
147            .write()
148            .entry(detector_id.to_string())
149            .or_default()
150            .beta += 1;
151    }
152
153    /// Return the posterior mean for `detector_id`, falling back to 0.5
154    /// when no observations exist (uniform prior over a never-calibrated
155    /// detector). Callers MAY use this value as a confidence multiplier
156    /// inside the scanner's confidence-scoring path; the live integration
157    /// is staged separately.
158    pub fn confidence_multiplier(&self, detector_id: &str) -> f64 {
159        self.inner
160            .read()
161            .get(detector_id)
162            .copied()
163            .unwrap_or_default()
164            .posterior_mean()
165    }
166
167    /// Return the full counters for `detector_id` (defaults to Beta(1, 1)).
168    pub fn counters(&self, detector_id: &str) -> BetaCounters {
169        self.inner
170            .read()
171            .get(detector_id)
172            .copied()
173            .unwrap_or_default()
174    }
175
176    /// Iterate every recorded `(detector_id, counters)`. Useful for
177    /// `keyhog calibrate --show`.
178    pub fn entries(&self) -> Vec<(String, BetaCounters)> {
179        let mut out: Vec<_> = self
180            .inner
181            .read()
182            .iter()
183            .map(|(k, v)| (k.clone(), *v))
184            .collect();
185        out.sort_by(|a, b| a.0.cmp(&b.0));
186        out
187    }
188}
189
190/// Default cache location: `$XDG_CACHE_HOME/keyhog/calibration.json` (or
191/// the macOS/Windows equivalents via the `dirs` crate).
192pub fn default_cache_path() -> Option<PathBuf> {
193    dirs::cache_dir().map(|d| d.join("keyhog").join("calibration.json"))
194}
195
196#[cfg(test)]
197mod tests {
198    use super::*;
199
200    #[test]
201    fn fresh_detector_returns_uniform_prior() {
202        let c = Calibration::empty();
203        assert_eq!(c.confidence_multiplier("never-seen"), 0.5);
204    }
205
206    #[test]
207    fn true_positives_drive_posterior_up() {
208        let c = Calibration::empty();
209        for _ in 0..9 {
210            c.record_true_positive("aws-access-key");
211        }
212        // α = 10, β = 1 → mean = 10/11 ≈ 0.909
213        let m = c.confidence_multiplier("aws-access-key");
214        assert!(m > 0.85, "expected >0.85, got {m}");
215    }
216
217    #[test]
218    fn false_positives_drive_posterior_down() {
219        let c = Calibration::empty();
220        for _ in 0..9 {
221            c.record_false_positive("noisy-detector");
222        }
223        // α = 1, β = 10 → mean = 1/11 ≈ 0.091
224        let m = c.confidence_multiplier("noisy-detector");
225        assert!(m < 0.15, "expected <0.15, got {m}");
226    }
227
228    #[test]
229    fn observations_excludes_prior() {
230        let c = Calibration::empty();
231        assert_eq!(c.counters("x").observations(), 0);
232        c.record_true_positive("x");
233        c.record_false_positive("x");
234        assert_eq!(c.counters("x").observations(), 2);
235    }
236
237    #[test]
238    fn save_load_roundtrip() {
239        let dir = tempfile::tempdir().unwrap();
240        let path = dir.path().join("calibration.json");
241
242        let c = Calibration::empty();
243        c.record_true_positive("aws-access-key");
244        c.record_false_positive("aws-access-key");
245        c.record_true_positive("github-pat");
246        c.save(&path).unwrap();
247
248        let loaded = Calibration::load(&path);
249        let aws = loaded.counters("aws-access-key");
250        assert_eq!(aws.alpha, 2);
251        assert_eq!(aws.beta, 2);
252        let gh = loaded.counters("github-pat");
253        assert_eq!(gh.alpha, 2);
254        assert_eq!(gh.beta, 1);
255    }
256
257    #[test]
258    fn corrupted_cache_returns_empty() {
259        let dir = tempfile::tempdir().unwrap();
260        let path = dir.path().join("calibration.json");
261        std::fs::write(&path, b"this is not json").unwrap();
262        let loaded = Calibration::load(&path);
263        assert_eq!(loaded.entries().len(), 0);
264    }
265
266    #[test]
267    fn schema_mismatch_returns_empty() {
268        let dir = tempfile::tempdir().unwrap();
269        let path = dir.path().join("calibration.json");
270        let bad = serde_json::json!({
271            "version": 99,
272            "detectors": { "x": { "alpha": 5, "beta": 5 } }
273        });
274        std::fs::write(&path, serde_json::to_vec(&bad).unwrap()).unwrap();
275        let loaded = Calibration::load(&path);
276        assert_eq!(loaded.entries().len(), 0);
277    }
278
279    #[test]
280    fn entries_returns_sorted() {
281        let c = Calibration::empty();
282        c.record_true_positive("zzz");
283        c.record_true_positive("aaa");
284        c.record_true_positive("mmm");
285        let e = c.entries();
286        assert_eq!(e.len(), 3);
287        assert_eq!(e[0].0, "aaa");
288        assert_eq!(e[1].0, "mmm");
289        assert_eq!(e[2].0, "zzz");
290    }
291}