Skip to main content

roboticus_db/
routing_dataset.rs

1//! Historical routing dataset extraction pipeline.
2//!
3//! Joins `model_selection_events` with `inference_costs` via `turn_id` to produce
4//! flat, exportable rows that capture both the routing decision and its cost
5//! outcome. This is the training-data foundation for the shadow ML pipeline
6//! (Phase 2a) and the offline evaluation harness.
7
8use crate::Database;
9use roboticus_core::{RoboticusError, Result};
10use serde::Serialize;
11
12/// A single row in the routing dataset — one routing decision joined with its
13/// aggregated inference cost outcome.
14#[derive(Debug, Clone, Serialize)]
15pub struct RoutingDatasetRow {
16    // ── routing decision (from model_selection_events) ──
17    pub event_id: String,
18    pub turn_id: String,
19    pub session_id: String,
20    pub agent_id: String,
21    pub channel: String,
22    pub selected_model: String,
23    pub strategy: String,
24    pub primary_model: String,
25    pub override_model: Option<String>,
26    pub complexity: Option<String>,
27    pub user_excerpt: String,
28    pub candidates_json: String,
29    pub attribution: Option<String>,
30    pub metascore_json: Option<String>,
31    pub features_json: Option<String>,
32    pub schema_version: i64,
33    pub decision_at: String,
34
35    // ── cost outcome (aggregated from inference_costs) ──
36    pub total_tokens_in: i64,
37    pub total_tokens_out: i64,
38    pub total_cost: f64,
39    pub inference_count: i64,
40    pub any_cached: bool,
41    pub avg_latency_ms: Option<f64>,
42    pub avg_quality_score: Option<f64>,
43    pub any_escalation: bool,
44}
45
46/// Filter parameters for dataset extraction.
47#[derive(Debug, Clone, Default)]
48pub struct DatasetFilter {
49    /// Only include rows with `created_at >= since`.
50    pub since: Option<String>,
51    /// Only include rows with `created_at < until`.
52    pub until: Option<String>,
53    /// Only include rows at this schema version.
54    pub schema_version: Option<i64>,
55    /// Maximum rows to return (default: 10_000).
56    pub limit: Option<usize>,
57}
58
59/// Extract the routing dataset by joining routing decisions with cost outcomes.
60///
61/// Each row represents one routing decision with aggregated cost metrics from
62/// all inference calls made during that turn. Decisions with no matching
63/// inference costs are excluded (INNER JOIN) since they provide no cost signal.
64pub fn extract_routing_dataset(
65    db: &Database,
66    filter: &DatasetFilter,
67) -> Result<Vec<RoutingDatasetRow>> {
68    let conn = db.conn();
69
70    // Build WHERE clause dynamically
71    let mut conditions = Vec::new();
72    let mut params: Vec<Box<dyn rusqlite::types::ToSql>> = Vec::new();
73    let mut idx = 1;
74
75    if let Some(ref since) = filter.since {
76        conditions.push(format!("mse.created_at >= ?{idx}"));
77        params.push(Box::new(since.clone()));
78        idx += 1;
79    }
80    if let Some(ref until) = filter.until {
81        conditions.push(format!("mse.created_at < ?{idx}"));
82        params.push(Box::new(until.clone()));
83        idx += 1;
84    }
85    if let Some(sv) = filter.schema_version {
86        conditions.push(format!("mse.schema_version = ?{idx}"));
87        params.push(Box::new(sv));
88        idx += 1;
89    }
90
91    let limit = filter.limit.unwrap_or(10_000) as i64;
92    let limit_placeholder = format!("?{idx}");
93    params.push(Box::new(limit));
94
95    let where_clause = if conditions.is_empty() {
96        String::new()
97    } else {
98        format!("WHERE {}", conditions.join(" AND "))
99    };
100
101    let sql = format!(
102        "SELECT
103            mse.id,
104            mse.turn_id,
105            mse.session_id,
106            mse.agent_id,
107            mse.channel,
108            mse.selected_model,
109            mse.strategy,
110            mse.primary_model,
111            mse.override_model,
112            mse.complexity,
113            mse.user_excerpt,
114            mse.candidates_json,
115            mse.attribution,
116            mse.metascore_json,
117            mse.features_json,
118            mse.schema_version,
119            mse.created_at,
120            COALESCE(SUM(ic.tokens_in), 0)  AS total_tokens_in,
121            COALESCE(SUM(ic.tokens_out), 0) AS total_tokens_out,
122            COALESCE(SUM(ic.cost), 0.0)     AS total_cost,
123            COUNT(ic.id)                     AS inference_count,
124            COALESCE(MAX(ic.cached), 0)      AS any_cached,
125            AVG(ic.latency_ms)              AS avg_latency_ms,
126            AVG(ic.quality_score)           AS avg_quality_score,
127            COALESCE(MAX(ic.escalation), 0) AS any_escalation
128         FROM model_selection_events mse
129         INNER JOIN inference_costs ic ON ic.turn_id = mse.turn_id
130         {where_clause}
131         GROUP BY mse.id
132         ORDER BY mse.created_at ASC
133         LIMIT {limit_placeholder}"
134    );
135
136    let mut stmt = conn
137        .prepare(&sql)
138        .map_err(|e| RoboticusError::Database(format!("prepare routing dataset: {e}")))?;
139
140    let rows = stmt
141        .query_map(rusqlite::params_from_iter(params.iter()), |r| {
142            Ok(RoutingDatasetRow {
143                event_id: r.get(0)?,
144                turn_id: r.get(1)?,
145                session_id: r.get(2)?,
146                agent_id: r.get(3)?,
147                channel: r.get(4)?,
148                selected_model: r.get(5)?,
149                strategy: r.get(6)?,
150                primary_model: r.get(7)?,
151                override_model: r.get(8)?,
152                complexity: r.get(9)?,
153                user_excerpt: r.get(10)?,
154                candidates_json: r.get(11)?,
155                attribution: r.get(12)?,
156                metascore_json: r.get(13)?,
157                features_json: r.get(14)?,
158                schema_version: r.get(15)?,
159                decision_at: r.get(16)?,
160                total_tokens_in: r.get(17)?,
161                total_tokens_out: r.get(18)?,
162                total_cost: r.get(19)?,
163                inference_count: r.get(20)?,
164                any_cached: r.get::<_, i32>(21)? != 0,
165                avg_latency_ms: r.get(22)?,
166                avg_quality_score: r.get(23)?,
167                any_escalation: r.get::<_, i32>(24)? != 0,
168            })
169        })
170        .map_err(|e| RoboticusError::Database(format!("query routing dataset: {e}")))?
171        .collect::<std::result::Result<Vec<_>, _>>()
172        .map_err(|e| RoboticusError::Database(format!("collect routing dataset: {e}")))?;
173
174    Ok(rows)
175}
176
177/// Summary statistics for the extracted dataset.
178#[derive(Debug, Clone, Serialize)]
179pub struct DatasetSummary {
180    pub total_rows: usize,
181    pub distinct_models: usize,
182    pub distinct_strategies: usize,
183    pub total_cost: f64,
184    pub avg_cost_per_decision: f64,
185    pub schema_versions: Vec<i64>,
186}
187
188/// Compute summary statistics for the routing dataset.
189pub fn dataset_summary(db: &Database, filter: &DatasetFilter) -> Result<DatasetSummary> {
190    let mut summary_filter = filter.clone();
191    // Summary stats should represent the full filtered dataset, not a pagination cap.
192    summary_filter.limit = None;
193    let rows = extract_routing_dataset(db, &summary_filter)?;
194    if rows.is_empty() {
195        return Ok(DatasetSummary {
196            total_rows: 0,
197            distinct_models: 0,
198            distinct_strategies: 0,
199            total_cost: 0.0,
200            avg_cost_per_decision: 0.0,
201            schema_versions: vec![],
202        });
203    }
204
205    use std::collections::HashSet;
206    let models: HashSet<&str> = rows.iter().map(|r| r.selected_model.as_str()).collect();
207    let strategies: HashSet<&str> = rows.iter().map(|r| r.strategy.as_str()).collect();
208    let total_cost: f64 = rows.iter().map(|r| r.total_cost).sum();
209    let svs: HashSet<i64> = rows.iter().map(|r| r.schema_version).collect();
210    let mut sv_vec: Vec<i64> = svs.into_iter().collect();
211    sv_vec.sort();
212
213    Ok(DatasetSummary {
214        total_rows: rows.len(),
215        distinct_models: models.len(),
216        distinct_strategies: strategies.len(),
217        total_cost,
218        avg_cost_per_decision: total_cost / rows.len() as f64,
219        schema_versions: sv_vec,
220    })
221}
222
223/// Export the dataset as tab-separated values (header + rows).
224///
225/// TSV chosen over CSV because user_excerpt may contain commas.
226pub fn extract_routing_dataset_tsv(db: &Database, filter: &DatasetFilter) -> Result<String> {
227    let rows = extract_routing_dataset(db, filter)?;
228    let mut out = String::with_capacity(rows.len() * 256);
229
230    // Header
231    out.push_str(
232        "event_id\tturn_id\tsession_id\tagent_id\tchannel\tselected_model\tstrategy\t\
233                   primary_model\toverride_model\tcomplexity\tattribution\tschema_version\t\
234                   decision_at\ttotal_tokens_in\ttotal_tokens_out\ttotal_cost\tinference_count\t\
235                   any_cached\tavg_latency_ms\tavg_quality_score\tany_escalation\n",
236    );
237
238    for r in &rows {
239        out.push_str(&format!(
240            "{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{:.6}\t{}\t{}\t{}\t{}\t{}\n",
241            r.event_id,
242            r.turn_id,
243            r.session_id,
244            r.agent_id,
245            r.channel,
246            r.selected_model,
247            r.strategy,
248            r.primary_model,
249            r.override_model.as_deref().unwrap_or(""),
250            r.complexity.as_deref().unwrap_or(""),
251            r.attribution.as_deref().unwrap_or(""),
252            r.schema_version,
253            r.decision_at,
254            r.total_tokens_in,
255            r.total_tokens_out,
256            r.total_cost,
257            r.inference_count,
258            r.any_cached as i32,
259            r.avg_latency_ms.map_or("".to_string(), |v| format!("{v:.1}")),
260            r.avg_quality_score.map_or("".to_string(), |v| format!("{v:.4}")),
261            r.any_escalation as i32,
262        ));
263    }
264
265    Ok(out)
266}
267
268#[cfg(test)]
269mod tests {
270    use super::*;
271    use crate::metrics::record_inference_cost;
272    use crate::model_selection::{
273        ModelSelectionEventRow, ROUTING_SCHEMA_VERSION, record_model_selection_event,
274    };
275
276    fn test_db() -> Database {
277        Database::new(":memory:").unwrap()
278    }
279
280    fn insert_decision(
281        db: &Database,
282        event_id: &str,
283        turn_id: &str,
284        model: &str,
285        attribution: Option<&str>,
286        created_at: &str,
287    ) {
288        let evt = ModelSelectionEventRow {
289            id: event_id.to_string(),
290            turn_id: turn_id.to_string(),
291            session_id: "sess-ds".to_string(),
292            agent_id: "agent-ds".to_string(),
293            channel: "cli".to_string(),
294            selected_model: model.to_string(),
295            strategy: "complexity".to_string(),
296            primary_model: model.to_string(),
297            override_model: None,
298            complexity: Some("medium".to_string()),
299            user_excerpt: "test query".to_string(),
300            candidates_json: format!(r#"["{model}"]"#),
301            created_at: created_at.to_string(),
302            schema_version: ROUTING_SCHEMA_VERSION,
303            attribution: attribution.map(|s| s.to_string()),
304            metascore_json: None,
305            features_json: None,
306        };
307        record_model_selection_event(db, &evt).unwrap();
308    }
309
310    fn insert_cost(
311        db: &Database,
312        turn_id: &str,
313        model: &str,
314        tokens_in: i64,
315        tokens_out: i64,
316        cost: f64,
317    ) {
318        record_inference_cost(
319            db,
320            model,
321            "test-provider",
322            tokens_in,
323            tokens_out,
324            cost,
325            None,
326            false,
327            Some(100),
328            Some(0.85),
329            false,
330            Some(turn_id),
331        )
332        .unwrap();
333    }
334
335    #[test]
336    fn empty_dataset() {
337        let db = test_db();
338        let rows = extract_routing_dataset(&db, &DatasetFilter::default()).unwrap();
339        assert!(rows.is_empty());
340    }
341
342    #[test]
343    fn decision_without_cost_excluded() {
344        let db = test_db();
345        insert_decision(
346            &db,
347            "evt-1",
348            "turn-1",
349            "claude-4",
350            Some("metascore"),
351            "2025-06-01T00:00:00",
352        );
353        // No inference_cost for turn-1
354        let rows = extract_routing_dataset(&db, &DatasetFilter::default()).unwrap();
355        assert!(
356            rows.is_empty(),
357            "decisions with no cost should be excluded (INNER JOIN)"
358        );
359    }
360
361    #[test]
362    fn basic_join() {
363        let db = test_db();
364        insert_decision(
365            &db,
366            "evt-1",
367            "turn-1",
368            "claude-4",
369            Some("metascore"),
370            "2025-06-01T00:00:00",
371        );
372        insert_cost(&db, "turn-1", "claude-4", 1000, 500, 0.03);
373
374        let rows = extract_routing_dataset(&db, &DatasetFilter::default()).unwrap();
375        assert_eq!(rows.len(), 1);
376        let r = &rows[0];
377        assert_eq!(r.event_id, "evt-1");
378        assert_eq!(r.selected_model, "claude-4");
379        assert_eq!(r.total_tokens_in, 1000);
380        assert_eq!(r.total_tokens_out, 500);
381        assert!((r.total_cost - 0.03).abs() < 1e-9);
382        assert_eq!(r.inference_count, 1);
383        assert!(!r.any_cached);
384        assert!(r.avg_latency_ms.is_some());
385        assert!(r.avg_quality_score.is_some());
386    }
387
388    #[test]
389    fn multiple_costs_per_turn_aggregate() {
390        let db = test_db();
391        insert_decision(
392            &db,
393            "evt-agg",
394            "turn-agg",
395            "claude-4",
396            None,
397            "2025-06-01T00:00:00",
398        );
399        insert_cost(&db, "turn-agg", "claude-4", 500, 200, 0.01);
400        insert_cost(&db, "turn-agg", "claude-4", 300, 100, 0.005);
401
402        let rows = extract_routing_dataset(&db, &DatasetFilter::default()).unwrap();
403        assert_eq!(rows.len(), 1);
404        let r = &rows[0];
405        assert_eq!(r.total_tokens_in, 800);
406        assert_eq!(r.total_tokens_out, 300);
407        assert!((r.total_cost - 0.015).abs() < 1e-9);
408        assert_eq!(r.inference_count, 2);
409    }
410
411    #[test]
412    fn filter_since() {
413        let db = test_db();
414        insert_decision(
415            &db,
416            "evt-old",
417            "turn-old",
418            "gpt-4",
419            None,
420            "2024-01-01T00:00:00",
421        );
422        insert_cost(&db, "turn-old", "gpt-4", 100, 50, 0.01);
423        insert_decision(
424            &db,
425            "evt-new",
426            "turn-new",
427            "claude-4",
428            None,
429            "2025-06-01T00:00:00",
430        );
431        insert_cost(&db, "turn-new", "claude-4", 200, 100, 0.02);
432
433        let rows = extract_routing_dataset(
434            &db,
435            &DatasetFilter {
436                since: Some("2025-01-01T00:00:00".to_string()),
437                ..Default::default()
438            },
439        )
440        .unwrap();
441        assert_eq!(rows.len(), 1);
442        assert_eq!(rows[0].event_id, "evt-new");
443    }
444
445    #[test]
446    fn filter_until() {
447        let db = test_db();
448        insert_decision(
449            &db,
450            "evt-old",
451            "turn-old",
452            "gpt-4",
453            None,
454            "2024-01-01T00:00:00",
455        );
456        insert_cost(&db, "turn-old", "gpt-4", 100, 50, 0.01);
457        insert_decision(
458            &db,
459            "evt-new",
460            "turn-new",
461            "claude-4",
462            None,
463            "2025-06-01T00:00:00",
464        );
465        insert_cost(&db, "turn-new", "claude-4", 200, 100, 0.02);
466
467        let rows = extract_routing_dataset(
468            &db,
469            &DatasetFilter {
470                until: Some("2025-01-01T00:00:00".to_string()),
471                ..Default::default()
472            },
473        )
474        .unwrap();
475        assert_eq!(rows.len(), 1);
476        assert_eq!(rows[0].event_id, "evt-old");
477    }
478
479    #[test]
480    fn filter_schema_version() {
481        let db = test_db();
482        insert_decision(
483            &db,
484            "evt-v1",
485            "turn-v1",
486            "claude-4",
487            None,
488            "2025-06-01T00:00:00",
489        );
490        insert_cost(&db, "turn-v1", "claude-4", 100, 50, 0.01);
491
492        // Filter for a non-existent schema version
493        let rows = extract_routing_dataset(
494            &db,
495            &DatasetFilter {
496                schema_version: Some(99),
497                ..Default::default()
498            },
499        )
500        .unwrap();
501        assert!(rows.is_empty());
502
503        // Filter for the actual schema version
504        let rows = extract_routing_dataset(
505            &db,
506            &DatasetFilter {
507                schema_version: Some(ROUTING_SCHEMA_VERSION),
508                ..Default::default()
509            },
510        )
511        .unwrap();
512        assert_eq!(rows.len(), 1);
513    }
514
515    #[test]
516    fn filter_limit() {
517        let db = test_db();
518        for i in 0..5 {
519            let eid = format!("evt-lim-{i}");
520            let tid = format!("turn-lim-{i}");
521            insert_decision(
522                &db,
523                &eid,
524                &tid,
525                "claude-4",
526                None,
527                &format!("2025-06-0{i}T00:00:00"),
528            );
529            insert_cost(&db, &tid, "claude-4", 100, 50, 0.01);
530        }
531
532        let rows = extract_routing_dataset(
533            &db,
534            &DatasetFilter {
535                limit: Some(2),
536                ..Default::default()
537            },
538        )
539        .unwrap();
540        assert_eq!(rows.len(), 2);
541    }
542
543    #[test]
544    fn dataset_summary_empty() {
545        let db = test_db();
546        let s = dataset_summary(&db, &DatasetFilter::default()).unwrap();
547        assert_eq!(s.total_rows, 0);
548        assert_eq!(s.distinct_models, 0);
549    }
550
551    #[test]
552    fn dataset_summary_populated() {
553        let db = test_db();
554        insert_decision(
555            &db,
556            "evt-s1",
557            "turn-s1",
558            "claude-4",
559            Some("metascore"),
560            "2025-06-01T00:00:00",
561        );
562        insert_cost(&db, "turn-s1", "claude-4", 1000, 500, 0.03);
563        insert_decision(
564            &db,
565            "evt-s2",
566            "turn-s2",
567            "gpt-4",
568            Some("fallback"),
569            "2025-06-02T00:00:00",
570        );
571        insert_cost(&db, "turn-s2", "gpt-4", 500, 200, 0.01);
572
573        let s = dataset_summary(&db, &DatasetFilter::default()).unwrap();
574        assert_eq!(s.total_rows, 2);
575        assert_eq!(s.distinct_models, 2);
576        assert_eq!(s.distinct_strategies, 1); // both "complexity"
577        assert!((s.total_cost - 0.04).abs() < 1e-9);
578        assert!((s.avg_cost_per_decision - 0.02).abs() < 1e-9);
579        assert_eq!(s.schema_versions, vec![ROUTING_SCHEMA_VERSION]);
580    }
581
582    #[test]
583    fn dataset_summary_ignores_limit_cap() {
584        let db = test_db();
585        for i in 0..3 {
586            let eid = format!("evt-sum-{i}");
587            let tid = format!("turn-sum-{i}");
588            insert_decision(
589                &db,
590                &eid,
591                &tid,
592                "claude-4",
593                Some("metascore"),
594                &format!("2025-06-0{}T00:00:00", i + 1),
595            );
596            insert_cost(&db, &tid, "claude-4", 100, 50, 0.01);
597        }
598        let s = dataset_summary(
599            &db,
600            &DatasetFilter {
601                limit: Some(1),
602                ..Default::default()
603            },
604        )
605        .unwrap();
606        assert_eq!(s.total_rows, 3);
607    }
608
609    #[test]
610    fn tsv_export_header_and_rows() {
611        let db = test_db();
612        insert_decision(
613            &db,
614            "evt-tsv",
615            "turn-tsv",
616            "claude-4",
617            Some("primary_usable"),
618            "2025-06-01T00:00:00",
619        );
620        insert_cost(&db, "turn-tsv", "claude-4", 100, 50, 0.005);
621
622        let tsv = extract_routing_dataset_tsv(&db, &DatasetFilter::default()).unwrap();
623        let lines: Vec<&str> = tsv.lines().collect();
624        assert!(lines.len() >= 2, "should have header + at least 1 row");
625        assert!(lines[0].starts_with("event_id\t"));
626        assert!(lines[1].starts_with("evt-tsv\t"));
627        assert!(lines[1].contains("primary_usable"));
628    }
629
630    #[test]
631    fn ordering_is_ascending() {
632        let db = test_db();
633        insert_decision(
634            &db,
635            "evt-asc-2",
636            "turn-asc-2",
637            "claude-4",
638            None,
639            "2025-06-02T00:00:00",
640        );
641        insert_cost(&db, "turn-asc-2", "claude-4", 100, 50, 0.01);
642        insert_decision(
643            &db,
644            "evt-asc-1",
645            "turn-asc-1",
646            "claude-4",
647            None,
648            "2025-06-01T00:00:00",
649        );
650        insert_cost(&db, "turn-asc-1", "claude-4", 100, 50, 0.01);
651
652        let rows = extract_routing_dataset(&db, &DatasetFilter::default()).unwrap();
653        assert_eq!(
654            rows[0].event_id, "evt-asc-1",
655            "oldest first (ASC for training data)"
656        );
657        assert_eq!(rows[1].event_id, "evt-asc-2");
658    }
659}