Skip to main content

ai_memory/confidence/
calibrate.rs

1// Copyright 2026 AlphaOne LLC
2// SPDX-License-Identifier: Apache-2.0
3
4//! v0.7.0 Form 5 — calibration sweep.
5//!
6//! Reads `confidence_shadow_observations` since N days back and emits
7//! per-(namespace, source) baselines: the median derived confidence the
8//! [`crate::confidence::derive`] engine produced over the observed
9//! window. Driven by the `ai-memory calibrate confidence --from-shadow`
10//! CLI subcommand and the `memory_calibrate_confidence` MCP tool.
11//!
12//! Audit-honest contract: the sweep is **read-only** by default. The
13//! computed baselines are surfaced as a report; persistence into a
14//! calibration store is an opt-in follow-up that operators run only
15//! after reviewing the output (so a poorly-sampled window can't
16//! silently re-pin a namespace's confidence ceiling).
17//!
18//! # Streaming aggregation (Cluster G, PERF-12)
19//!
20//! Pre-Cluster-G, this module materialised the entire window into a
21//! `Vec<(ShadowObservation, String)>` (via INNER JOIN against
22//! `memories` to pull the source role), then grouped + sorted in Rust.
23//! A long-running shadow-mode deployment with millions of observations
24//! exhausted memory on the calibration call.
25//!
26//! Post-Cluster-G, the sweep streams in two passes:
27//!
28//! 1. **Group counts + mean** (single SQL aggregation):
29//!    ```sql
30//!    SELECT namespace, source, COUNT(*), AVG(derived_confidence)
31//!    FROM confidence_shadow_observations
32//!    WHERE observed_at >= ?1
33//!    GROUP BY namespace, source
34//!    ```
35//!
36//! 2. **Per-group median + bucket histogram** (cursor-based scan):
37//!    ```sql
38//!    SELECT derived_confidence FROM confidence_shadow_observations
39//!    WHERE observed_at >= ?1 AND namespace = ?2 AND source = ?3
40//!    ORDER BY derived_confidence ASC
41//!    ```
42//!    The compound `(namespace, source, observed_at)` index added in
43//!    schema v40 keeps the WHERE-predicate scan tight; the ORDER BY
44//!    DESCfile by sort merge stays in scratch space (no full-table
45//!    Vec materialisation). Median is picked at row index
46//!    `count / 2` (lower median for even counts, identical to the
47//!    pre-Cluster-G `(a+b)/2` semantics within the test tolerance);
48//!    buckets fold into 10 stack-allocated counters via a single pass.
49//!
50//! The denormalised `source` column (also schema v40) eliminates the
51//! join with `memories` entirely — orphan observation rows whose
52//! source memory has been CASCADE-deleted continue to surface in the
53//! report under their stamped `source` value, which is the audit-
54//! honest behaviour (the calibration sample was real; the source
55//! memory's later deletion doesn't unmake the observation).
56
57use anyhow::Result;
58use chrono::{Duration, Utc};
59use rusqlite::{Connection, params};
60use serde::Serialize;
61
62/// Default sweep window. The Form 5 brief calls for 30 days; tunable
63/// per call via the CLI `--days N` flag and the MCP `days` parameter.
64pub const DEFAULT_WINDOW_DAYS: i64 = 30;
65
66/// One per-(namespace, source) row in the calibration report.
67///
68/// `source` is the `memories.source` role label (`user`, `claude`,
69/// `api`, …) denormalised onto each shadow observation via the
70/// v40-schema column. `count` is the number of observations that
71/// contributed; `median` and the bucket distribution let an operator
72/// spot a skewed sample.
73#[derive(Debug, Clone, Serialize, PartialEq)]
74pub struct PerSourceBaseline {
75    pub namespace: String,
76    pub source: String,
77    pub count: u64,
78    /// Median derived confidence across the window. Robust to outliers
79    /// vs. the mean.
80    pub median: f64,
81    /// Mean derived confidence — emitted alongside the median so a
82    /// caller can spot a skew-vs-tail distinction at a glance.
83    pub mean: f64,
84    /// Bucketed distribution of derived values. 10 buckets covering
85    /// `[0.0, 0.1)` … `[0.9, 1.0]` so a downstream UI can plot a
86    /// histogram without re-reading the observation table.
87    pub buckets: [u64; 10],
88}
89
90/// Top-level calibration report emitted by the sweep.
91#[derive(Debug, Clone, Serialize, PartialEq)]
92pub struct CalibrationReport {
93    pub window_days: i64,
94    pub total_observations: u64,
95    pub baselines: Vec<PerSourceBaseline>,
96}
97
98/// Compute the calibration report by scanning shadow observations from
99/// the last `days` days.
100///
101/// `now` is parameterised so tests can pin a deterministic clock. The
102/// production CLI/MCP wrappers pass `Utc::now()`.
103///
104/// # Errors
105///
106/// Returns the underlying `rusqlite` error if the SELECT fails.
107#[allow(clippy::cast_precision_loss)]
108pub fn calibrate_from_shadow(
109    conn: &Connection,
110    days: i64,
111    now: chrono::DateTime<Utc>,
112) -> Result<CalibrationReport> {
113    let since_dt = now - Duration::days(days);
114    let since = since_dt.to_rfc3339();
115
116    // Pass 1: per-group count + mean, computed entirely in SQL. The
117    // denormalised `source` column (schema v40) lets us avoid the
118    // INNER JOIN against `memories` that pre-Cluster-G code carried.
119    let mut stmt = conn.prepare(
120        "SELECT namespace, source, COUNT(*), AVG(derived_confidence)
121         FROM confidence_shadow_observations
122         WHERE observed_at >= ?1
123         GROUP BY namespace, source
124         ORDER BY namespace, source",
125    )?;
126    let groups: Vec<(String, String, i64, f64)> = stmt
127        .query_map(params![since.as_str()], |row| {
128            Ok((
129                row.get::<_, String>(0)?,
130                row.get::<_, String>(1)?,
131                row.get::<_, i64>(2)?,
132                row.get::<_, f64>(3)?,
133            ))
134        })?
135        .collect::<rusqlite::Result<Vec<_>>>()?;
136    drop(stmt);
137
138    let total_observations: u64 = groups.iter().map(|(_, _, c, _)| *c as u64).sum();
139
140    // Pass 2: per-group cursor scan for median + bucket histogram.
141    // The compound (namespace, source, observed_at) index from
142    // schema v40 makes the WHERE filter cheap; the per-group result
143    // set is bounded by the group size (typically thousands, not
144    // millions) so the streaming Vec<f64> stays small.
145    let mut median_stmt = conn.prepare(
146        "SELECT derived_confidence
147         FROM confidence_shadow_observations
148         WHERE observed_at >= ?1 AND namespace = ?2 AND source = ?3
149         ORDER BY derived_confidence ASC",
150    )?;
151
152    let mut baselines: Vec<PerSourceBaseline> = Vec::with_capacity(groups.len());
153    for (namespace, source, count_i64, mean) in groups {
154        if count_i64 <= 0 {
155            continue;
156        }
157        let count = count_i64 as u64;
158        let mut values: Vec<f64> = Vec::with_capacity(count as usize);
159        let mut rows =
160            median_stmt.query(params![since.as_str(), namespace.as_str(), source.as_str()])?;
161        let mut buckets = [0_u64; 10];
162        while let Some(row) = rows.next()? {
163            let v: f64 = row.get(0)?;
164            let idx = ((v.clamp(0.0, 1.0) * 10.0) as usize).min(9);
165            buckets[idx] += 1;
166            values.push(v);
167        }
168        // Values arrived ORDER BY ASC — pick the median by index.
169        let median = if values.is_empty() {
170            0.0
171        } else if values.len() % 2 == 0 {
172            let mid = values.len() / 2;
173            (values[mid - 1] + values[mid]) / 2.0
174        } else {
175            values[values.len() / 2]
176        };
177        baselines.push(PerSourceBaseline {
178            namespace,
179            source,
180            count,
181            median,
182            mean,
183            buckets,
184        });
185    }
186    drop(median_stmt);
187
188    Ok(CalibrationReport {
189        window_days: days,
190        total_observations,
191        baselines,
192    })
193}
194
195#[cfg(test)]
196mod tests {
197    use super::*;
198    use crate::confidence::shadow::observe;
199    use crate::models::ConfidenceSignals;
200    use crate::storage::open as open_storage;
201
202    fn open_tmp() -> (Connection, tempfile::TempDir) {
203        let dir = tempfile::tempdir().expect("tmpdir");
204        let path = dir.path().join("test.db");
205        let _ = open_storage(&path).expect("open storage");
206        let conn = Connection::open(&path).expect("open conn");
207        (conn, dir)
208    }
209
210    fn seed_mem(conn: &Connection, id: &str, ns: &str, source: &str) {
211        conn.execute(
212            "INSERT INTO memories (id, tier, namespace, title, content, source, created_at, updated_at)
213             VALUES (?1, 'mid', ?2, ?1, 'c', ?3, '2026-05-15T00:00:00Z', '2026-05-15T00:00:00Z')",
214            params![id, ns, source],
215        )
216        .expect("seed mem");
217    }
218
219    fn signals() -> ConfidenceSignals {
220        ConfidenceSignals::default()
221    }
222
223    #[test]
224    fn calibrate_emits_per_source_baselines() {
225        let (conn, _dir) = open_tmp();
226        seed_mem(&conn, "m1", "ns", "user");
227        seed_mem(&conn, "m2", "ns", "user");
228        seed_mem(&conn, "m3", "ns", "claude");
229        observe(&conn, "m1", "ns", "user", 0.9, 0.5, &signals(), None).unwrap();
230        observe(&conn, "m2", "ns", "user", 0.9, 0.7, &signals(), None).unwrap();
231        observe(&conn, "m3", "ns", "claude", 0.9, 0.3, &signals(), None).unwrap();
232
233        let report = calibrate_from_shadow(&conn, 30, Utc::now()).expect("calibrate");
234        assert_eq!(report.total_observations, 3);
235        assert_eq!(report.baselines.len(), 2);
236        let user = report
237            .baselines
238            .iter()
239            .find(|b| b.source == "user")
240            .expect("user baseline");
241        assert_eq!(user.count, 2);
242        assert!(
243            (user.median - 0.6).abs() < 1e-6,
244            "median got {}",
245            user.median
246        );
247        let claude = report
248            .baselines
249            .iter()
250            .find(|b| b.source == "claude")
251            .expect("claude baseline");
252        assert!((claude.median - 0.3).abs() < 1e-6);
253    }
254
255    #[test]
256    fn calibrate_buckets_cover_full_range() {
257        let (conn, _dir) = open_tmp();
258        seed_mem(&conn, "m1", "ns", "user");
259        for v in &[0.05, 0.25, 0.45, 0.55, 0.95] {
260            observe(&conn, "m1", "ns", "user", 0.9, *v, &signals(), None).unwrap();
261        }
262        let report = calibrate_from_shadow(&conn, 30, Utc::now()).expect("calibrate");
263        let b = &report.baselines[0];
264        // One value in each of buckets 0, 2, 4, 5, 9
265        assert_eq!(b.buckets[0], 1);
266        assert_eq!(b.buckets[2], 1);
267        assert_eq!(b.buckets[4], 1);
268        assert_eq!(b.buckets[5], 1);
269        assert_eq!(b.buckets[9], 1);
270        assert_eq!(b.count, 5);
271    }
272
273    #[test]
274    fn calibrate_filters_by_window() {
275        let (conn, _dir) = open_tmp();
276        seed_mem(&conn, "m1", "ns", "user");
277        // Insert one row with a very old observed_at by direct INSERT.
278        conn.execute(
279            "INSERT INTO confidence_shadow_observations
280                (memory_id, namespace, source, caller_confidence, derived_confidence,
281                 signals, recall_outcome, observed_at)
282             VALUES ('m1', 'ns', 'user', 0.9, 0.5, '{}', NULL, '2020-01-01T00:00:00Z')",
283            [],
284        )
285        .unwrap();
286        observe(&conn, "m1", "ns", "user", 0.9, 0.7, &signals(), None).unwrap();
287        let report = calibrate_from_shadow(&conn, 1, Utc::now()).expect("calibrate");
288        // Old row outside the 1-day window drops out.
289        assert_eq!(report.total_observations, 1);
290        let b = &report.baselines[0];
291        assert!((b.median - 0.7).abs() < 1e-6);
292    }
293
294    #[test]
295    fn calibrate_empty_table_returns_empty_report() {
296        let (conn, _dir) = open_tmp();
297        let report = calibrate_from_shadow(&conn, 30, Utc::now()).expect("calibrate");
298        assert_eq!(report.total_observations, 0);
299        assert!(report.baselines.is_empty());
300    }
301}