1use roboticus_core::{RoboticusError, Result};
2use rusqlite::OptionalExtension;
3
4use crate::{Database, DbResultExt};
5
6pub const ROUTING_SCHEMA_VERSION: i64 = 1;
9
10#[derive(Debug, Clone)]
11pub struct ModelSelectionEventRow {
12 pub id: String,
13 pub turn_id: String,
14 pub session_id: String,
15 pub agent_id: String,
16 pub channel: String,
17 pub selected_model: String,
18 pub strategy: String,
19 pub primary_model: String,
20 pub override_model: Option<String>,
21 pub complexity: Option<String>,
22 pub user_excerpt: String,
23 pub candidates_json: String,
24 pub created_at: String,
25 pub schema_version: i64,
27 pub attribution: Option<String>,
28 pub metascore_json: Option<String>,
29 pub features_json: Option<String>,
30}
31
32pub fn record_model_selection_event(db: &Database, row: &ModelSelectionEventRow) -> Result<()> {
33 let conn = db.conn();
34 conn.execute(
35 "INSERT INTO model_selection_events
36 (id, turn_id, session_id, agent_id, channel, selected_model, strategy, primary_model,
37 override_model, complexity, user_excerpt, candidates_json, created_at,
38 schema_version, attribution, metascore_json, features_json)
39 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11, ?12, ?13, ?14, ?15, ?16, ?17)",
40 rusqlite::params![
41 row.id,
42 row.turn_id,
43 row.session_id,
44 row.agent_id,
45 row.channel,
46 row.selected_model,
47 row.strategy,
48 row.primary_model,
49 row.override_model,
50 row.complexity,
51 row.user_excerpt,
52 row.candidates_json,
53 row.created_at,
54 row.schema_version,
55 row.attribution,
56 row.metascore_json,
57 row.features_json,
58 ],
59 )
60 .map_err(|e| RoboticusError::Database(format!("record model selection event: {e}")))?;
61 Ok(())
62}
63
64pub fn get_model_selection_by_turn_id(
65 db: &Database,
66 turn_id: &str,
67) -> Result<Option<ModelSelectionEventRow>> {
68 let conn = db.conn();
69 let mut stmt = conn
70 .prepare(
71 "SELECT id, turn_id, session_id, agent_id, channel, selected_model, strategy, primary_model,
72 override_model, complexity, user_excerpt, candidates_json, created_at,
73 schema_version, attribution, metascore_json, features_json
74 FROM model_selection_events
75 WHERE turn_id = ?1
76 ORDER BY created_at DESC
77 LIMIT 1",
78 )
79 .db_err()?;
80 let row = stmt
81 .query_row(rusqlite::params![turn_id], |r| {
82 Ok(ModelSelectionEventRow {
83 id: r.get(0)?,
84 turn_id: r.get(1)?,
85 session_id: r.get(2)?,
86 agent_id: r.get(3)?,
87 channel: r.get(4)?,
88 selected_model: r.get(5)?,
89 strategy: r.get(6)?,
90 primary_model: r.get(7)?,
91 override_model: r.get(8)?,
92 complexity: r.get(9)?,
93 user_excerpt: r.get(10)?,
94 candidates_json: r.get(11)?,
95 created_at: r.get(12)?,
96 schema_version: r.get(13)?,
97 attribution: r.get(14)?,
98 metascore_json: r.get(15)?,
99 features_json: r.get(16)?,
100 })
101 })
102 .optional()
103 .db_err()?;
104 Ok(row)
105}
106
107pub fn list_model_selection_events(
108 db: &Database,
109 limit: usize,
110) -> Result<Vec<ModelSelectionEventRow>> {
111 let conn = db.conn();
112 let mut stmt = conn
113 .prepare(
114 "SELECT id, turn_id, session_id, agent_id, channel, selected_model, strategy, primary_model,
115 override_model, complexity, user_excerpt, candidates_json, created_at,
116 schema_version, attribution, metascore_json, features_json
117 FROM model_selection_events
118 ORDER BY created_at DESC
119 LIMIT ?1",
120 )
121 .db_err()?;
122 let rows = stmt
123 .query_map(rusqlite::params![limit as i64], |r| {
124 Ok(ModelSelectionEventRow {
125 id: r.get(0)?,
126 turn_id: r.get(1)?,
127 session_id: r.get(2)?,
128 agent_id: r.get(3)?,
129 channel: r.get(4)?,
130 selected_model: r.get(5)?,
131 strategy: r.get(6)?,
132 primary_model: r.get(7)?,
133 override_model: r.get(8)?,
134 complexity: r.get(9)?,
135 user_excerpt: r.get(10)?,
136 candidates_json: r.get(11)?,
137 created_at: r.get(12)?,
138 schema_version: r.get(13)?,
139 attribution: r.get(14)?,
140 metascore_json: r.get(15)?,
141 features_json: r.get(16)?,
142 })
143 })
144 .db_err()?
145 .collect::<std::result::Result<Vec<_>, _>>()
146 .db_err()?;
147 Ok(rows)
148}
149
150pub fn attribution_breakdown(db: &Database, since: Option<&str>) -> Result<Vec<(String, i64)>> {
152 let conn = db.conn();
153 let (sql, params): (&str, Vec<Box<dyn rusqlite::types::ToSql>>) = match since {
154 Some(dt) => (
155 "SELECT COALESCE(attribution, 'unknown'), COUNT(*)
156 FROM model_selection_events
157 WHERE created_at >= ?1
158 GROUP BY COALESCE(attribution, 'unknown')
159 ORDER BY COUNT(*) DESC",
160 vec![Box::new(dt.to_string())],
161 ),
162 None => (
163 "SELECT COALESCE(attribution, 'unknown'), COUNT(*)
164 FROM model_selection_events
165 GROUP BY COALESCE(attribution, 'unknown')
166 ORDER BY COUNT(*) DESC",
167 vec![],
168 ),
169 };
170 let mut stmt = conn.prepare(sql).db_err()?;
171 let rows = stmt
172 .query_map(rusqlite::params_from_iter(params.iter()), |r| {
173 Ok((r.get::<_, String>(0)?, r.get::<_, i64>(1)?))
174 })
175 .db_err()?
176 .collect::<std::result::Result<Vec<_>, _>>()
177 .db_err()?;
178 Ok(rows)
179}
180
181#[cfg(test)]
182mod tests {
183 use super::*;
184
185 fn test_db() -> Database {
186 Database::new(":memory:").unwrap()
187 }
188
189 fn sample_event(id: &str, turn_id: &str) -> ModelSelectionEventRow {
190 ModelSelectionEventRow {
191 id: id.to_string(),
192 turn_id: turn_id.to_string(),
193 session_id: "sess-1".to_string(),
194 agent_id: "agent-1".to_string(),
195 channel: "cli".to_string(),
196 selected_model: "claude-4".to_string(),
197 strategy: "complexity".to_string(),
198 primary_model: "claude-4".to_string(),
199 override_model: None,
200 complexity: Some("high".to_string()),
201 user_excerpt: "Tell me about Rust".to_string(),
202 candidates_json: r#"["claude-4","gpt-4"]"#.to_string(),
203 created_at: "2025-06-01T00:00:00".to_string(),
204 schema_version: ROUTING_SCHEMA_VERSION,
205 attribution: None,
206 metascore_json: None,
207 features_json: None,
208 }
209 }
210
211 #[test]
212 fn record_and_get_by_turn_id() {
213 let db = test_db();
214 let evt = sample_event("mse-1", "turn-1");
215 record_model_selection_event(&db, &evt).unwrap();
216
217 let found = get_model_selection_by_turn_id(&db, "turn-1")
218 .unwrap()
219 .unwrap();
220 assert_eq!(found.id, "mse-1");
221 assert_eq!(found.selected_model, "claude-4");
222 assert_eq!(found.strategy, "complexity");
223 assert_eq!(found.complexity.as_deref(), Some("high"));
224 assert_eq!(found.schema_version, ROUTING_SCHEMA_VERSION);
225 }
226
227 #[test]
228 fn get_by_turn_id_returns_none_for_missing() {
229 let db = test_db();
230 let found = get_model_selection_by_turn_id(&db, "nonexistent").unwrap();
231 assert!(found.is_none());
232 }
233
234 #[test]
235 fn record_with_override_model() {
236 let db = test_db();
237 let mut evt = sample_event("mse-2", "turn-2");
238 evt.override_model = Some("gpt-4".to_string());
239 record_model_selection_event(&db, &evt).unwrap();
240
241 let found = get_model_selection_by_turn_id(&db, "turn-2")
242 .unwrap()
243 .unwrap();
244 assert_eq!(found.override_model.as_deref(), Some("gpt-4"));
245 }
246
247 #[test]
248 fn record_with_no_complexity() {
249 let db = test_db();
250 let mut evt = sample_event("mse-3", "turn-3");
251 evt.complexity = None;
252 record_model_selection_event(&db, &evt).unwrap();
253
254 let found = get_model_selection_by_turn_id(&db, "turn-3")
255 .unwrap()
256 .unwrap();
257 assert!(found.complexity.is_none());
258 }
259
260 #[test]
261 fn record_with_attribution_and_metascore() {
262 let db = test_db();
263 let mut evt = sample_event("mse-attr", "turn-attr");
264 evt.attribution = Some("metascore".to_string());
265 evt.metascore_json = Some(r#"{"efficacy":0.8,"cost":0.5}"#.to_string());
266 evt.features_json = Some(r#"[0.3,0.5,0.1]"#.to_string());
267 record_model_selection_event(&db, &evt).unwrap();
268
269 let found = get_model_selection_by_turn_id(&db, "turn-attr")
270 .unwrap()
271 .unwrap();
272 assert_eq!(found.attribution.as_deref(), Some("metascore"));
273 assert!(found.metascore_json.is_some());
274 assert!(found.features_json.is_some());
275 assert_eq!(found.schema_version, ROUTING_SCHEMA_VERSION);
276 }
277
278 #[test]
279 fn list_events_empty() {
280 let db = test_db();
281 let events = list_model_selection_events(&db, 10).unwrap();
282 assert!(events.is_empty());
283 }
284
285 #[test]
286 fn list_events_returns_all() {
287 let db = test_db();
288 for i in 0..3 {
289 let mut evt = sample_event(&format!("mse-list-{i}"), &format!("turn-list-{i}"));
290 evt.created_at = format!("2025-06-01T0{i}:00:00");
291 record_model_selection_event(&db, &evt).unwrap();
292 }
293
294 let events = list_model_selection_events(&db, 10).unwrap();
295 assert_eq!(events.len(), 3);
296 }
297
298 #[test]
299 fn list_events_respects_limit() {
300 let db = test_db();
301 for i in 0..5 {
302 let mut evt = sample_event(&format!("mse-lim-{i}"), &format!("turn-lim-{i}"));
303 evt.created_at = format!("2025-06-01T0{i}:00:00");
304 record_model_selection_event(&db, &evt).unwrap();
305 }
306
307 let events = list_model_selection_events(&db, 2).unwrap();
308 assert_eq!(events.len(), 2);
309 }
310
311 #[test]
312 fn list_events_ordered_desc() {
313 let db = test_db();
314 let mut e1 = sample_event("mse-ord-1", "turn-ord-1");
315 e1.created_at = "2025-06-01T01:00:00".to_string();
316 let mut e2 = sample_event("mse-ord-2", "turn-ord-2");
317 e2.created_at = "2025-06-01T02:00:00".to_string();
318 record_model_selection_event(&db, &e1).unwrap();
319 record_model_selection_event(&db, &e2).unwrap();
320
321 let events = list_model_selection_events(&db, 10).unwrap();
322 assert_eq!(events[0].id, "mse-ord-2", "most recent should be first");
323 assert_eq!(events[1].id, "mse-ord-1");
324 }
325
326 #[test]
327 fn all_fields_populated() {
328 let db = test_db();
329 let evt = sample_event("mse-fields", "turn-fields");
330 record_model_selection_event(&db, &evt).unwrap();
331
332 let found = get_model_selection_by_turn_id(&db, "turn-fields")
333 .unwrap()
334 .unwrap();
335 assert_eq!(found.session_id, "sess-1");
336 assert_eq!(found.agent_id, "agent-1");
337 assert_eq!(found.channel, "cli");
338 assert_eq!(found.primary_model, "claude-4");
339 assert_eq!(found.user_excerpt, "Tell me about Rust");
340 assert_eq!(found.candidates_json, r#"["claude-4","gpt-4"]"#);
341 assert_eq!(found.created_at, "2025-06-01T00:00:00");
342 }
343
344 #[test]
345 fn duplicate_id_fails() {
346 let db = test_db();
347 let evt = sample_event("mse-dup", "turn-dup");
348 record_model_selection_event(&db, &evt).unwrap();
349 let result = record_model_selection_event(&db, &evt);
351 assert!(result.is_err());
352 }
353
354 #[test]
355 fn attribution_breakdown_counts_correctly() {
356 let db = test_db();
357 for (i, attr) in ["metascore", "metascore", "override", "fallback"]
358 .iter()
359 .enumerate()
360 {
361 let mut evt = sample_event(&format!("mse-ab-{i}"), &format!("turn-ab-{i}"));
362 evt.attribution = Some(attr.to_string());
363 evt.created_at = format!("2025-06-01T0{i}:00:00");
364 record_model_selection_event(&db, &evt).unwrap();
365 }
366
367 let counts = attribution_breakdown(&db, None).unwrap();
368 assert_eq!(counts.len(), 3);
369 assert_eq!(counts[0].0, "metascore");
371 assert_eq!(counts[0].1, 2);
372 }
373
374 #[test]
375 fn attribution_breakdown_with_since_filter() {
376 let db = test_db();
377 let mut e1 = sample_event("mse-ab-old", "turn-ab-old");
378 e1.attribution = Some("metascore".to_string());
379 e1.created_at = "2024-01-01T00:00:00".to_string();
380 let mut e2 = sample_event("mse-ab-new", "turn-ab-new");
381 e2.attribution = Some("override".to_string());
382 e2.created_at = "2025-06-01T00:00:00".to_string();
383 record_model_selection_event(&db, &e1).unwrap();
384 record_model_selection_event(&db, &e2).unwrap();
385
386 let counts = attribution_breakdown(&db, Some("2025-01-01T00:00:00")).unwrap();
387 assert_eq!(counts.len(), 1);
388 assert_eq!(counts[0].0, "override");
389 }
390}