Skip to main content

roboticus_db/
shadow_routing.rs

1//! Shadow routing predictions — counterfactual ML recommendations stored
2//! alongside production decisions for offline validation.
3//!
4//! The shadow pipeline records what a candidate ML model *would* have chosen
5//! without affecting live routing. Agreement rate and regret analysis run
6//! against this data to decide when (if ever) to promote the ML model.
7
8use roboticus_core::Result;
9
10use crate::{Database, DbResultExt};
11
12/// A single shadow routing prediction row.
13#[derive(Debug, Clone)]
14pub struct ShadowPredictionRow {
15    pub id: String,
16    pub turn_id: String,
17    /// The model that production routing actually selected.
18    pub production_model: String,
19    /// The model the shadow recommender would have selected (None if shadow abstained).
20    pub shadow_model: Option<String>,
21    /// Complexity estimate used by production routing.
22    pub production_complexity: Option<f64>,
23    /// Complexity estimate from the shadow model (may differ).
24    pub shadow_complexity: Option<f64>,
25    /// 1 if production and shadow agree, 0 otherwise.
26    pub agreed: bool,
27    /// Arbitrary JSON detail blob (scores, feature weights, etc.).
28    pub detail_json: Option<String>,
29    pub created_at: String,
30}
31
32/// Insert a shadow prediction record.
33pub fn record_shadow_prediction(db: &Database, row: &ShadowPredictionRow) -> Result<()> {
34    let conn = db.conn();
35    conn.execute(
36        "INSERT INTO shadow_routing_predictions
37         (id, turn_id, production_model, shadow_model, production_complexity,
38          shadow_complexity, agreed, detail_json, created_at)
39         VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9)",
40        rusqlite::params![
41            row.id,
42            row.turn_id,
43            row.production_model,
44            row.shadow_model,
45            row.production_complexity,
46            row.shadow_complexity,
47            row.agreed as i32,
48            row.detail_json,
49            row.created_at,
50        ],
51    )
52    .db_err()?;
53    Ok(())
54}
55
56/// Summary statistics for shadow prediction agreement.
57#[derive(Debug, Clone)]
58pub struct ShadowAgreementSummary {
59    pub total: usize,
60    pub agreed: usize,
61    pub disagreed: usize,
62    /// Agreement rate [0.0, 1.0], or None if no predictions.
63    pub agreement_rate: Option<f64>,
64}
65
66/// Compute agreement summary for shadow predictions, optionally filtered by
67/// a time window (`since` in ISO-8601 format).
68pub fn shadow_agreement_summary(
69    db: &Database,
70    since: Option<&str>,
71) -> Result<ShadowAgreementSummary> {
72    let conn = db.conn();
73    let (sql, params): (&str, Vec<Box<dyn rusqlite::types::ToSql>>) = if let Some(s) = since {
74        (
75            "SELECT
76                COUNT(*) AS total,
77                COALESCE(SUM(CASE WHEN agreed = 1 THEN 1 ELSE 0 END), 0) AS agreed
78             FROM shadow_routing_predictions
79             WHERE created_at >= ?1",
80            vec![Box::new(s.to_string())],
81        )
82    } else {
83        (
84            "SELECT
85                COUNT(*) AS total,
86                COALESCE(SUM(CASE WHEN agreed = 1 THEN 1 ELSE 0 END), 0) AS agreed
87             FROM shadow_routing_predictions",
88            vec![],
89        )
90    };
91
92    let (total, agreed): (usize, usize) = conn
93        .query_row(sql, rusqlite::params_from_iter(params.iter()), |r| {
94            Ok((r.get::<_, usize>(0)?, r.get::<_, usize>(1)?))
95        })
96        .db_err()?;
97
98    let disagreed = total.saturating_sub(agreed);
99    let agreement_rate = if total > 0 {
100        Some(agreed as f64 / total as f64)
101    } else {
102        None
103    };
104
105    Ok(ShadowAgreementSummary {
106        total,
107        agreed,
108        disagreed,
109        agreement_rate,
110    })
111}
112
113/// Fetch the N most recent shadow predictions (newest first).
114pub fn recent_shadow_predictions(db: &Database, limit: usize) -> Result<Vec<ShadowPredictionRow>> {
115    let conn = db.conn();
116    let mut stmt = conn
117        .prepare(
118            "SELECT id, turn_id, production_model, shadow_model,
119                    production_complexity, shadow_complexity, agreed,
120                    detail_json, created_at
121             FROM shadow_routing_predictions
122             ORDER BY created_at DESC
123             LIMIT ?1",
124        )
125        .db_err()?;
126
127    let rows = stmt
128        .query_map(rusqlite::params![limit as i64], |r| {
129            Ok(ShadowPredictionRow {
130                id: r.get(0)?,
131                turn_id: r.get(1)?,
132                production_model: r.get(2)?,
133                shadow_model: r.get(3)?,
134                production_complexity: r.get(4)?,
135                shadow_complexity: r.get(5)?,
136                agreed: r.get::<_, i32>(6)? != 0,
137                detail_json: r.get(7)?,
138                created_at: r.get(8)?,
139            })
140        })
141        .db_err()?;
142
143    let mut results = Vec::new();
144    for row in rows {
145        results.push(row.db_err()?);
146    }
147    Ok(results)
148}
149
150/// Delete shadow routing predictions older than `retention_days` days.
151///
152/// Returns the number of rows deleted.
153pub fn prune_shadow_predictions(db: &Database, retention_days: u32) -> Result<usize> {
154    let conn = db.conn();
155    let deleted = conn
156        .execute(
157            "DELETE FROM shadow_routing_predictions \
158             WHERE created_at < datetime('now', ?1)",
159            [format!("-{retention_days} days")],
160        )
161        .db_err()?;
162    Ok(deleted)
163}
164
165/// Count disagreements where shadow would have picked a different model,
166/// grouped by (production_model, shadow_model) pair. Useful for identifying
167/// systematic divergence patterns.
168pub fn disagreement_pairs(
169    db: &Database,
170    since: Option<&str>,
171) -> Result<Vec<(String, String, usize)>> {
172    let conn = db.conn();
173    let (sql, params): (&str, Vec<Box<dyn rusqlite::types::ToSql>>) = if let Some(s) = since {
174        (
175            "SELECT production_model, shadow_model, COUNT(*) AS cnt
176             FROM shadow_routing_predictions
177             WHERE agreed = 0 AND shadow_model IS NOT NULL AND created_at >= ?1
178             GROUP BY production_model, shadow_model
179             ORDER BY cnt DESC",
180            vec![Box::new(s.to_string())],
181        )
182    } else {
183        (
184            "SELECT production_model, shadow_model, COUNT(*) AS cnt
185             FROM shadow_routing_predictions
186             WHERE agreed = 0 AND shadow_model IS NOT NULL
187             GROUP BY production_model, shadow_model
188             ORDER BY cnt DESC",
189            vec![],
190        )
191    };
192
193    let mut stmt = conn.prepare(sql).db_err()?;
194
195    let rows = stmt
196        .query_map(rusqlite::params_from_iter(params.iter()), |r| {
197            Ok((
198                r.get::<_, String>(0)?,
199                r.get::<_, String>(1)?,
200                r.get::<_, usize>(2)?,
201            ))
202        })
203        .db_err()?;
204
205    let mut results = Vec::new();
206    for row in rows {
207        results.push(row.db_err()?);
208    }
209    Ok(results)
210}
211
212#[cfg(test)]
213mod tests {
214    use super::*;
215
216    fn test_db() -> Database {
217        Database::new(":memory:").expect("in-memory db")
218    }
219
220    fn make_row(
221        id: &str,
222        turn: &str,
223        prod: &str,
224        shadow: Option<&str>,
225        agreed: bool,
226    ) -> ShadowPredictionRow {
227        ShadowPredictionRow {
228            id: id.into(),
229            turn_id: turn.into(),
230            production_model: prod.into(),
231            shadow_model: shadow.map(String::from),
232            production_complexity: Some(0.5),
233            shadow_complexity: Some(0.5),
234            agreed,
235            detail_json: None,
236            created_at: "2025-01-15T10:00:00".into(),
237        }
238    }
239
240    #[test]
241    fn record_and_retrieve() {
242        let db = test_db();
243        let row = make_row(
244            "sp-1",
245            "t-1",
246            "openai/gpt-4o",
247            Some("ollama/qwen3:8b"),
248            false,
249        );
250        record_shadow_prediction(&db, &row).unwrap();
251
252        let recent = recent_shadow_predictions(&db, 10).unwrap();
253        assert_eq!(recent.len(), 1);
254        assert_eq!(recent[0].production_model, "openai/gpt-4o");
255        assert_eq!(recent[0].shadow_model.as_deref(), Some("ollama/qwen3:8b"));
256        assert!(!recent[0].agreed);
257    }
258
259    #[test]
260    fn agreement_summary_empty() {
261        let db = test_db();
262        let summary = shadow_agreement_summary(&db, None).unwrap();
263        assert_eq!(summary.total, 0);
264        assert!(summary.agreement_rate.is_none());
265    }
266
267    #[test]
268    fn agreement_summary_mixed() {
269        let db = test_db();
270        // 3 agreed, 2 disagreed
271        for (i, agreed) in [true, true, false, true, false].iter().enumerate() {
272            let row = make_row(
273                &format!("sp-{i}"),
274                &format!("t-{i}"),
275                "openai/gpt-4o",
276                Some("ollama/qwen3:8b"),
277                *agreed,
278            );
279            record_shadow_prediction(&db, &row).unwrap();
280        }
281
282        let summary = shadow_agreement_summary(&db, None).unwrap();
283        assert_eq!(summary.total, 5);
284        assert_eq!(summary.agreed, 3);
285        assert_eq!(summary.disagreed, 2);
286        let rate = summary.agreement_rate.unwrap();
287        assert!((rate - 0.6).abs() < 1e-9);
288    }
289
290    #[test]
291    fn agreement_summary_with_since_filter() {
292        let db = test_db();
293        // Old prediction
294        let mut old = make_row("sp-old", "t-old", "openai/gpt-4o", Some("local"), false);
295        old.created_at = "2024-01-01T00:00:00".into();
296        record_shadow_prediction(&db, &old).unwrap();
297
298        // Recent prediction
299        let recent_row = make_row("sp-new", "t-new", "openai/gpt-4o", Some("local"), true);
300        record_shadow_prediction(&db, &recent_row).unwrap();
301
302        let summary = shadow_agreement_summary(&db, Some("2025-01-01T00:00:00")).unwrap();
303        assert_eq!(summary.total, 1);
304        assert_eq!(summary.agreed, 1);
305    }
306
307    #[test]
308    fn recent_predictions_ordering() {
309        let db = test_db();
310        let mut r1 = make_row("sp-1", "t-1", "m1", None, true);
311        r1.created_at = "2025-01-15T10:00:00".into();
312        let mut r2 = make_row("sp-2", "t-2", "m2", None, true);
313        r2.created_at = "2025-01-15T11:00:00".into();
314        record_shadow_prediction(&db, &r1).unwrap();
315        record_shadow_prediction(&db, &r2).unwrap();
316
317        let recent = recent_shadow_predictions(&db, 10).unwrap();
318        assert_eq!(recent.len(), 2);
319        assert_eq!(recent[0].id, "sp-2"); // newer first
320        assert_eq!(recent[1].id, "sp-1");
321    }
322
323    #[test]
324    fn recent_predictions_limit() {
325        let db = test_db();
326        for i in 0..5 {
327            let row = make_row(&format!("sp-{i}"), &format!("t-{i}"), "m", None, true);
328            record_shadow_prediction(&db, &row).unwrap();
329        }
330
331        let recent = recent_shadow_predictions(&db, 2).unwrap();
332        assert_eq!(recent.len(), 2);
333    }
334
335    #[test]
336    fn disagreement_pairs_basic() {
337        let db = test_db();
338        // 2 disagreements on same pair, 1 on different pair, 1 agreement (excluded)
339        record_shadow_prediction(
340            &db,
341            &make_row("sp-1", "t-1", "gpt-4o", Some("qwen3:8b"), false),
342        )
343        .unwrap();
344        record_shadow_prediction(
345            &db,
346            &make_row("sp-2", "t-2", "gpt-4o", Some("qwen3:8b"), false),
347        )
348        .unwrap();
349        record_shadow_prediction(
350            &db,
351            &make_row("sp-3", "t-3", "gpt-4o", Some("claude-3"), false),
352        )
353        .unwrap();
354        record_shadow_prediction(
355            &db,
356            &make_row("sp-4", "t-4", "gpt-4o", Some("qwen3:8b"), true),
357        )
358        .unwrap();
359
360        let pairs = disagreement_pairs(&db, None).unwrap();
361        assert_eq!(pairs.len(), 2);
362        // Most frequent disagreement first
363        assert_eq!(pairs[0], ("gpt-4o".into(), "qwen3:8b".into(), 2));
364        assert_eq!(pairs[1], ("gpt-4o".into(), "claude-3".into(), 1));
365    }
366
367    #[test]
368    fn shadow_model_none_excluded_from_disagreement_pairs() {
369        let db = test_db();
370        // Shadow abstained (None) — should NOT appear in disagreement pairs
371        record_shadow_prediction(&db, &make_row("sp-1", "t-1", "gpt-4o", None, false)).unwrap();
372
373        let pairs = disagreement_pairs(&db, None).unwrap();
374        assert!(pairs.is_empty());
375    }
376}