use std::collections::HashMap;
use std::path::{Path, PathBuf};
use parking_lot::RwLock;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
pub struct BetaCounters {
pub alpha: u32,
pub beta: u32,
}
impl Default for BetaCounters {
fn default() -> Self {
Self { alpha: 1, beta: 1 }
}
}
impl BetaCounters {
pub fn posterior_mean(&self) -> f64 {
let total = self.alpha as f64 + self.beta as f64;
if total == 0.0 {
0.5
} else {
self.alpha as f64 / total
}
}
pub fn observations(&self) -> u32 {
self.alpha.saturating_sub(1) + self.beta.saturating_sub(1)
}
}
#[derive(Debug, Serialize, Deserialize)]
struct OnDisk {
version: u32,
detectors: HashMap<String, BetaCounters>,
}
const SCHEMA_VERSION: u32 = 1;
#[derive(Debug, Default)]
pub struct Calibration {
inner: RwLock<HashMap<String, BetaCounters>>,
}
impl Calibration {
pub fn empty() -> Self {
Self::default()
}
pub fn load(path: &Path) -> Self {
let bytes = match std::fs::read(path) {
Ok(b) => b,
Err(_) => return Self::empty(),
};
let on_disk: OnDisk = match serde_json::from_slice(&bytes) {
Ok(d) => d,
Err(e) => {
tracing::warn!(
cache = %path.display(),
error = %e,
"calibration parse failed; treating as cold start"
);
return Self::empty();
}
};
if on_disk.version != SCHEMA_VERSION {
tracing::warn!(
cache = %path.display(),
version = on_disk.version,
expected = SCHEMA_VERSION,
"calibration schema mismatch; treating as cold start"
);
return Self::empty();
}
Self {
inner: RwLock::new(on_disk.detectors),
}
}
pub fn save(&self, path: &Path) -> std::io::Result<()> {
let detectors = self.inner.read().clone();
let on_disk = OnDisk {
version: SCHEMA_VERSION,
detectors,
};
let serialized = serde_json::to_vec_pretty(&on_disk)
.map_err(|e| std::io::Error::other(format!("calibration encode: {e}")))?;
let parent = path.parent().unwrap_or_else(|| std::path::Path::new("."));
std::fs::create_dir_all(parent)?;
let mut tmp = tempfile::NamedTempFile::new_in(parent)?;
std::io::Write::write_all(&mut tmp, &serialized)?;
tmp.as_file().sync_all()?;
tmp.persist(path).map_err(|e| e.error)?;
Ok(())
}
pub fn record_true_positive(&self, detector_id: &str) {
self.inner
.write()
.entry(detector_id.to_string())
.or_default()
.alpha += 1;
}
pub fn record_false_positive(&self, detector_id: &str) {
self.inner
.write()
.entry(detector_id.to_string())
.or_default()
.beta += 1;
}
pub fn confidence_multiplier(&self, detector_id: &str) -> f64 {
self.inner
.read()
.get(detector_id)
.copied()
.unwrap_or_default()
.posterior_mean()
}
pub fn counters(&self, detector_id: &str) -> BetaCounters {
self.inner
.read()
.get(detector_id)
.copied()
.unwrap_or_default()
}
pub fn entries(&self) -> Vec<(String, BetaCounters)> {
let mut out: Vec<_> = self
.inner
.read()
.iter()
.map(|(k, v)| (k.clone(), *v))
.collect();
out.sort_by(|a, b| a.0.cmp(&b.0));
out
}
}
pub fn default_cache_path() -> Option<PathBuf> {
dirs::cache_dir().map(|d| d.join("keyhog").join("calibration.json"))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn fresh_detector_returns_uniform_prior() {
let c = Calibration::empty();
assert_eq!(c.confidence_multiplier("never-seen"), 0.5);
}
#[test]
fn true_positives_drive_posterior_up() {
let c = Calibration::empty();
for _ in 0..9 {
c.record_true_positive("aws-access-key");
}
let m = c.confidence_multiplier("aws-access-key");
assert!(m > 0.85, "expected >0.85, got {m}");
}
#[test]
fn false_positives_drive_posterior_down() {
let c = Calibration::empty();
for _ in 0..9 {
c.record_false_positive("noisy-detector");
}
let m = c.confidence_multiplier("noisy-detector");
assert!(m < 0.15, "expected <0.15, got {m}");
}
#[test]
fn observations_excludes_prior() {
let c = Calibration::empty();
assert_eq!(c.counters("x").observations(), 0);
c.record_true_positive("x");
c.record_false_positive("x");
assert_eq!(c.counters("x").observations(), 2);
}
#[test]
fn save_load_roundtrip() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("calibration.json");
let c = Calibration::empty();
c.record_true_positive("aws-access-key");
c.record_false_positive("aws-access-key");
c.record_true_positive("github-pat");
c.save(&path).unwrap();
let loaded = Calibration::load(&path);
let aws = loaded.counters("aws-access-key");
assert_eq!(aws.alpha, 2);
assert_eq!(aws.beta, 2);
let gh = loaded.counters("github-pat");
assert_eq!(gh.alpha, 2);
assert_eq!(gh.beta, 1);
}
#[test]
fn corrupted_cache_returns_empty() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("calibration.json");
std::fs::write(&path, b"this is not json").unwrap();
let loaded = Calibration::load(&path);
assert_eq!(loaded.entries().len(), 0);
}
#[test]
fn schema_mismatch_returns_empty() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("calibration.json");
let bad = serde_json::json!({
"version": 99,
"detectors": { "x": { "alpha": 5, "beta": 5 } }
});
std::fs::write(&path, serde_json::to_vec(&bad).unwrap()).unwrap();
let loaded = Calibration::load(&path);
assert_eq!(loaded.entries().len(), 0);
}
#[test]
fn entries_returns_sorted() {
let c = Calibration::empty();
c.record_true_positive("zzz");
c.record_true_positive("aaa");
c.record_true_positive("mmm");
let e = c.entries();
assert_eq!(e.len(), 3);
assert_eq!(e[0].0, "aaa");
assert_eq!(e[1].0, "mmm");
assert_eq!(e[2].0, "zzz");
}
}