#![allow(missing_docs)]
use std::collections::HashMap;
use std::path::{Path, PathBuf};
use parking_lot::RwLock;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
pub struct BetaCounters {
pub alpha: u32,
pub beta: u32,
}
impl Default for BetaCounters {
fn default() -> Self {
Self { alpha: 1, beta: 1 }
}
}
impl BetaCounters {
pub fn posterior_mean(&self) -> f64 {
let total = self.alpha as f64 + self.beta as f64;
if total == 0.0 {
0.5
} else {
self.alpha as f64 / total
}
}
pub fn observations(&self) -> u32 {
self.alpha
.saturating_sub(1)
.saturating_add(self.beta.saturating_sub(1))
}
}
#[derive(Debug, Serialize, Deserialize)]
struct OnDisk {
version: u32,
detectors: HashMap<String, BetaCounters>,
}
const SCHEMA_VERSION: u32 = 1;
#[derive(Debug, Default)]
pub struct Calibration {
inner: RwLock<HashMap<String, BetaCounters>>,
}
impl Calibration {
pub fn empty() -> Self {
Self::default()
}
pub fn load(path: &Path) -> Self {
let bytes = match std::fs::read(path) {
Ok(b) => b,
Err(e) if e.kind() == std::io::ErrorKind::NotFound => return Self::empty(),
Err(e) => {
tracing::warn!(
cache = %path.display(),
error = %e,
"calibration file read failed; treating as cold start"
);
return Self::empty();
}
};
let on_disk: OnDisk = match serde_json::from_slice(&bytes) {
Ok(d) => d,
Err(e) => {
tracing::warn!(
cache = %path.display(),
error = %e,
"calibration parse failed; treating as cold start"
);
return Self::empty();
}
};
if on_disk.version != SCHEMA_VERSION {
tracing::warn!(
cache = %path.display(),
version = on_disk.version,
expected = SCHEMA_VERSION,
"calibration schema mismatch; treating as cold start"
);
return Self::empty();
}
Self {
inner: RwLock::new(on_disk.detectors),
}
}
pub fn save(&self, path: &Path) -> std::io::Result<()> {
let detectors = self.inner.read().clone();
let on_disk = OnDisk {
version: SCHEMA_VERSION,
detectors,
};
let serialized = serde_json::to_vec_pretty(&on_disk)
.map_err(|e| std::io::Error::other(format!("calibration encode: {e}")))?;
let parent = path.parent().unwrap_or_else(|| std::path::Path::new("."));
std::fs::create_dir_all(parent)?;
let mut tmp = tempfile::NamedTempFile::new_in(parent)?;
std::io::Write::write_all(&mut tmp, &serialized)?;
tmp.as_file().sync_all()?;
tmp.persist(path).map_err(|e| e.error)?;
Ok(())
}
pub fn record_true_positive(&self, detector_id: &str) {
let mut guard = self.inner.write();
let entry = guard.entry(detector_id.to_string()).or_default();
entry.alpha = entry.alpha.saturating_add(1);
}
pub fn record_false_positive(&self, detector_id: &str) {
let mut guard = self.inner.write();
let entry = guard.entry(detector_id.to_string()).or_default();
entry.beta = entry.beta.saturating_add(1);
}
pub fn confidence_multiplier(&self, detector_id: &str) -> f64 {
self.inner
.read()
.get(detector_id)
.copied()
.unwrap_or_default()
.posterior_mean()
}
pub fn counters(&self, detector_id: &str) -> BetaCounters {
self.inner
.read()
.get(detector_id)
.copied()
.unwrap_or_default()
}
pub fn entries(&self) -> Vec<(String, BetaCounters)> {
let mut out: Vec<_> = self
.inner
.read()
.iter()
.map(|(k, v)| (k.clone(), *v))
.collect();
out.sort_by(|a, b| a.0.cmp(&b.0));
out
}
#[doc(hidden)]
pub fn test_seed_counters(&self, id: &str, alpha: u32, beta: u32) {
let mut guard = self.inner.write();
let entry = guard.entry(id.to_string()).or_default();
entry.alpha = alpha;
entry.beta = beta;
}
}
pub fn default_cache_path() -> Option<PathBuf> {
dirs::cache_dir().map(|d| d.join("keyhog").join("calibration.json"))
}