use anyhow::Result;
use chrono::{Duration, Utc};
use rusqlite::{Connection, params};
use serde::Serialize;
pub const DEFAULT_WINDOW_DAYS: i64 = 30;
#[derive(Debug, Clone, Serialize, PartialEq)]
pub struct PerSourceBaseline {
pub namespace: String,
pub source: String,
pub count: u64,
pub median: f64,
pub mean: f64,
pub buckets: [u64; 10],
}
#[derive(Debug, Clone, Serialize, PartialEq)]
pub struct CalibrationReport {
pub window_days: i64,
pub total_observations: u64,
pub baselines: Vec<PerSourceBaseline>,
}
#[allow(clippy::cast_precision_loss)]
pub fn calibrate_from_shadow(
conn: &Connection,
days: i64,
now: chrono::DateTime<Utc>,
) -> Result<CalibrationReport> {
let since_dt = now - Duration::days(days);
let since = since_dt.to_rfc3339();
let mut stmt = conn.prepare(
"SELECT namespace, source, COUNT(*), AVG(derived_confidence)
FROM confidence_shadow_observations
WHERE observed_at >= ?1
GROUP BY namespace, source
ORDER BY namespace, source",
)?;
let groups: Vec<(String, String, i64, f64)> = stmt
.query_map(params![since.as_str()], |row| {
Ok((
row.get::<_, String>(0)?,
row.get::<_, String>(1)?,
row.get::<_, i64>(2)?,
row.get::<_, f64>(3)?,
))
})?
.collect::<rusqlite::Result<Vec<_>>>()?;
drop(stmt);
let total_observations: u64 = groups.iter().map(|(_, _, c, _)| *c as u64).sum();
let mut median_stmt = conn.prepare(
"SELECT derived_confidence
FROM confidence_shadow_observations
WHERE observed_at >= ?1 AND namespace = ?2 AND source = ?3
ORDER BY derived_confidence ASC",
)?;
let mut baselines: Vec<PerSourceBaseline> = Vec::with_capacity(groups.len());
for (namespace, source, count_i64, mean) in groups {
if count_i64 <= 0 {
continue;
}
let count = count_i64 as u64;
let mut values: Vec<f64> = Vec::with_capacity(count as usize);
let mut rows =
median_stmt.query(params![since.as_str(), namespace.as_str(), source.as_str()])?;
let mut buckets = [0_u64; 10];
while let Some(row) = rows.next()? {
let v: f64 = row.get(0)?;
let idx = ((v.clamp(0.0, 1.0) * 10.0) as usize).min(9);
buckets[idx] += 1;
values.push(v);
}
let median = if values.is_empty() {
0.0
} else if values.len() % 2 == 0 {
let mid = values.len() / 2;
(values[mid - 1] + values[mid]) / 2.0
} else {
values[values.len() / 2]
};
baselines.push(PerSourceBaseline {
namespace,
source,
count,
median,
mean,
buckets,
});
}
drop(median_stmt);
Ok(CalibrationReport {
window_days: days,
total_observations,
baselines,
})
}
#[cfg(test)]
mod tests {
use super::*;
use crate::confidence::shadow::observe;
use crate::models::ConfidenceSignals;
use crate::storage::open as open_storage;
fn open_tmp() -> (Connection, tempfile::TempDir) {
let dir = tempfile::tempdir().expect("tmpdir");
let path = dir.path().join("test.db");
let _ = open_storage(&path).expect("open storage");
let conn = Connection::open(&path).expect("open conn");
(conn, dir)
}
fn seed_mem(conn: &Connection, id: &str, ns: &str, source: &str) {
conn.execute(
"INSERT INTO memories (id, tier, namespace, title, content, source, created_at, updated_at)
VALUES (?1, 'mid', ?2, ?1, 'c', ?3, '2026-05-15T00:00:00Z', '2026-05-15T00:00:00Z')",
params![id, ns, source],
)
.expect("seed mem");
}
fn signals() -> ConfidenceSignals {
ConfidenceSignals::default()
}
#[test]
fn calibrate_emits_per_source_baselines() {
let (conn, _dir) = open_tmp();
seed_mem(&conn, "m1", "ns", "user");
seed_mem(&conn, "m2", "ns", "user");
seed_mem(&conn, "m3", "ns", "claude");
observe(&conn, "m1", "ns", "user", 0.9, 0.5, &signals(), None).unwrap();
observe(&conn, "m2", "ns", "user", 0.9, 0.7, &signals(), None).unwrap();
observe(&conn, "m3", "ns", "claude", 0.9, 0.3, &signals(), None).unwrap();
let report = calibrate_from_shadow(&conn, 30, Utc::now()).expect("calibrate");
assert_eq!(report.total_observations, 3);
assert_eq!(report.baselines.len(), 2);
let user = report
.baselines
.iter()
.find(|b| b.source == "user")
.expect("user baseline");
assert_eq!(user.count, 2);
assert!(
(user.median - 0.6).abs() < 1e-6,
"median got {}",
user.median
);
let claude = report
.baselines
.iter()
.find(|b| b.source == "claude")
.expect("claude baseline");
assert!((claude.median - 0.3).abs() < 1e-6);
}
#[test]
fn calibrate_buckets_cover_full_range() {
let (conn, _dir) = open_tmp();
seed_mem(&conn, "m1", "ns", "user");
for v in &[0.05, 0.25, 0.45, 0.55, 0.95] {
observe(&conn, "m1", "ns", "user", 0.9, *v, &signals(), None).unwrap();
}
let report = calibrate_from_shadow(&conn, 30, Utc::now()).expect("calibrate");
let b = &report.baselines[0];
assert_eq!(b.buckets[0], 1);
assert_eq!(b.buckets[2], 1);
assert_eq!(b.buckets[4], 1);
assert_eq!(b.buckets[5], 1);
assert_eq!(b.buckets[9], 1);
assert_eq!(b.count, 5);
}
#[test]
fn calibrate_filters_by_window() {
let (conn, _dir) = open_tmp();
seed_mem(&conn, "m1", "ns", "user");
conn.execute(
"INSERT INTO confidence_shadow_observations
(memory_id, namespace, source, caller_confidence, derived_confidence,
signals, recall_outcome, observed_at)
VALUES ('m1', 'ns', 'user', 0.9, 0.5, '{}', NULL, '2020-01-01T00:00:00Z')",
[],
)
.unwrap();
observe(&conn, "m1", "ns", "user", 0.9, 0.7, &signals(), None).unwrap();
let report = calibrate_from_shadow(&conn, 1, Utc::now()).expect("calibrate");
assert_eq!(report.total_observations, 1);
let b = &report.baselines[0];
assert!((b.median - 0.7).abs() < 1e-6);
}
#[test]
fn calibrate_empty_table_returns_empty_report() {
let (conn, _dir) = open_tmp();
let report = calibrate_from_shadow(&conn, 30, Utc::now()).expect("calibrate");
assert_eq!(report.total_observations, 0);
assert!(report.baselines.is_empty());
}
}