1use crate::Database;
9use roboticus_core::{RoboticusError, Result};
10use serde::Serialize;
11
12#[derive(Debug, Clone, Serialize)]
15pub struct RoutingDatasetRow {
16 pub event_id: String,
18 pub turn_id: String,
19 pub session_id: String,
20 pub agent_id: String,
21 pub channel: String,
22 pub selected_model: String,
23 pub strategy: String,
24 pub primary_model: String,
25 pub override_model: Option<String>,
26 pub complexity: Option<String>,
27 pub user_excerpt: String,
28 pub candidates_json: String,
29 pub attribution: Option<String>,
30 pub metascore_json: Option<String>,
31 pub features_json: Option<String>,
32 pub schema_version: i64,
33 pub decision_at: String,
34
35 pub total_tokens_in: i64,
37 pub total_tokens_out: i64,
38 pub total_cost: f64,
39 pub inference_count: i64,
40 pub any_cached: bool,
41 pub avg_latency_ms: Option<f64>,
42 pub avg_quality_score: Option<f64>,
43 pub any_escalation: bool,
44}
45
46#[derive(Debug, Clone, Default)]
48pub struct DatasetFilter {
49 pub since: Option<String>,
51 pub until: Option<String>,
53 pub schema_version: Option<i64>,
55 pub limit: Option<usize>,
57}
58
59pub fn extract_routing_dataset(
65 db: &Database,
66 filter: &DatasetFilter,
67) -> Result<Vec<RoutingDatasetRow>> {
68 let conn = db.conn();
69
70 let mut conditions = Vec::new();
72 let mut params: Vec<Box<dyn rusqlite::types::ToSql>> = Vec::new();
73 let mut idx = 1;
74
75 if let Some(ref since) = filter.since {
76 conditions.push(format!("mse.created_at >= ?{idx}"));
77 params.push(Box::new(since.clone()));
78 idx += 1;
79 }
80 if let Some(ref until) = filter.until {
81 conditions.push(format!("mse.created_at < ?{idx}"));
82 params.push(Box::new(until.clone()));
83 idx += 1;
84 }
85 if let Some(sv) = filter.schema_version {
86 conditions.push(format!("mse.schema_version = ?{idx}"));
87 params.push(Box::new(sv));
88 idx += 1;
89 }
90
91 let limit = filter.limit.unwrap_or(10_000) as i64;
92 let limit_placeholder = format!("?{idx}");
93 params.push(Box::new(limit));
94
95 let where_clause = if conditions.is_empty() {
96 String::new()
97 } else {
98 format!("WHERE {}", conditions.join(" AND "))
99 };
100
101 let sql = format!(
102 "SELECT
103 mse.id,
104 mse.turn_id,
105 mse.session_id,
106 mse.agent_id,
107 mse.channel,
108 mse.selected_model,
109 mse.strategy,
110 mse.primary_model,
111 mse.override_model,
112 mse.complexity,
113 mse.user_excerpt,
114 mse.candidates_json,
115 mse.attribution,
116 mse.metascore_json,
117 mse.features_json,
118 mse.schema_version,
119 mse.created_at,
120 COALESCE(SUM(ic.tokens_in), 0) AS total_tokens_in,
121 COALESCE(SUM(ic.tokens_out), 0) AS total_tokens_out,
122 COALESCE(SUM(ic.cost), 0.0) AS total_cost,
123 COUNT(ic.id) AS inference_count,
124 COALESCE(MAX(ic.cached), 0) AS any_cached,
125 AVG(ic.latency_ms) AS avg_latency_ms,
126 AVG(ic.quality_score) AS avg_quality_score,
127 COALESCE(MAX(ic.escalation), 0) AS any_escalation
128 FROM model_selection_events mse
129 INNER JOIN inference_costs ic ON ic.turn_id = mse.turn_id
130 {where_clause}
131 GROUP BY mse.id
132 ORDER BY mse.created_at ASC
133 LIMIT {limit_placeholder}"
134 );
135
136 let mut stmt = conn
137 .prepare(&sql)
138 .map_err(|e| RoboticusError::Database(format!("prepare routing dataset: {e}")))?;
139
140 let rows = stmt
141 .query_map(rusqlite::params_from_iter(params.iter()), |r| {
142 Ok(RoutingDatasetRow {
143 event_id: r.get(0)?,
144 turn_id: r.get(1)?,
145 session_id: r.get(2)?,
146 agent_id: r.get(3)?,
147 channel: r.get(4)?,
148 selected_model: r.get(5)?,
149 strategy: r.get(6)?,
150 primary_model: r.get(7)?,
151 override_model: r.get(8)?,
152 complexity: r.get(9)?,
153 user_excerpt: r.get(10)?,
154 candidates_json: r.get(11)?,
155 attribution: r.get(12)?,
156 metascore_json: r.get(13)?,
157 features_json: r.get(14)?,
158 schema_version: r.get(15)?,
159 decision_at: r.get(16)?,
160 total_tokens_in: r.get(17)?,
161 total_tokens_out: r.get(18)?,
162 total_cost: r.get(19)?,
163 inference_count: r.get(20)?,
164 any_cached: r.get::<_, i32>(21)? != 0,
165 avg_latency_ms: r.get(22)?,
166 avg_quality_score: r.get(23)?,
167 any_escalation: r.get::<_, i32>(24)? != 0,
168 })
169 })
170 .map_err(|e| RoboticusError::Database(format!("query routing dataset: {e}")))?
171 .collect::<std::result::Result<Vec<_>, _>>()
172 .map_err(|e| RoboticusError::Database(format!("collect routing dataset: {e}")))?;
173
174 Ok(rows)
175}
176
177#[derive(Debug, Clone, Serialize)]
179pub struct DatasetSummary {
180 pub total_rows: usize,
181 pub distinct_models: usize,
182 pub distinct_strategies: usize,
183 pub total_cost: f64,
184 pub avg_cost_per_decision: f64,
185 pub schema_versions: Vec<i64>,
186}
187
188pub fn dataset_summary(db: &Database, filter: &DatasetFilter) -> Result<DatasetSummary> {
190 let mut summary_filter = filter.clone();
191 summary_filter.limit = None;
193 let rows = extract_routing_dataset(db, &summary_filter)?;
194 if rows.is_empty() {
195 return Ok(DatasetSummary {
196 total_rows: 0,
197 distinct_models: 0,
198 distinct_strategies: 0,
199 total_cost: 0.0,
200 avg_cost_per_decision: 0.0,
201 schema_versions: vec![],
202 });
203 }
204
205 use std::collections::HashSet;
206 let models: HashSet<&str> = rows.iter().map(|r| r.selected_model.as_str()).collect();
207 let strategies: HashSet<&str> = rows.iter().map(|r| r.strategy.as_str()).collect();
208 let total_cost: f64 = rows.iter().map(|r| r.total_cost).sum();
209 let svs: HashSet<i64> = rows.iter().map(|r| r.schema_version).collect();
210 let mut sv_vec: Vec<i64> = svs.into_iter().collect();
211 sv_vec.sort();
212
213 Ok(DatasetSummary {
214 total_rows: rows.len(),
215 distinct_models: models.len(),
216 distinct_strategies: strategies.len(),
217 total_cost,
218 avg_cost_per_decision: total_cost / rows.len() as f64,
219 schema_versions: sv_vec,
220 })
221}
222
223pub fn extract_routing_dataset_tsv(db: &Database, filter: &DatasetFilter) -> Result<String> {
227 let rows = extract_routing_dataset(db, filter)?;
228 let mut out = String::with_capacity(rows.len() * 256);
229
230 out.push_str(
232 "event_id\tturn_id\tsession_id\tagent_id\tchannel\tselected_model\tstrategy\t\
233 primary_model\toverride_model\tcomplexity\tattribution\tschema_version\t\
234 decision_at\ttotal_tokens_in\ttotal_tokens_out\ttotal_cost\tinference_count\t\
235 any_cached\tavg_latency_ms\tavg_quality_score\tany_escalation\n",
236 );
237
238 for r in &rows {
239 out.push_str(&format!(
240 "{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{:.6}\t{}\t{}\t{}\t{}\t{}\n",
241 r.event_id,
242 r.turn_id,
243 r.session_id,
244 r.agent_id,
245 r.channel,
246 r.selected_model,
247 r.strategy,
248 r.primary_model,
249 r.override_model.as_deref().unwrap_or(""),
250 r.complexity.as_deref().unwrap_or(""),
251 r.attribution.as_deref().unwrap_or(""),
252 r.schema_version,
253 r.decision_at,
254 r.total_tokens_in,
255 r.total_tokens_out,
256 r.total_cost,
257 r.inference_count,
258 r.any_cached as i32,
259 r.avg_latency_ms.map_or("".to_string(), |v| format!("{v:.1}")),
260 r.avg_quality_score.map_or("".to_string(), |v| format!("{v:.4}")),
261 r.any_escalation as i32,
262 ));
263 }
264
265 Ok(out)
266}
267
268#[cfg(test)]
269mod tests {
270 use super::*;
271 use crate::metrics::record_inference_cost;
272 use crate::model_selection::{
273 ModelSelectionEventRow, ROUTING_SCHEMA_VERSION, record_model_selection_event,
274 };
275
276 fn test_db() -> Database {
277 Database::new(":memory:").unwrap()
278 }
279
280 fn insert_decision(
281 db: &Database,
282 event_id: &str,
283 turn_id: &str,
284 model: &str,
285 attribution: Option<&str>,
286 created_at: &str,
287 ) {
288 let evt = ModelSelectionEventRow {
289 id: event_id.to_string(),
290 turn_id: turn_id.to_string(),
291 session_id: "sess-ds".to_string(),
292 agent_id: "agent-ds".to_string(),
293 channel: "cli".to_string(),
294 selected_model: model.to_string(),
295 strategy: "complexity".to_string(),
296 primary_model: model.to_string(),
297 override_model: None,
298 complexity: Some("medium".to_string()),
299 user_excerpt: "test query".to_string(),
300 candidates_json: format!(r#"["{model}"]"#),
301 created_at: created_at.to_string(),
302 schema_version: ROUTING_SCHEMA_VERSION,
303 attribution: attribution.map(|s| s.to_string()),
304 metascore_json: None,
305 features_json: None,
306 };
307 record_model_selection_event(db, &evt).unwrap();
308 }
309
310 fn insert_cost(
311 db: &Database,
312 turn_id: &str,
313 model: &str,
314 tokens_in: i64,
315 tokens_out: i64,
316 cost: f64,
317 ) {
318 record_inference_cost(
319 db,
320 model,
321 "test-provider",
322 tokens_in,
323 tokens_out,
324 cost,
325 None,
326 false,
327 Some(100),
328 Some(0.85),
329 false,
330 Some(turn_id),
331 )
332 .unwrap();
333 }
334
335 #[test]
336 fn empty_dataset() {
337 let db = test_db();
338 let rows = extract_routing_dataset(&db, &DatasetFilter::default()).unwrap();
339 assert!(rows.is_empty());
340 }
341
342 #[test]
343 fn decision_without_cost_excluded() {
344 let db = test_db();
345 insert_decision(
346 &db,
347 "evt-1",
348 "turn-1",
349 "claude-4",
350 Some("metascore"),
351 "2025-06-01T00:00:00",
352 );
353 let rows = extract_routing_dataset(&db, &DatasetFilter::default()).unwrap();
355 assert!(
356 rows.is_empty(),
357 "decisions with no cost should be excluded (INNER JOIN)"
358 );
359 }
360
361 #[test]
362 fn basic_join() {
363 let db = test_db();
364 insert_decision(
365 &db,
366 "evt-1",
367 "turn-1",
368 "claude-4",
369 Some("metascore"),
370 "2025-06-01T00:00:00",
371 );
372 insert_cost(&db, "turn-1", "claude-4", 1000, 500, 0.03);
373
374 let rows = extract_routing_dataset(&db, &DatasetFilter::default()).unwrap();
375 assert_eq!(rows.len(), 1);
376 let r = &rows[0];
377 assert_eq!(r.event_id, "evt-1");
378 assert_eq!(r.selected_model, "claude-4");
379 assert_eq!(r.total_tokens_in, 1000);
380 assert_eq!(r.total_tokens_out, 500);
381 assert!((r.total_cost - 0.03).abs() < 1e-9);
382 assert_eq!(r.inference_count, 1);
383 assert!(!r.any_cached);
384 assert!(r.avg_latency_ms.is_some());
385 assert!(r.avg_quality_score.is_some());
386 }
387
388 #[test]
389 fn multiple_costs_per_turn_aggregate() {
390 let db = test_db();
391 insert_decision(
392 &db,
393 "evt-agg",
394 "turn-agg",
395 "claude-4",
396 None,
397 "2025-06-01T00:00:00",
398 );
399 insert_cost(&db, "turn-agg", "claude-4", 500, 200, 0.01);
400 insert_cost(&db, "turn-agg", "claude-4", 300, 100, 0.005);
401
402 let rows = extract_routing_dataset(&db, &DatasetFilter::default()).unwrap();
403 assert_eq!(rows.len(), 1);
404 let r = &rows[0];
405 assert_eq!(r.total_tokens_in, 800);
406 assert_eq!(r.total_tokens_out, 300);
407 assert!((r.total_cost - 0.015).abs() < 1e-9);
408 assert_eq!(r.inference_count, 2);
409 }
410
411 #[test]
412 fn filter_since() {
413 let db = test_db();
414 insert_decision(
415 &db,
416 "evt-old",
417 "turn-old",
418 "gpt-4",
419 None,
420 "2024-01-01T00:00:00",
421 );
422 insert_cost(&db, "turn-old", "gpt-4", 100, 50, 0.01);
423 insert_decision(
424 &db,
425 "evt-new",
426 "turn-new",
427 "claude-4",
428 None,
429 "2025-06-01T00:00:00",
430 );
431 insert_cost(&db, "turn-new", "claude-4", 200, 100, 0.02);
432
433 let rows = extract_routing_dataset(
434 &db,
435 &DatasetFilter {
436 since: Some("2025-01-01T00:00:00".to_string()),
437 ..Default::default()
438 },
439 )
440 .unwrap();
441 assert_eq!(rows.len(), 1);
442 assert_eq!(rows[0].event_id, "evt-new");
443 }
444
445 #[test]
446 fn filter_until() {
447 let db = test_db();
448 insert_decision(
449 &db,
450 "evt-old",
451 "turn-old",
452 "gpt-4",
453 None,
454 "2024-01-01T00:00:00",
455 );
456 insert_cost(&db, "turn-old", "gpt-4", 100, 50, 0.01);
457 insert_decision(
458 &db,
459 "evt-new",
460 "turn-new",
461 "claude-4",
462 None,
463 "2025-06-01T00:00:00",
464 );
465 insert_cost(&db, "turn-new", "claude-4", 200, 100, 0.02);
466
467 let rows = extract_routing_dataset(
468 &db,
469 &DatasetFilter {
470 until: Some("2025-01-01T00:00:00".to_string()),
471 ..Default::default()
472 },
473 )
474 .unwrap();
475 assert_eq!(rows.len(), 1);
476 assert_eq!(rows[0].event_id, "evt-old");
477 }
478
479 #[test]
480 fn filter_schema_version() {
481 let db = test_db();
482 insert_decision(
483 &db,
484 "evt-v1",
485 "turn-v1",
486 "claude-4",
487 None,
488 "2025-06-01T00:00:00",
489 );
490 insert_cost(&db, "turn-v1", "claude-4", 100, 50, 0.01);
491
492 let rows = extract_routing_dataset(
494 &db,
495 &DatasetFilter {
496 schema_version: Some(99),
497 ..Default::default()
498 },
499 )
500 .unwrap();
501 assert!(rows.is_empty());
502
503 let rows = extract_routing_dataset(
505 &db,
506 &DatasetFilter {
507 schema_version: Some(ROUTING_SCHEMA_VERSION),
508 ..Default::default()
509 },
510 )
511 .unwrap();
512 assert_eq!(rows.len(), 1);
513 }
514
515 #[test]
516 fn filter_limit() {
517 let db = test_db();
518 for i in 0..5 {
519 let eid = format!("evt-lim-{i}");
520 let tid = format!("turn-lim-{i}");
521 insert_decision(
522 &db,
523 &eid,
524 &tid,
525 "claude-4",
526 None,
527 &format!("2025-06-0{i}T00:00:00"),
528 );
529 insert_cost(&db, &tid, "claude-4", 100, 50, 0.01);
530 }
531
532 let rows = extract_routing_dataset(
533 &db,
534 &DatasetFilter {
535 limit: Some(2),
536 ..Default::default()
537 },
538 )
539 .unwrap();
540 assert_eq!(rows.len(), 2);
541 }
542
543 #[test]
544 fn dataset_summary_empty() {
545 let db = test_db();
546 let s = dataset_summary(&db, &DatasetFilter::default()).unwrap();
547 assert_eq!(s.total_rows, 0);
548 assert_eq!(s.distinct_models, 0);
549 }
550
551 #[test]
552 fn dataset_summary_populated() {
553 let db = test_db();
554 insert_decision(
555 &db,
556 "evt-s1",
557 "turn-s1",
558 "claude-4",
559 Some("metascore"),
560 "2025-06-01T00:00:00",
561 );
562 insert_cost(&db, "turn-s1", "claude-4", 1000, 500, 0.03);
563 insert_decision(
564 &db,
565 "evt-s2",
566 "turn-s2",
567 "gpt-4",
568 Some("fallback"),
569 "2025-06-02T00:00:00",
570 );
571 insert_cost(&db, "turn-s2", "gpt-4", 500, 200, 0.01);
572
573 let s = dataset_summary(&db, &DatasetFilter::default()).unwrap();
574 assert_eq!(s.total_rows, 2);
575 assert_eq!(s.distinct_models, 2);
576 assert_eq!(s.distinct_strategies, 1); assert!((s.total_cost - 0.04).abs() < 1e-9);
578 assert!((s.avg_cost_per_decision - 0.02).abs() < 1e-9);
579 assert_eq!(s.schema_versions, vec![ROUTING_SCHEMA_VERSION]);
580 }
581
582 #[test]
583 fn dataset_summary_ignores_limit_cap() {
584 let db = test_db();
585 for i in 0..3 {
586 let eid = format!("evt-sum-{i}");
587 let tid = format!("turn-sum-{i}");
588 insert_decision(
589 &db,
590 &eid,
591 &tid,
592 "claude-4",
593 Some("metascore"),
594 &format!("2025-06-0{}T00:00:00", i + 1),
595 );
596 insert_cost(&db, &tid, "claude-4", 100, 50, 0.01);
597 }
598 let s = dataset_summary(
599 &db,
600 &DatasetFilter {
601 limit: Some(1),
602 ..Default::default()
603 },
604 )
605 .unwrap();
606 assert_eq!(s.total_rows, 3);
607 }
608
609 #[test]
610 fn tsv_export_header_and_rows() {
611 let db = test_db();
612 insert_decision(
613 &db,
614 "evt-tsv",
615 "turn-tsv",
616 "claude-4",
617 Some("primary_usable"),
618 "2025-06-01T00:00:00",
619 );
620 insert_cost(&db, "turn-tsv", "claude-4", 100, 50, 0.005);
621
622 let tsv = extract_routing_dataset_tsv(&db, &DatasetFilter::default()).unwrap();
623 let lines: Vec<&str> = tsv.lines().collect();
624 assert!(lines.len() >= 2, "should have header + at least 1 row");
625 assert!(lines[0].starts_with("event_id\t"));
626 assert!(lines[1].starts_with("evt-tsv\t"));
627 assert!(lines[1].contains("primary_usable"));
628 }
629
630 #[test]
631 fn ordering_is_ascending() {
632 let db = test_db();
633 insert_decision(
634 &db,
635 "evt-asc-2",
636 "turn-asc-2",
637 "claude-4",
638 None,
639 "2025-06-02T00:00:00",
640 );
641 insert_cost(&db, "turn-asc-2", "claude-4", 100, 50, 0.01);
642 insert_decision(
643 &db,
644 "evt-asc-1",
645 "turn-asc-1",
646 "claude-4",
647 None,
648 "2025-06-01T00:00:00",
649 );
650 insert_cost(&db, "turn-asc-1", "claude-4", 100, 50, 0.01);
651
652 let rows = extract_routing_dataset(&db, &DatasetFilter::default()).unwrap();
653 assert_eq!(
654 rows[0].event_id, "evt-asc-1",
655 "oldest first (ASC for training data)"
656 );
657 assert_eq!(rows[1].event_id, "evt-asc-2");
658 }
659}