1use anyhow::Result;
58use chrono::{Duration, Utc};
59use rusqlite::{Connection, params};
60use serde::Serialize;
61
62pub const DEFAULT_WINDOW_DAYS: i64 = 30;
65
66#[derive(Debug, Clone, Serialize, PartialEq)]
74pub struct PerSourceBaseline {
75 pub namespace: String,
76 pub source: String,
77 pub count: u64,
78 pub median: f64,
81 pub mean: f64,
84 pub buckets: [u64; 10],
88}
89
90#[derive(Debug, Clone, Serialize, PartialEq)]
92pub struct CalibrationReport {
93 pub window_days: i64,
94 pub total_observations: u64,
95 pub baselines: Vec<PerSourceBaseline>,
96}
97
98#[allow(clippy::cast_precision_loss)]
108pub fn calibrate_from_shadow(
109 conn: &Connection,
110 days: i64,
111 now: chrono::DateTime<Utc>,
112) -> Result<CalibrationReport> {
113 let since_dt = now - Duration::days(days);
114 let since = since_dt.to_rfc3339();
115
116 let mut stmt = conn.prepare(
120 "SELECT namespace, source, COUNT(*), AVG(derived_confidence)
121 FROM confidence_shadow_observations
122 WHERE observed_at >= ?1
123 GROUP BY namespace, source
124 ORDER BY namespace, source",
125 )?;
126 let groups: Vec<(String, String, i64, f64)> = stmt
127 .query_map(params![since.as_str()], |row| {
128 Ok((
129 row.get::<_, String>(0)?,
130 row.get::<_, String>(1)?,
131 row.get::<_, i64>(2)?,
132 row.get::<_, f64>(3)?,
133 ))
134 })?
135 .collect::<rusqlite::Result<Vec<_>>>()?;
136 drop(stmt);
137
138 let total_observations: u64 = groups.iter().map(|(_, _, c, _)| *c as u64).sum();
139
140 let mut median_stmt = conn.prepare(
146 "SELECT derived_confidence
147 FROM confidence_shadow_observations
148 WHERE observed_at >= ?1 AND namespace = ?2 AND source = ?3
149 ORDER BY derived_confidence ASC",
150 )?;
151
152 let mut baselines: Vec<PerSourceBaseline> = Vec::with_capacity(groups.len());
153 for (namespace, source, count_i64, mean) in groups {
154 if count_i64 <= 0 {
155 continue;
156 }
157 let count = count_i64 as u64;
158 let mut values: Vec<f64> = Vec::with_capacity(count as usize);
159 let mut rows =
160 median_stmt.query(params![since.as_str(), namespace.as_str(), source.as_str()])?;
161 let mut buckets = [0_u64; 10];
162 while let Some(row) = rows.next()? {
163 let v: f64 = row.get(0)?;
164 let idx = ((v.clamp(0.0, 1.0) * 10.0) as usize).min(9);
165 buckets[idx] += 1;
166 values.push(v);
167 }
168 let median = if values.is_empty() {
170 0.0
171 } else if values.len() % 2 == 0 {
172 let mid = values.len() / 2;
173 (values[mid - 1] + values[mid]) / 2.0
174 } else {
175 values[values.len() / 2]
176 };
177 baselines.push(PerSourceBaseline {
178 namespace,
179 source,
180 count,
181 median,
182 mean,
183 buckets,
184 });
185 }
186 drop(median_stmt);
187
188 Ok(CalibrationReport {
189 window_days: days,
190 total_observations,
191 baselines,
192 })
193}
194
195#[cfg(test)]
196mod tests {
197 use super::*;
198 use crate::confidence::shadow::observe;
199 use crate::models::ConfidenceSignals;
200 use crate::storage::open as open_storage;
201
202 fn open_tmp() -> (Connection, tempfile::TempDir) {
203 let dir = tempfile::tempdir().expect("tmpdir");
204 let path = dir.path().join("test.db");
205 let _ = open_storage(&path).expect("open storage");
206 let conn = Connection::open(&path).expect("open conn");
207 (conn, dir)
208 }
209
210 fn seed_mem(conn: &Connection, id: &str, ns: &str, source: &str) {
211 conn.execute(
212 "INSERT INTO memories (id, tier, namespace, title, content, source, created_at, updated_at)
213 VALUES (?1, 'mid', ?2, ?1, 'c', ?3, '2026-05-15T00:00:00Z', '2026-05-15T00:00:00Z')",
214 params![id, ns, source],
215 )
216 .expect("seed mem");
217 }
218
219 fn signals() -> ConfidenceSignals {
220 ConfidenceSignals::default()
221 }
222
223 #[test]
224 fn calibrate_emits_per_source_baselines() {
225 let (conn, _dir) = open_tmp();
226 seed_mem(&conn, "m1", "ns", "user");
227 seed_mem(&conn, "m2", "ns", "user");
228 seed_mem(&conn, "m3", "ns", "claude");
229 observe(&conn, "m1", "ns", "user", 0.9, 0.5, &signals(), None).unwrap();
230 observe(&conn, "m2", "ns", "user", 0.9, 0.7, &signals(), None).unwrap();
231 observe(&conn, "m3", "ns", "claude", 0.9, 0.3, &signals(), None).unwrap();
232
233 let report = calibrate_from_shadow(&conn, 30, Utc::now()).expect("calibrate");
234 assert_eq!(report.total_observations, 3);
235 assert_eq!(report.baselines.len(), 2);
236 let user = report
237 .baselines
238 .iter()
239 .find(|b| b.source == "user")
240 .expect("user baseline");
241 assert_eq!(user.count, 2);
242 assert!(
243 (user.median - 0.6).abs() < 1e-6,
244 "median got {}",
245 user.median
246 );
247 let claude = report
248 .baselines
249 .iter()
250 .find(|b| b.source == "claude")
251 .expect("claude baseline");
252 assert!((claude.median - 0.3).abs() < 1e-6);
253 }
254
255 #[test]
256 fn calibrate_buckets_cover_full_range() {
257 let (conn, _dir) = open_tmp();
258 seed_mem(&conn, "m1", "ns", "user");
259 for v in &[0.05, 0.25, 0.45, 0.55, 0.95] {
260 observe(&conn, "m1", "ns", "user", 0.9, *v, &signals(), None).unwrap();
261 }
262 let report = calibrate_from_shadow(&conn, 30, Utc::now()).expect("calibrate");
263 let b = &report.baselines[0];
264 assert_eq!(b.buckets[0], 1);
266 assert_eq!(b.buckets[2], 1);
267 assert_eq!(b.buckets[4], 1);
268 assert_eq!(b.buckets[5], 1);
269 assert_eq!(b.buckets[9], 1);
270 assert_eq!(b.count, 5);
271 }
272
273 #[test]
274 fn calibrate_filters_by_window() {
275 let (conn, _dir) = open_tmp();
276 seed_mem(&conn, "m1", "ns", "user");
277 conn.execute(
279 "INSERT INTO confidence_shadow_observations
280 (memory_id, namespace, source, caller_confidence, derived_confidence,
281 signals, recall_outcome, observed_at)
282 VALUES ('m1', 'ns', 'user', 0.9, 0.5, '{}', NULL, '2020-01-01T00:00:00Z')",
283 [],
284 )
285 .unwrap();
286 observe(&conn, "m1", "ns", "user", 0.9, 0.7, &signals(), None).unwrap();
287 let report = calibrate_from_shadow(&conn, 1, Utc::now()).expect("calibrate");
288 assert_eq!(report.total_observations, 1);
290 let b = &report.baselines[0];
291 assert!((b.median - 0.7).abs() < 1e-6);
292 }
293
294 #[test]
295 fn calibrate_empty_table_returns_empty_report() {
296 let (conn, _dir) = open_tmp();
297 let report = calibrate_from_shadow(&conn, 30, Utc::now()).expect("calibrate");
298 assert_eq!(report.total_observations, 0);
299 assert!(report.baselines.is_empty());
300 }
301}