Skip to main content

roboticus_db/
model_selection.rs

1use roboticus_core::{RoboticusError, Result};
2use rusqlite::OptionalExtension;
3
4use crate::{Database, DbResultExt};
5
6/// Current routing feature schema version. Bump when feature extraction or
7/// scoring logic changes to invalidate historical reproducibility.
8pub const ROUTING_SCHEMA_VERSION: i64 = 1;
9
10#[derive(Debug, Clone)]
11pub struct ModelSelectionEventRow {
12    pub id: String,
13    pub turn_id: String,
14    pub session_id: String,
15    pub agent_id: String,
16    pub channel: String,
17    pub selected_model: String,
18    pub strategy: String,
19    pub primary_model: String,
20    pub override_model: Option<String>,
21    pub complexity: Option<String>,
22    pub user_excerpt: String,
23    pub candidates_json: String,
24    pub created_at: String,
25    // v0.9.4: routing baseline hardening fields
26    pub schema_version: i64,
27    pub attribution: Option<String>,
28    pub metascore_json: Option<String>,
29    pub features_json: Option<String>,
30}
31
32pub fn record_model_selection_event(db: &Database, row: &ModelSelectionEventRow) -> Result<()> {
33    let conn = db.conn();
34    conn.execute(
35        "INSERT INTO model_selection_events
36         (id, turn_id, session_id, agent_id, channel, selected_model, strategy, primary_model,
37          override_model, complexity, user_excerpt, candidates_json, created_at,
38          schema_version, attribution, metascore_json, features_json)
39         VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11, ?12, ?13, ?14, ?15, ?16, ?17)",
40        rusqlite::params![
41            row.id,
42            row.turn_id,
43            row.session_id,
44            row.agent_id,
45            row.channel,
46            row.selected_model,
47            row.strategy,
48            row.primary_model,
49            row.override_model,
50            row.complexity,
51            row.user_excerpt,
52            row.candidates_json,
53            row.created_at,
54            row.schema_version,
55            row.attribution,
56            row.metascore_json,
57            row.features_json,
58        ],
59    )
60    .map_err(|e| RoboticusError::Database(format!("record model selection event: {e}")))?;
61    Ok(())
62}
63
64pub fn get_model_selection_by_turn_id(
65    db: &Database,
66    turn_id: &str,
67) -> Result<Option<ModelSelectionEventRow>> {
68    let conn = db.conn();
69    let mut stmt = conn
70        .prepare(
71            "SELECT id, turn_id, session_id, agent_id, channel, selected_model, strategy, primary_model,
72                    override_model, complexity, user_excerpt, candidates_json, created_at,
73                    schema_version, attribution, metascore_json, features_json
74             FROM model_selection_events
75             WHERE turn_id = ?1
76             ORDER BY created_at DESC
77             LIMIT 1",
78        )
79        .db_err()?;
80    let row = stmt
81        .query_row(rusqlite::params![turn_id], |r| {
82            Ok(ModelSelectionEventRow {
83                id: r.get(0)?,
84                turn_id: r.get(1)?,
85                session_id: r.get(2)?,
86                agent_id: r.get(3)?,
87                channel: r.get(4)?,
88                selected_model: r.get(5)?,
89                strategy: r.get(6)?,
90                primary_model: r.get(7)?,
91                override_model: r.get(8)?,
92                complexity: r.get(9)?,
93                user_excerpt: r.get(10)?,
94                candidates_json: r.get(11)?,
95                created_at: r.get(12)?,
96                schema_version: r.get(13)?,
97                attribution: r.get(14)?,
98                metascore_json: r.get(15)?,
99                features_json: r.get(16)?,
100            })
101        })
102        .optional()
103        .db_err()?;
104    Ok(row)
105}
106
107pub fn list_model_selection_events(
108    db: &Database,
109    limit: usize,
110) -> Result<Vec<ModelSelectionEventRow>> {
111    let conn = db.conn();
112    let mut stmt = conn
113        .prepare(
114            "SELECT id, turn_id, session_id, agent_id, channel, selected_model, strategy, primary_model,
115                    override_model, complexity, user_excerpt, candidates_json, created_at,
116                    schema_version, attribution, metascore_json, features_json
117             FROM model_selection_events
118             ORDER BY created_at DESC
119             LIMIT ?1",
120        )
121        .db_err()?;
122    let rows = stmt
123        .query_map(rusqlite::params![limit as i64], |r| {
124            Ok(ModelSelectionEventRow {
125                id: r.get(0)?,
126                turn_id: r.get(1)?,
127                session_id: r.get(2)?,
128                agent_id: r.get(3)?,
129                channel: r.get(4)?,
130                selected_model: r.get(5)?,
131                strategy: r.get(6)?,
132                primary_model: r.get(7)?,
133                override_model: r.get(8)?,
134                complexity: r.get(9)?,
135                user_excerpt: r.get(10)?,
136                candidates_json: r.get(11)?,
137                created_at: r.get(12)?,
138                schema_version: r.get(13)?,
139                attribution: r.get(14)?,
140                metascore_json: r.get(15)?,
141                features_json: r.get(16)?,
142            })
143        })
144        .db_err()?
145        .collect::<std::result::Result<Vec<_>, _>>()
146        .db_err()?;
147    Ok(rows)
148}
149
150/// Count routing decisions grouped by attribution label since a given datetime.
151pub fn attribution_breakdown(db: &Database, since: Option<&str>) -> Result<Vec<(String, i64)>> {
152    let conn = db.conn();
153    let (sql, params): (&str, Vec<Box<dyn rusqlite::types::ToSql>>) = match since {
154        Some(dt) => (
155            "SELECT COALESCE(attribution, 'unknown'), COUNT(*)
156             FROM model_selection_events
157             WHERE created_at >= ?1
158             GROUP BY COALESCE(attribution, 'unknown')
159             ORDER BY COUNT(*) DESC",
160            vec![Box::new(dt.to_string())],
161        ),
162        None => (
163            "SELECT COALESCE(attribution, 'unknown'), COUNT(*)
164             FROM model_selection_events
165             GROUP BY COALESCE(attribution, 'unknown')
166             ORDER BY COUNT(*) DESC",
167            vec![],
168        ),
169    };
170    let mut stmt = conn.prepare(sql).db_err()?;
171    let rows = stmt
172        .query_map(rusqlite::params_from_iter(params.iter()), |r| {
173            Ok((r.get::<_, String>(0)?, r.get::<_, i64>(1)?))
174        })
175        .db_err()?
176        .collect::<std::result::Result<Vec<_>, _>>()
177        .db_err()?;
178    Ok(rows)
179}
180
181#[cfg(test)]
182mod tests {
183    use super::*;
184
185    fn test_db() -> Database {
186        Database::new(":memory:").unwrap()
187    }
188
189    fn sample_event(id: &str, turn_id: &str) -> ModelSelectionEventRow {
190        ModelSelectionEventRow {
191            id: id.to_string(),
192            turn_id: turn_id.to_string(),
193            session_id: "sess-1".to_string(),
194            agent_id: "agent-1".to_string(),
195            channel: "cli".to_string(),
196            selected_model: "claude-4".to_string(),
197            strategy: "complexity".to_string(),
198            primary_model: "claude-4".to_string(),
199            override_model: None,
200            complexity: Some("high".to_string()),
201            user_excerpt: "Tell me about Rust".to_string(),
202            candidates_json: r#"["claude-4","gpt-4"]"#.to_string(),
203            created_at: "2025-06-01T00:00:00".to_string(),
204            schema_version: ROUTING_SCHEMA_VERSION,
205            attribution: None,
206            metascore_json: None,
207            features_json: None,
208        }
209    }
210
211    #[test]
212    fn record_and_get_by_turn_id() {
213        let db = test_db();
214        let evt = sample_event("mse-1", "turn-1");
215        record_model_selection_event(&db, &evt).unwrap();
216
217        let found = get_model_selection_by_turn_id(&db, "turn-1")
218            .unwrap()
219            .unwrap();
220        assert_eq!(found.id, "mse-1");
221        assert_eq!(found.selected_model, "claude-4");
222        assert_eq!(found.strategy, "complexity");
223        assert_eq!(found.complexity.as_deref(), Some("high"));
224        assert_eq!(found.schema_version, ROUTING_SCHEMA_VERSION);
225    }
226
227    #[test]
228    fn get_by_turn_id_returns_none_for_missing() {
229        let db = test_db();
230        let found = get_model_selection_by_turn_id(&db, "nonexistent").unwrap();
231        assert!(found.is_none());
232    }
233
234    #[test]
235    fn record_with_override_model() {
236        let db = test_db();
237        let mut evt = sample_event("mse-2", "turn-2");
238        evt.override_model = Some("gpt-4".to_string());
239        record_model_selection_event(&db, &evt).unwrap();
240
241        let found = get_model_selection_by_turn_id(&db, "turn-2")
242            .unwrap()
243            .unwrap();
244        assert_eq!(found.override_model.as_deref(), Some("gpt-4"));
245    }
246
247    #[test]
248    fn record_with_no_complexity() {
249        let db = test_db();
250        let mut evt = sample_event("mse-3", "turn-3");
251        evt.complexity = None;
252        record_model_selection_event(&db, &evt).unwrap();
253
254        let found = get_model_selection_by_turn_id(&db, "turn-3")
255            .unwrap()
256            .unwrap();
257        assert!(found.complexity.is_none());
258    }
259
260    #[test]
261    fn record_with_attribution_and_metascore() {
262        let db = test_db();
263        let mut evt = sample_event("mse-attr", "turn-attr");
264        evt.attribution = Some("metascore".to_string());
265        evt.metascore_json = Some(r#"{"efficacy":0.8,"cost":0.5}"#.to_string());
266        evt.features_json = Some(r#"[0.3,0.5,0.1]"#.to_string());
267        record_model_selection_event(&db, &evt).unwrap();
268
269        let found = get_model_selection_by_turn_id(&db, "turn-attr")
270            .unwrap()
271            .unwrap();
272        assert_eq!(found.attribution.as_deref(), Some("metascore"));
273        assert!(found.metascore_json.is_some());
274        assert!(found.features_json.is_some());
275        assert_eq!(found.schema_version, ROUTING_SCHEMA_VERSION);
276    }
277
278    #[test]
279    fn list_events_empty() {
280        let db = test_db();
281        let events = list_model_selection_events(&db, 10).unwrap();
282        assert!(events.is_empty());
283    }
284
285    #[test]
286    fn list_events_returns_all() {
287        let db = test_db();
288        for i in 0..3 {
289            let mut evt = sample_event(&format!("mse-list-{i}"), &format!("turn-list-{i}"));
290            evt.created_at = format!("2025-06-01T0{i}:00:00");
291            record_model_selection_event(&db, &evt).unwrap();
292        }
293
294        let events = list_model_selection_events(&db, 10).unwrap();
295        assert_eq!(events.len(), 3);
296    }
297
298    #[test]
299    fn list_events_respects_limit() {
300        let db = test_db();
301        for i in 0..5 {
302            let mut evt = sample_event(&format!("mse-lim-{i}"), &format!("turn-lim-{i}"));
303            evt.created_at = format!("2025-06-01T0{i}:00:00");
304            record_model_selection_event(&db, &evt).unwrap();
305        }
306
307        let events = list_model_selection_events(&db, 2).unwrap();
308        assert_eq!(events.len(), 2);
309    }
310
311    #[test]
312    fn list_events_ordered_desc() {
313        let db = test_db();
314        let mut e1 = sample_event("mse-ord-1", "turn-ord-1");
315        e1.created_at = "2025-06-01T01:00:00".to_string();
316        let mut e2 = sample_event("mse-ord-2", "turn-ord-2");
317        e2.created_at = "2025-06-01T02:00:00".to_string();
318        record_model_selection_event(&db, &e1).unwrap();
319        record_model_selection_event(&db, &e2).unwrap();
320
321        let events = list_model_selection_events(&db, 10).unwrap();
322        assert_eq!(events[0].id, "mse-ord-2", "most recent should be first");
323        assert_eq!(events[1].id, "mse-ord-1");
324    }
325
326    #[test]
327    fn all_fields_populated() {
328        let db = test_db();
329        let evt = sample_event("mse-fields", "turn-fields");
330        record_model_selection_event(&db, &evt).unwrap();
331
332        let found = get_model_selection_by_turn_id(&db, "turn-fields")
333            .unwrap()
334            .unwrap();
335        assert_eq!(found.session_id, "sess-1");
336        assert_eq!(found.agent_id, "agent-1");
337        assert_eq!(found.channel, "cli");
338        assert_eq!(found.primary_model, "claude-4");
339        assert_eq!(found.user_excerpt, "Tell me about Rust");
340        assert_eq!(found.candidates_json, r#"["claude-4","gpt-4"]"#);
341        assert_eq!(found.created_at, "2025-06-01T00:00:00");
342    }
343
344    #[test]
345    fn duplicate_id_fails() {
346        let db = test_db();
347        let evt = sample_event("mse-dup", "turn-dup");
348        record_model_selection_event(&db, &evt).unwrap();
349        // Same id should fail (PRIMARY KEY constraint)
350        let result = record_model_selection_event(&db, &evt);
351        assert!(result.is_err());
352    }
353
354    #[test]
355    fn attribution_breakdown_counts_correctly() {
356        let db = test_db();
357        for (i, attr) in ["metascore", "metascore", "override", "fallback"]
358            .iter()
359            .enumerate()
360        {
361            let mut evt = sample_event(&format!("mse-ab-{i}"), &format!("turn-ab-{i}"));
362            evt.attribution = Some(attr.to_string());
363            evt.created_at = format!("2025-06-01T0{i}:00:00");
364            record_model_selection_event(&db, &evt).unwrap();
365        }
366
367        let counts = attribution_breakdown(&db, None).unwrap();
368        assert_eq!(counts.len(), 3);
369        // metascore should be first (count=2)
370        assert_eq!(counts[0].0, "metascore");
371        assert_eq!(counts[0].1, 2);
372    }
373
374    #[test]
375    fn attribution_breakdown_with_since_filter() {
376        let db = test_db();
377        let mut e1 = sample_event("mse-ab-old", "turn-ab-old");
378        e1.attribution = Some("metascore".to_string());
379        e1.created_at = "2024-01-01T00:00:00".to_string();
380        let mut e2 = sample_event("mse-ab-new", "turn-ab-new");
381        e2.attribution = Some("override".to_string());
382        e2.created_at = "2025-06-01T00:00:00".to_string();
383        record_model_selection_event(&db, &e1).unwrap();
384        record_model_selection_event(&db, &e2).unwrap();
385
386        let counts = attribution_breakdown(&db, Some("2025-01-01T00:00:00")).unwrap();
387        assert_eq!(counts.len(), 1);
388        assert_eq!(counts[0].0, "override");
389    }
390}