Skip to main content

keyhog_core/
calibration.rs

1//! Bayesian Beta(α, β) calibration per detector.
2//!
3//! Tier-B moat innovation #4 from audits/legendary-2026-04-26: surface
4//! per-detector reliability based on observed true-positive vs false-
5//! positive history rather than a fixed threshold. Detectors with a long
6//! history of clean hits get a higher confidence multiplier; detectors
7//! that fire-then-suppress repeatedly get downweighted.
8//!
9//! Mathematical model:
10//!     each detector has a Beta(α, β) prior over P(true positive | match).
11//!     α counts confirmed TPs, β counts confirmed FPs (both incremented from
12//!     a starting prior of α=1, β=1 — uniform Beta(1, 1)).
13//!     posterior mean = α / (α + β)  ∈ [0, 1].
14//!
15//! Storage: JSON at `$XDG_CACHE_HOME/keyhog/calibration.json` with a schema
16//! version field. Load returns an empty store on miss / corrupted JSON /
17//! schema mismatch — never poison the cache from a damaged artifact.
18//!
19//! This module ships the DATA layer only. Live integration into the
20//! scanner's confidence-scoring path is a separate change that needs
21//! per-detector lookup at `apply_post_ml_penalties` time.
22
23use std::collections::HashMap;
24use std::path::{Path, PathBuf};
25
26use parking_lot::RwLock;
27use serde::{Deserialize, Serialize};
28
29/// A detector's running Beta posterior counters. Always ≥1 each (Beta(1,1)
30/// uniform prior baseline) to avoid posterior_mean undefined when a detector
31/// has had no observations yet.
32#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
33pub struct BetaCounters {
34    pub alpha: u32,
35    pub beta: u32,
36}
37
38impl Default for BetaCounters {
39    fn default() -> Self {
40        Self { alpha: 1, beta: 1 }
41    }
42}
43
44impl BetaCounters {
45    /// Posterior mean: α / (α + β). Falls in [0, 1]; the higher, the more
46    /// reliable the detector is historically.
47    pub fn posterior_mean(&self) -> f64 {
48        let total = self.alpha as f64 + self.beta as f64;
49        if total == 0.0 {
50            0.5
51        } else {
52            self.alpha as f64 / total
53        }
54    }
55
56    /// Number of observations (excluding the prior) the posterior is built
57    /// on. Useful for "trust the recent history" UI gates.
58    pub fn observations(&self) -> u32 {
59        // Subtract the Beta(1, 1) prior baseline.
60        self.alpha.saturating_sub(1) + self.beta.saturating_sub(1)
61    }
62}
63
64/// On-disk format. The version field gates breaking schema changes.
65#[derive(Debug, Serialize, Deserialize)]
66struct OnDisk {
67    version: u32,
68    detectors: HashMap<String, BetaCounters>,
69}
70
71const SCHEMA_VERSION: u32 = 1;
72
73/// Process-wide calibration store. Concurrent updates are serialized via
74/// a single `RwLock` because update events are rare (one per `keyhog
75/// calibrate` invocation or per verifier outcome) and the locked region is
76/// constant-time. We deliberately don't shard via DashMap — the persisted
77/// artifact is small enough that contention is a non-issue.
78#[derive(Debug, Default)]
79pub struct Calibration {
80    inner: RwLock<HashMap<String, BetaCounters>>,
81}
82
83impl Calibration {
84    pub fn empty() -> Self {
85        Self::default()
86    }
87
88    pub fn load(path: &Path) -> Self {
89        let bytes = match std::fs::read(path) {
90            Ok(b) => b,
91            Err(_) => return Self::empty(),
92        };
93        let on_disk: OnDisk = match serde_json::from_slice(&bytes) {
94            Ok(d) => d,
95            Err(e) => {
96                tracing::warn!(
97                    cache = %path.display(),
98                    error = %e,
99                    "calibration parse failed; treating as cold start"
100                );
101                return Self::empty();
102            }
103        };
104        if on_disk.version != SCHEMA_VERSION {
105            tracing::warn!(
106                cache = %path.display(),
107                version = on_disk.version,
108                expected = SCHEMA_VERSION,
109                "calibration schema mismatch; treating as cold start"
110            );
111            return Self::empty();
112        }
113        Self {
114            inner: RwLock::new(on_disk.detectors),
115        }
116    }
117
118    pub fn save(&self, path: &Path) -> std::io::Result<()> {
119        let detectors = self.inner.read().clone();
120        let on_disk = OnDisk {
121            version: SCHEMA_VERSION,
122            detectors,
123        };
124        let serialized = serde_json::to_vec_pretty(&on_disk)
125            .map_err(|e| std::io::Error::other(format!("calibration encode: {e}")))?;
126        let parent = path.parent().unwrap_or_else(|| std::path::Path::new("."));
127        std::fs::create_dir_all(parent)?;
128        // Same atomic-write-via-NamedTempFile pattern used by
129        // `merkle_index::save` — see that file's note for rationale.
130        let mut tmp = tempfile::NamedTempFile::new_in(parent)?;
131        std::io::Write::write_all(&mut tmp, &serialized)?;
132        tmp.as_file().sync_all()?;
133        tmp.persist(path).map_err(|e| e.error)?;
134        Ok(())
135    }
136
137    /// Record a true positive for `detector_id` (α += 1).
138    pub fn record_true_positive(&self, detector_id: &str) {
139        self.inner
140            .write()
141            .entry(detector_id.to_string())
142            .or_default()
143            .alpha += 1;
144    }
145
146    /// Record a false positive for `detector_id` (β += 1).
147    pub fn record_false_positive(&self, detector_id: &str) {
148        self.inner
149            .write()
150            .entry(detector_id.to_string())
151            .or_default()
152            .beta += 1;
153    }
154
155    /// Return the posterior mean for `detector_id`, falling back to 0.5
156    /// when no observations exist (uniform prior over a never-calibrated
157    /// detector). Callers MAY use this value as a confidence multiplier
158    /// inside the scanner's confidence-scoring path; the live integration
159    /// is staged separately.
160    pub fn confidence_multiplier(&self, detector_id: &str) -> f64 {
161        self.inner
162            .read()
163            .get(detector_id)
164            .copied()
165            .unwrap_or_default()
166            .posterior_mean()
167    }
168
169    /// Return the full counters for `detector_id` (defaults to Beta(1, 1)).
170    pub fn counters(&self, detector_id: &str) -> BetaCounters {
171        self.inner
172            .read()
173            .get(detector_id)
174            .copied()
175            .unwrap_or_default()
176    }
177
178    /// Iterate every recorded `(detector_id, counters)`. Useful for
179    /// `keyhog calibrate --show`.
180    pub fn entries(&self) -> Vec<(String, BetaCounters)> {
181        let mut out: Vec<_> = self
182            .inner
183            .read()
184            .iter()
185            .map(|(k, v)| (k.clone(), *v))
186            .collect();
187        out.sort_by(|a, b| a.0.cmp(&b.0));
188        out
189    }
190}
191
192/// Default cache location: `$XDG_CACHE_HOME/keyhog/calibration.json` (or
193/// the macOS/Windows equivalents via the `dirs` crate).
194pub fn default_cache_path() -> Option<PathBuf> {
195    dirs::cache_dir().map(|d| d.join("keyhog").join("calibration.json"))
196}
197
198#[cfg(test)]
199mod tests {
200    use super::*;
201
202    #[test]
203    fn fresh_detector_returns_uniform_prior() {
204        let c = Calibration::empty();
205        assert_eq!(c.confidence_multiplier("never-seen"), 0.5);
206    }
207
208    #[test]
209    fn true_positives_drive_posterior_up() {
210        let c = Calibration::empty();
211        for _ in 0..9 {
212            c.record_true_positive("aws-access-key");
213        }
214        // α = 10, β = 1 → mean = 10/11 ≈ 0.909
215        let m = c.confidence_multiplier("aws-access-key");
216        assert!(m > 0.85, "expected >0.85, got {m}");
217    }
218
219    #[test]
220    fn false_positives_drive_posterior_down() {
221        let c = Calibration::empty();
222        for _ in 0..9 {
223            c.record_false_positive("noisy-detector");
224        }
225        // α = 1, β = 10 → mean = 1/11 ≈ 0.091
226        let m = c.confidence_multiplier("noisy-detector");
227        assert!(m < 0.15, "expected <0.15, got {m}");
228    }
229
230    #[test]
231    fn observations_excludes_prior() {
232        let c = Calibration::empty();
233        assert_eq!(c.counters("x").observations(), 0);
234        c.record_true_positive("x");
235        c.record_false_positive("x");
236        assert_eq!(c.counters("x").observations(), 2);
237    }
238
239    #[test]
240    fn save_load_roundtrip() {
241        let dir = tempfile::tempdir().unwrap();
242        let path = dir.path().join("calibration.json");
243
244        let c = Calibration::empty();
245        c.record_true_positive("aws-access-key");
246        c.record_false_positive("aws-access-key");
247        c.record_true_positive("github-pat");
248        c.save(&path).unwrap();
249
250        let loaded = Calibration::load(&path);
251        let aws = loaded.counters("aws-access-key");
252        assert_eq!(aws.alpha, 2);
253        assert_eq!(aws.beta, 2);
254        let gh = loaded.counters("github-pat");
255        assert_eq!(gh.alpha, 2);
256        assert_eq!(gh.beta, 1);
257    }
258
259    #[test]
260    fn corrupted_cache_returns_empty() {
261        let dir = tempfile::tempdir().unwrap();
262        let path = dir.path().join("calibration.json");
263        std::fs::write(&path, b"this is not json").unwrap();
264        let loaded = Calibration::load(&path);
265        assert_eq!(loaded.entries().len(), 0);
266    }
267
268    #[test]
269    fn schema_mismatch_returns_empty() {
270        let dir = tempfile::tempdir().unwrap();
271        let path = dir.path().join("calibration.json");
272        let bad = serde_json::json!({
273            "version": 99,
274            "detectors": { "x": { "alpha": 5, "beta": 5 } }
275        });
276        std::fs::write(&path, serde_json::to_vec(&bad).unwrap()).unwrap();
277        let loaded = Calibration::load(&path);
278        assert_eq!(loaded.entries().len(), 0);
279    }
280
281    #[test]
282    fn entries_returns_sorted() {
283        let c = Calibration::empty();
284        c.record_true_positive("zzz");
285        c.record_true_positive("aaa");
286        c.record_true_positive("mmm");
287        let e = c.entries();
288        assert_eq!(e.len(), 3);
289        assert_eq!(e[0].0, "aaa");
290        assert_eq!(e[1].0, "mmm");
291        assert_eq!(e[2].0, "zzz");
292    }
293}