Skip to main content

roboticus_db/
efficiency.rs

1use std::collections::HashMap;
2
3use serde::{Deserialize, Serialize};
4
5use crate::{Database, DbResultExt};
6use roboticus_core::Result;
7
8#[derive(Debug, Clone, Serialize, Deserialize)]
9pub struct MemoryImpact {
10    pub with_memory: f64,
11    pub without_memory: f64,
12}
13
14#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct QualityMetrics {
16    pub avg_grade: f64,
17    pub grade_count: i64,
18    pub grade_coverage: f64,
19    pub cost_per_quality_point: f64,
20    pub by_complexity: HashMap<String, f64>,
21    pub memory_impact: MemoryImpact,
22    pub trend: String,
23}
24
25#[derive(Debug, Clone, Serialize, Deserialize)]
26pub struct ModelEfficiency {
27    pub model: String,
28    pub total_turns: i64,
29    pub avg_output_density: f64,
30    pub avg_budget_utilization: f64,
31    pub avg_memory_roi: f64,
32    pub avg_system_prompt_weight: f64,
33    pub cache_hit_rate: f64,
34    pub context_pressure_rate: f64,
35    pub cost: CostMetrics,
36    pub trend: TrendMetrics,
37    pub quality: Option<QualityMetrics>,
38}
39
40#[derive(Debug, Clone, Serialize, Deserialize)]
41pub struct CostMetrics {
42    pub total: f64,
43    pub per_output_token: f64,
44    pub effective_per_turn: f64,
45    pub cache_savings: f64,
46    pub cumulative_trend: String,
47    pub attribution: CostAttribution,
48    pub wasted_budget_cost: f64,
49}
50
51#[derive(Debug, Clone, Serialize, Deserialize)]
52pub struct CostAttribution {
53    pub system_prompt: AttributionDetail,
54    pub memories: AttributionDetail,
55    pub history: AttributionDetail,
56}
57
58#[derive(Debug, Clone, Serialize, Deserialize)]
59pub struct AttributionDetail {
60    pub tokens: i64,
61    pub cost: f64,
62    pub pct: f64,
63}
64
65#[derive(Debug, Clone, Serialize, Deserialize)]
66pub struct TrendMetrics {
67    pub output_density: String,
68    pub cost_per_turn: String,
69    pub cache_hit_rate: String,
70}
71
72#[derive(Debug, Clone, Serialize, Deserialize)]
73pub struct TimeSeriesPoint {
74    pub bucket: String,
75    pub model: String,
76    pub output_density: f64,
77    pub cost: f64,
78    pub turns: i64,
79    pub budget_utilization: f64,
80    pub cached_count: i64,
81}
82
83#[derive(Debug, Clone, Serialize, Deserialize)]
84pub struct EfficiencyTotals {
85    pub total_cost: f64,
86    pub total_cache_savings: f64,
87    pub total_turns: i64,
88    pub most_expensive_model: Option<String>,
89    pub most_efficient_model: Option<String>,
90    pub biggest_cost_driver: String,
91}
92
93#[derive(Debug, Clone, Serialize, Deserialize)]
94pub struct EfficiencyReport {
95    pub period: String,
96    pub models: HashMap<String, ModelEfficiency>,
97    pub time_series: Vec<TimeSeriesPoint>,
98    pub totals: EfficiencyTotals,
99}
100
101/// Round a float to 6 decimal places to avoid floating-point display noise.
102fn round6(v: f64) -> f64 {
103    (v * 1_000_000.0).round() / 1_000_000.0
104}
105
106fn cutoff_expr(period: &str) -> &'static str {
107    match period {
108        "1h" => "datetime('now', '-1 hour')",
109        "24h" => "datetime('now', '-1 day')",
110        "7d" => "datetime('now', '-7 days')",
111        "30d" => "datetime('now', '-30 days')",
112        _ => "datetime('1970-01-01')",
113    }
114}
115
116struct RawModelRow {
117    model: String,
118    total_turns: i64,
119    avg_output_density: f64,
120    total_cost: f64,
121    total_tokens_out: i64,
122    total_tokens_in: i64,
123    cached_count: i64,
124    avg_cost_per_turn: f64,
125}
126
127fn trend_label(first_half: f64, second_half: f64) -> String {
128    if first_half == 0.0 && second_half == 0.0 {
129        return "stable".into();
130    }
131    let delta = second_half - first_half;
132    let base = first_half.max(0.001);
133    let pct = delta / base;
134    if pct > 0.05 {
135        "increasing".into()
136    } else if pct < -0.05 {
137        "decreasing".into()
138    } else {
139        "stable".into()
140    }
141}
142
143fn compute_quality_for_model(
144    conn: &rusqlite::Connection,
145    model: &str,
146    cutoff: &str,
147    total_turns: i64,
148) -> Option<QualityMetrics> {
149    let sql = format!(
150        "SELECT tf.grade, t.model, t.cost \
151         FROM turn_feedback tf \
152         JOIN turns t ON t.id = tf.turn_id \
153         WHERE t.model = ?1 AND tf.created_at >= {cutoff}"
154    );
155    let mut stmt = conn.prepare(&sql)
156        .inspect_err(|e| tracing::warn!(model, error = %e, "efficiency: failed to prepare quality metrics query"))
157        .ok()?;
158    let rows: Vec<(i32, String, f64)> = stmt
159        .query_map(rusqlite::params![model], |row| {
160            Ok((
161                row.get::<_, i32>(0)?,
162                row.get::<_, String>(1)?,
163                row.get::<_, f64>(2).unwrap_or(0.0),
164            ))
165        })
166        .inspect_err(|e| tracing::warn!(model, error = %e, "efficiency: failed to execute quality metrics query"))
167        .ok()?
168        .collect::<std::result::Result<Vec<_>, _>>()
169        .inspect_err(|e| tracing::warn!(model, error = %e, "efficiency: failed to collect quality metrics rows"))
170        .ok()?;
171
172    if rows.is_empty() {
173        return None;
174    }
175
176    let grade_count = rows.len() as i64;
177    let sum_grade: i64 = rows.iter().map(|(g, _, _)| *g as i64).sum();
178    let avg_grade = sum_grade as f64 / grade_count as f64;
179    let grade_coverage = if total_turns > 0 {
180        grade_count as f64 / total_turns as f64
181    } else {
182        0.0
183    };
184
185    let total_cost: f64 = rows.iter().map(|(_, _, c)| c).sum();
186    let total_quality: f64 = rows.iter().map(|(g, _, _)| *g as f64).sum();
187    let cost_per_quality_point = if total_quality > 0.0 {
188        total_cost / total_quality
189    } else {
190        0.0
191    };
192
193    let half = rows.len() / 2;
194    let trend = if rows.len() >= 4 {
195        let first_avg = rows[..half].iter().map(|(g, _, _)| *g as f64).sum::<f64>() / half as f64;
196        let second_avg = rows[half..].iter().map(|(g, _, _)| *g as f64).sum::<f64>()
197            / (rows.len() - half) as f64;
198        trend_label(first_avg, second_avg)
199    } else {
200        "stable".into()
201    };
202
203    Some(QualityMetrics {
204        avg_grade,
205        grade_count,
206        grade_coverage,
207        cost_per_quality_point,
208        by_complexity: HashMap::new(),
209        memory_impact: MemoryImpact {
210            with_memory: 0.0,
211            without_memory: 0.0,
212        },
213        trend,
214    })
215}
216
217pub fn compute_efficiency(
218    db: &Database,
219    period: &str,
220    model_filter: Option<&str>,
221) -> Result<EfficiencyReport> {
222    let cutoff = cutoff_expr(period);
223    let conn = db.conn();
224
225    let model_clause = if model_filter.is_some() {
226        " AND model = ?1"
227    } else {
228        ""
229    };
230
231    // ── Per-model aggregates ─────────────────────────────────
232    let main_sql = format!(
233        "SELECT \
234            model, \
235            COUNT(*) AS total_turns, \
236            AVG(CAST(tokens_out AS REAL) / NULLIF(tokens_in, 0)) AS avg_output_density, \
237            SUM(cost) AS total_cost, \
238            SUM(tokens_out) AS total_tokens_out, \
239            SUM(tokens_in) AS total_tokens_in, \
240            SUM(CASE WHEN cached = 1 THEN 1 ELSE 0 END) AS cached_count, \
241            AVG(cost) AS avg_cost_per_turn \
242         FROM inference_costs \
243         WHERE created_at >= {cutoff}{model_clause} \
244         GROUP BY model \
245         ORDER BY total_cost DESC"
246    );
247
248    let mut stmt = conn.prepare(&main_sql).db_err()?;
249
250    let map_row = |row: &rusqlite::Row| -> rusqlite::Result<RawModelRow> {
251        Ok(RawModelRow {
252            model: row.get(0)?,
253            total_turns: row.get(1)?,
254            avg_output_density: row.get::<_, Option<f64>>(2)?.unwrap_or(0.0),
255            total_cost: row.get::<_, Option<f64>>(3)?.unwrap_or(0.0),
256            total_tokens_out: row.get::<_, Option<i64>>(4)?.unwrap_or(0),
257            total_tokens_in: row.get::<_, Option<i64>>(5)?.unwrap_or(0),
258            cached_count: row.get::<_, Option<i64>>(6)?.unwrap_or(0),
259            avg_cost_per_turn: row.get::<_, Option<f64>>(7)?.unwrap_or(0.0),
260        })
261    };
262
263    let rows: Vec<RawModelRow> = if let Some(mf) = model_filter {
264        stmt.query_map(rusqlite::params![mf], map_row)
265    } else {
266        stmt.query_map([], map_row)
267    }
268    .db_err()?
269    .collect::<std::result::Result<Vec<_>, _>>()
270    .db_err()?;
271
272    // ── Time-series (daily buckets) ──────────────────────────
273    let ts_sql = format!(
274        "SELECT \
275            strftime('%Y-%m-%d', created_at) AS bucket, \
276            model, \
277            AVG(CAST(tokens_out AS REAL) / NULLIF(tokens_in, 0)) AS output_density, \
278            SUM(cost) AS cost, \
279            COUNT(*) AS turns, \
280            SUM(CASE WHEN cached = 1 THEN 1 ELSE 0 END) AS cached_count \
281         FROM inference_costs \
282         WHERE created_at >= {cutoff}{model_clause} \
283         GROUP BY bucket, model \
284         ORDER BY bucket"
285    );
286
287    let mut ts_stmt = conn.prepare(&ts_sql).db_err()?;
288
289    let ts_map = |row: &rusqlite::Row| -> rusqlite::Result<TimeSeriesPoint> {
290        Ok(TimeSeriesPoint {
291            bucket: row.get(0)?,
292            model: row.get(1)?,
293            output_density: row.get::<_, Option<f64>>(2)?.unwrap_or(0.0),
294            cost: row.get::<_, Option<f64>>(3)?.unwrap_or(0.0),
295            turns: row.get(4)?,
296            budget_utilization: 0.0,
297            cached_count: row.get::<_, Option<i64>>(5)?.unwrap_or(0),
298        })
299    };
300
301    let time_series: Vec<TimeSeriesPoint> = if let Some(mf) = model_filter {
302        ts_stmt.query_map(rusqlite::params![mf], ts_map)
303    } else {
304        ts_stmt.query_map([], ts_map)
305    }
306    .db_err()?
307    .collect::<std::result::Result<Vec<_>, _>>()
308    .db_err()?;
309
310    // ── Build per-model trend data from time series ──────────
311    let mut model_ts: HashMap<String, Vec<&TimeSeriesPoint>> = HashMap::new();
312    for pt in &time_series {
313        model_ts.entry(pt.model.clone()).or_default().push(pt);
314    }
315
316    // ── Assemble ModelEfficiency map ─────────────────────────
317    let mut models: HashMap<String, ModelEfficiency> = HashMap::new();
318    let mut grand_total_cost = 0.0_f64;
319    let mut grand_total_turns = 0_i64;
320    let mut most_expensive: Option<(String, f64)> = None;
321    let mut most_efficient: Option<(String, f64)> = None;
322
323    for r in &rows {
324        let cache_hit_rate = if r.total_turns > 0 {
325            r.cached_count as f64 / r.total_turns as f64
326        } else {
327            0.0
328        };
329
330        let per_output_token = if r.total_tokens_out > 0 {
331            r.total_cost / r.total_tokens_out as f64
332        } else {
333            0.0
334        };
335
336        // Estimate cache savings: cached requests would have cost roughly the
337        // average per-turn cost, so savings ≈ cached_count × avg_cost_per_turn × input_fraction.
338        let input_fraction = if r.total_tokens_in + r.total_tokens_out > 0 {
339            r.total_tokens_in as f64 / (r.total_tokens_in + r.total_tokens_out) as f64
340        } else {
341            0.5
342        };
343        let cache_savings = r.cached_count as f64 * r.avg_cost_per_turn * input_fraction;
344
345        // Trends from time-series split
346        let pts = model_ts.get(&r.model).cloned().unwrap_or_default();
347        let trend = if pts.len() >= 2 {
348            let mid = pts.len() / 2;
349            let (first, second) = pts.split_at(mid);
350
351            let avg = |slice: &[&TimeSeriesPoint], f: fn(&TimeSeriesPoint) -> f64| -> f64 {
352                if slice.is_empty() {
353                    return 0.0;
354                }
355                slice.iter().map(|p| f(p)).sum::<f64>() / slice.len() as f64
356            };
357
358            let first_density = avg(first, |p| p.output_density);
359            let second_density = avg(second, |p| p.output_density);
360
361            let first_cpt = avg(first, |p| {
362                if p.turns > 0 {
363                    p.cost / p.turns as f64
364                } else {
365                    0.0
366                }
367            });
368            let second_cpt = avg(second, |p| {
369                if p.turns > 0 {
370                    p.cost / p.turns as f64
371                } else {
372                    0.0
373                }
374            });
375
376            let first_cache = avg(first, |p| {
377                if p.turns > 0 {
378                    p.cached_count as f64 / p.turns as f64
379                } else {
380                    0.0
381                }
382            });
383            let second_cache = avg(second, |p| {
384                if p.turns > 0 {
385                    p.cached_count as f64 / p.turns as f64
386                } else {
387                    0.0
388                }
389            });
390
391            TrendMetrics {
392                output_density: trend_label(first_density, second_density),
393                cost_per_turn: trend_label(first_cpt, second_cpt),
394                cache_hit_rate: trend_label(first_cache, second_cache),
395            }
396        } else {
397            TrendMetrics {
398                output_density: "stable".into(),
399                cost_per_turn: "stable".into(),
400                cache_hit_rate: "stable".into(),
401            }
402        };
403
404        let cumulative_trend = trend.cost_per_turn.clone();
405
406        // Without context_snapshots, attribute all input tokens to "history".
407        let attribution = CostAttribution {
408            system_prompt: AttributionDetail {
409                tokens: 0,
410                cost: 0.0,
411                pct: 0.0,
412            },
413            memories: AttributionDetail {
414                tokens: 0,
415                cost: 0.0,
416                pct: 0.0,
417            },
418            history: AttributionDetail {
419                tokens: r.total_tokens_in,
420                cost: r.total_cost * input_fraction,
421                pct: 100.0,
422            },
423        };
424
425        let quality = compute_quality_for_model(&conn, &r.model, cutoff, r.total_turns);
426
427        let eff = ModelEfficiency {
428            model: r.model.clone(),
429            total_turns: r.total_turns,
430            avg_output_density: r.avg_output_density,
431            avg_budget_utilization: 0.0,
432            avg_memory_roi: 0.0,
433            avg_system_prompt_weight: 0.0,
434            cache_hit_rate,
435            context_pressure_rate: 0.0,
436            cost: CostMetrics {
437                total: round6(r.total_cost),
438                per_output_token: round6(per_output_token),
439                effective_per_turn: round6(r.avg_cost_per_turn),
440                cache_savings: round6(cache_savings),
441                cumulative_trend,
442                attribution,
443                wasted_budget_cost: 0.0,
444            },
445            trend,
446            quality,
447        };
448
449        grand_total_cost += r.total_cost;
450        grand_total_turns += r.total_turns;
451
452        match &most_expensive {
453            None => most_expensive = Some((r.model.clone(), r.total_cost)),
454            Some((_, c)) if r.total_cost > *c => {
455                most_expensive = Some((r.model.clone(), r.total_cost));
456            }
457            _ => {}
458        }
459
460        let density = r.avg_output_density;
461        match &most_efficient {
462            None => most_efficient = Some((r.model.clone(), density)),
463            Some((_, d)) if density > *d => {
464                most_efficient = Some((r.model.clone(), density));
465            }
466            _ => {}
467        }
468
469        models.insert(r.model.clone(), eff);
470    }
471
472    let total_cache_savings: f64 = models.values().map(|m| m.cost.cache_savings).sum();
473
474    let biggest_cost_driver = most_expensive
475        .as_ref()
476        .map(|(m, _)| m.clone())
477        .unwrap_or_else(|| "none".into());
478
479    let totals = EfficiencyTotals {
480        total_cost: round6(grand_total_cost),
481        total_cache_savings: round6(total_cache_savings),
482        total_turns: grand_total_turns,
483        most_expensive_model: most_expensive.map(|(m, _)| m),
484        most_efficient_model: most_efficient.map(|(m, _)| m),
485        biggest_cost_driver,
486    };
487
488    Ok(EfficiencyReport {
489        period: period.to_string(),
490        models,
491        time_series,
492        totals,
493    })
494}
495
496// ── UserProfile types for recommendations ────────────────────
497
498#[derive(Debug, Clone, Serialize, Deserialize)]
499pub struct RecommendationModelStats {
500    pub turns: i64,
501    pub avg_cost: f64,
502    pub avg_quality: Option<f64>,
503    pub cache_hit_rate: f64,
504    pub avg_output_density: f64,
505}
506
507#[derive(Debug, Clone, Serialize, Deserialize)]
508pub struct RecommendationUserProfile {
509    pub total_sessions: i64,
510    pub total_turns: i64,
511    pub total_cost: f64,
512    pub avg_quality: Option<f64>,
513    pub grade_coverage: f64,
514    pub models_used: Vec<String>,
515    pub model_stats: HashMap<String, RecommendationModelStats>,
516    pub avg_session_length: f64,
517    pub avg_tokens_per_turn: f64,
518    pub tool_success_rate: f64,
519    pub cache_hit_rate: f64,
520    pub memory_retrieval_rate: f64,
521}
522
523pub fn build_user_profile(db: &Database, period: &str) -> Result<RecommendationUserProfile> {
524    let cutoff = cutoff_expr(period);
525    let conn = db.conn();
526
527    let (total_sessions, avg_session_length): (i64, f64) = conn
528        .query_row(
529            &format!(
530                "SELECT COUNT(*), COALESCE(AVG(msg_count), 0) FROM (\
531                   SELECT s.id, COUNT(m.id) AS msg_count \
532                   FROM sessions s \
533                   LEFT JOIN session_messages m ON m.session_id = s.id \
534                   WHERE s.created_at >= {cutoff} \
535                   GROUP BY s.id\
536                 )"
537            ),
538            [],
539            |row| Ok((row.get(0)?, row.get(1)?)),
540        )
541        .db_err()?;
542
543    let (total_turns, total_cost, avg_tokens_per_turn, cache_hit_rate): (i64, f64, f64, f64) = conn
544        .query_row(
545            &format!(
546                "SELECT \
547                   COUNT(*), \
548                   COALESCE(SUM(cost), 0), \
549                   COALESCE(AVG(tokens_in + tokens_out), 0), \
550                   CASE WHEN COUNT(*) > 0 \
551                     THEN CAST(SUM(CASE WHEN cached = 1 THEN 1 ELSE 0 END) AS REAL) / COUNT(*) \
552                     ELSE 0.0 END \
553                 FROM inference_costs \
554                 WHERE created_at >= {cutoff}"
555            ),
556            [],
557            |row| Ok((row.get(0)?, row.get(1)?, row.get(2)?, row.get(3)?)),
558        )
559        .db_err()?;
560
561    let mut model_stmt = conn
562        .prepare(&format!(
563            "SELECT \
564               model, \
565               COUNT(*) AS turns, \
566               AVG(cost) AS avg_cost, \
567               CASE WHEN COUNT(*) > 0 \
568                 THEN CAST(SUM(CASE WHEN cached = 1 THEN 1 ELSE 0 END) AS REAL) / COUNT(*) \
569                 ELSE 0.0 END AS cache_rate, \
570               AVG(CAST(tokens_out AS REAL) / NULLIF(tokens_in, 0)) AS avg_density \
571             FROM inference_costs \
572             WHERE created_at >= {cutoff} \
573             GROUP BY model \
574             ORDER BY turns DESC"
575        ))
576        .db_err()?;
577
578    let mut models_used = Vec::new();
579    let mut model_stats = HashMap::new();
580
581    let rows = model_stmt
582        .query_map([], |row| {
583            Ok((
584                row.get::<_, String>(0)?,
585                row.get::<_, i64>(1)?,
586                row.get::<_, f64>(2)?,
587                row.get::<_, f64>(3)?,
588                row.get::<_, Option<f64>>(4)?,
589            ))
590        })
591        .db_err()?;
592
593    for row in rows {
594        let (model, turns, avg_cost, cache_rate, avg_density) = row.db_err()?;
595        models_used.push(model.clone());
596        model_stats.insert(
597            model,
598            RecommendationModelStats {
599                turns,
600                avg_cost,
601                avg_quality: None,
602                cache_hit_rate: cache_rate,
603                avg_output_density: avg_density.unwrap_or(0.0),
604            },
605        );
606    }
607
608    let tool_success_rate: f64 = conn
609        .query_row(
610            &format!(
611                "SELECT CASE WHEN COUNT(*) > 0 \
612                   THEN CAST(SUM(CASE WHEN status = 'success' THEN 1 ELSE 0 END) AS REAL) / COUNT(*) \
613                   ELSE 1.0 END \
614                 FROM tool_calls WHERE created_at >= {cutoff}"
615            ),
616            [],
617            |row| row.get(0),
618        )
619        .db_err()?;
620
621    let (graded_turns, avg_quality): (i64, Option<f64>) = conn
622        .query_row(
623            &format!(
624                "SELECT COUNT(*), AVG(CAST(tf.grade AS REAL)) \
625                 FROM turn_feedback tf \
626                 JOIN turns t ON t.id = tf.turn_id \
627                 JOIN sessions s ON s.id = t.session_id \
628                 WHERE s.created_at >= {cutoff}"
629            ),
630            [],
631            |row| Ok((row.get(0)?, row.get(1)?)),
632        )
633        .unwrap_or((0, None));
634
635    let grade_coverage = if total_turns > 0 {
636        graded_turns as f64 / total_turns as f64
637    } else {
638        0.0
639    };
640
641    // Compute memory retrieval rate from context_snapshots if available,
642    // otherwise default to 0.5
643    let memory_retrieval_rate: f64 = conn
644        .query_row(
645            &format!(
646                "SELECT CASE WHEN COUNT(*) > 0 \
647                   THEN CAST(SUM(CASE WHEN memory_tokens > 0 THEN 1 ELSE 0 END) AS REAL) / COUNT(*) \
648                   ELSE 0.5 END \
649                 FROM context_snapshots WHERE created_at >= {cutoff}"
650            ),
651            [],
652            |row| row.get(0),
653        )
654        .unwrap_or(0.5);
655
656    Ok(RecommendationUserProfile {
657        total_sessions,
658        total_turns,
659        total_cost,
660        avg_quality,
661        grade_coverage,
662        models_used,
663        model_stats,
664        avg_session_length,
665        avg_tokens_per_turn,
666        tool_success_rate,
667        cache_hit_rate,
668        memory_retrieval_rate,
669    })
670}
671
672#[cfg(test)]
673mod tests {
674    use super::*;
675    use crate::metrics::record_inference_cost;
676
677    fn test_db() -> Database {
678        Database::new(":memory:").unwrap()
679    }
680
681    #[test]
682    fn empty_database_returns_empty_report() {
683        let db = test_db();
684        let report = compute_efficiency(&db, "7d", None).unwrap();
685        assert!(report.models.is_empty());
686        assert!(report.time_series.is_empty());
687        assert_eq!(report.totals.total_turns, 0);
688        assert_eq!(report.totals.total_cost, 0.0);
689    }
690
691    #[test]
692    fn single_model_report() {
693        let db = test_db();
694        record_inference_cost(
695            &db,
696            "claude-4",
697            "anthropic",
698            1000,
699            500,
700            0.015,
701            Some("T1"),
702            false,
703            None,
704            None,
705            false,
706            None,
707        )
708        .unwrap();
709        record_inference_cost(
710            &db,
711            "claude-4",
712            "anthropic",
713            2000,
714            800,
715            0.025,
716            Some("T1"),
717            true,
718            None,
719            None,
720            false,
721            None,
722        )
723        .unwrap();
724
725        let report = compute_efficiency(&db, "all", None).unwrap();
726        assert_eq!(report.models.len(), 1);
727        let m = &report.models["claude-4"];
728        assert_eq!(m.total_turns, 2);
729        assert!(m.avg_output_density > 0.0);
730        assert!((m.cost.total - 0.04).abs() < 1e-9);
731        assert_eq!(m.cache_hit_rate, 0.5);
732    }
733
734    #[test]
735    fn model_filter_works() {
736        let db = test_db();
737        record_inference_cost(
738            &db,
739            "claude-4",
740            "anthropic",
741            100,
742            50,
743            0.01,
744            None,
745            false,
746            None,
747            None,
748            false,
749            None,
750        )
751        .unwrap();
752        record_inference_cost(
753            &db, "gpt-4", "openai", 200, 100, 0.02, None, false, None, None, false, None,
754        )
755        .unwrap();
756
757        let report = compute_efficiency(&db, "all", Some("gpt-4")).unwrap();
758        assert_eq!(report.models.len(), 1);
759        assert!(report.models.contains_key("gpt-4"));
760    }
761
762    #[test]
763    fn multiple_models_totals() {
764        let db = test_db();
765        record_inference_cost(
766            &db,
767            "claude-4",
768            "anthropic",
769            100,
770            50,
771            0.01,
772            None,
773            false,
774            None,
775            None,
776            false,
777            None,
778        )
779        .unwrap();
780        record_inference_cost(
781            &db, "gpt-4", "openai", 200, 100, 0.02, None, false, None, None, false, None,
782        )
783        .unwrap();
784
785        let report = compute_efficiency(&db, "all", None).unwrap();
786        assert_eq!(report.totals.total_turns, 2);
787        assert!((report.totals.total_cost - 0.03).abs() < 1e-9);
788        assert_eq!(report.totals.most_expensive_model.as_deref(), Some("gpt-4"));
789    }
790
791    #[test]
792    fn time_series_has_entries() {
793        let db = test_db();
794        record_inference_cost(
795            &db,
796            "claude-4",
797            "anthropic",
798            100,
799            50,
800            0.01,
801            None,
802            false,
803            None,
804            None,
805            false,
806            None,
807        )
808        .unwrap();
809
810        let report = compute_efficiency(&db, "all", None).unwrap();
811        assert!(!report.time_series.is_empty());
812        assert_eq!(report.time_series[0].model, "claude-4");
813    }
814
815    #[test]
816    fn trend_label_logic() {
817        assert_eq!(trend_label(1.0, 1.5), "increasing");
818        assert_eq!(trend_label(1.0, 0.5), "decreasing");
819        assert_eq!(trend_label(1.0, 1.02), "stable");
820        assert_eq!(trend_label(0.0, 0.0), "stable");
821    }
822
823    #[test]
824    fn all_cached_full_rate() {
825        let db = test_db();
826        record_inference_cost(
827            &db, "m1", "p1", 100, 50, 0.01, None, true, None, None, false, None,
828        )
829        .unwrap();
830        record_inference_cost(
831            &db, "m1", "p1", 100, 50, 0.01, None, true, None, None, false, None,
832        )
833        .unwrap();
834
835        let report = compute_efficiency(&db, "all", None).unwrap();
836        assert_eq!(report.models["m1"].cache_hit_rate, 1.0);
837    }
838
839    #[test]
840    fn period_all_vs_default() {
841        assert_eq!(cutoff_expr("all"), "datetime('1970-01-01')");
842        assert_eq!(cutoff_expr("7d"), "datetime('now', '-7 days')");
843        assert_eq!(cutoff_expr("unknown"), "datetime('1970-01-01')");
844    }
845
846    #[test]
847    fn zero_tokens_in_no_division_by_zero() {
848        let db = test_db();
849        record_inference_cost(
850            &db, "m1", "p1", 0, 50, 0.01, None, false, None, None, false, None,
851        )
852        .unwrap();
853
854        let report = compute_efficiency(&db, "all", None).unwrap();
855        let m = &report.models["m1"];
856        assert!(m.avg_output_density.is_finite());
857    }
858
859    #[test]
860    fn cutoff_expr_1h() {
861        assert_eq!(cutoff_expr("1h"), "datetime('now', '-1 hour')");
862    }
863
864    #[test]
865    fn cutoff_expr_24h() {
866        assert_eq!(cutoff_expr("24h"), "datetime('now', '-1 day')");
867    }
868
869    #[test]
870    fn cutoff_expr_30d() {
871        assert_eq!(cutoff_expr("30d"), "datetime('now', '-30 days')");
872    }
873
874    #[test]
875    fn trend_label_edge_cases() {
876        // Clearly within stable band (3% change)
877        assert_eq!(trend_label(1.0, 1.03), "stable");
878        // Clearly above threshold (10% increase)
879        assert_eq!(trend_label(1.0, 1.10), "increasing");
880        // Clearly within stable band (3% decrease)
881        assert_eq!(trend_label(1.0, 0.97), "stable");
882        // Clearly below threshold (10% decrease)
883        assert_eq!(trend_label(1.0, 0.90), "decreasing");
884        // First half near zero uses 0.001 as base
885        assert_eq!(trend_label(0.001, 0.5), "increasing");
886        // Both near zero
887        assert_eq!(trend_label(0.0001, 0.0001), "stable");
888    }
889
890    #[test]
891    fn zero_tokens_out_no_division_by_zero() {
892        let db = test_db();
893        record_inference_cost(
894            &db, "m1", "p1", 1000, 0, 0.01, None, false, None, None, false, None,
895        )
896        .unwrap();
897
898        let report = compute_efficiency(&db, "all", None).unwrap();
899        let m = &report.models["m1"];
900        assert_eq!(m.cost.per_output_token, 0.0);
901        assert!(m.cost.total.is_finite());
902    }
903
904    #[test]
905    fn no_cached_zero_rate() {
906        let db = test_db();
907        record_inference_cost(
908            &db, "m1", "p1", 100, 50, 0.01, None, false, None, None, false, None,
909        )
910        .unwrap();
911        record_inference_cost(
912            &db, "m1", "p1", 100, 50, 0.01, None, false, None, None, false, None,
913        )
914        .unwrap();
915
916        let report = compute_efficiency(&db, "all", None).unwrap();
917        assert_eq!(report.models["m1"].cache_hit_rate, 0.0);
918    }
919
920    #[test]
921    fn multiple_models_identifies_most_efficient() {
922        let db = test_db();
923        // m1: 1000 in, 500 out -> density ~0.5
924        record_inference_cost(
925            &db, "m1", "p1", 1000, 500, 0.01, None, false, None, None, false, None,
926        )
927        .unwrap();
928        // m2: 100 in, 200 out -> density ~2.0  (more efficient)
929        record_inference_cost(
930            &db, "m2", "p2", 100, 200, 0.005, None, false, None, None, false, None,
931        )
932        .unwrap();
933
934        let report = compute_efficiency(&db, "all", None).unwrap();
935        assert_eq!(
936            report.totals.most_efficient_model.as_deref(),
937            Some("m2"),
938            "m2 has higher output density"
939        );
940    }
941
942    #[test]
943    fn trend_metrics_with_time_series() {
944        let db = test_db();
945        // Insert enough records over multiple days to generate time-series buckets
946        let conn = db.conn();
947        for i in 0..6 {
948            let day = format!("2025-01-{:02}T12:00:00", i + 1);
949            conn.execute(
950                "INSERT INTO inference_costs (id, model, provider, tokens_in, tokens_out, cost, cached, created_at) \
951                 VALUES (?1, 'claude-4', 'anthropic', ?2, ?3, ?4, 0, ?5)",
952                rusqlite::params![
953                    format!("ic-{i}"),
954                    1000 + i * 100,
955                    500 + i * 50,
956                    0.01 + i as f64 * 0.005,
957                    day,
958                ],
959            )
960            .unwrap();
961        }
962        drop(conn);
963
964        let report = compute_efficiency(&db, "all", None).unwrap();
965        let m = &report.models["claude-4"];
966        // With 6 data points over 6 different days, we should have time series data
967        assert!(report.time_series.len() >= 2);
968        // Trend should be computed (not just "stable")
969        assert!(!m.trend.output_density.is_empty());
970        assert!(!m.trend.cost_per_turn.is_empty());
971        assert!(!m.trend.cache_hit_rate.is_empty());
972    }
973
974    #[test]
975    fn build_user_profile_empty_db() {
976        let db = test_db();
977        let profile = build_user_profile(&db, "7d").unwrap();
978        assert_eq!(profile.total_sessions, 0);
979        assert_eq!(profile.total_turns, 0);
980        assert_eq!(profile.total_cost, 0.0);
981        assert!(profile.models_used.is_empty());
982        assert!(profile.model_stats.is_empty());
983        assert_eq!(profile.avg_session_length, 0.0);
984        assert_eq!(profile.avg_tokens_per_turn, 0.0);
985        assert_eq!(profile.tool_success_rate, 1.0); // No tools => default 1.0
986    }
987
988    #[test]
989    fn build_user_profile_with_data() {
990        let db = test_db();
991        let conn = db.conn();
992        // Create a session
993        conn.execute(
994            "INSERT INTO sessions (id, agent_id, scope_key, status) VALUES ('s1', 'agent-1', 'agent', 'active')",
995            [],
996        )
997        .unwrap();
998        // Create messages for the session
999        conn.execute(
1000            "INSERT INTO session_messages (id, session_id, role, content) VALUES ('m1', 's1', 'user', 'hello')",
1001            [],
1002        )
1003        .unwrap();
1004        conn.execute(
1005            "INSERT INTO session_messages (id, session_id, role, content) VALUES ('m2', 's1', 'assistant', 'hi')",
1006            [],
1007        )
1008        .unwrap();
1009        drop(conn);
1010
1011        // Add inference costs
1012        record_inference_cost(
1013            &db,
1014            "claude-4",
1015            "anthropic",
1016            1000,
1017            500,
1018            0.015,
1019            Some("T1"),
1020            false,
1021            None,
1022            None,
1023            false,
1024            None,
1025        )
1026        .unwrap();
1027        record_inference_cost(
1028            &db,
1029            "claude-4",
1030            "anthropic",
1031            2000,
1032            800,
1033            0.025,
1034            Some("T1"),
1035            true,
1036            None,
1037            None,
1038            false,
1039            None,
1040        )
1041        .unwrap();
1042        record_inference_cost(
1043            &db, "gpt-4", "openai", 500, 200, 0.01, None, false, None, None, false, None,
1044        )
1045        .unwrap();
1046
1047        // Add a tool call
1048        {
1049            let conn = db.conn();
1050            conn.execute("INSERT INTO turns (id, session_id) VALUES ('t1', 's1')", [])
1051                .unwrap();
1052            conn.execute(
1053                "INSERT INTO tool_calls (id, turn_id, tool_name, input, status) VALUES ('tc1', 't1', 'bash', '{}', 'success')",
1054                [],
1055            )
1056            .unwrap();
1057            conn.execute(
1058                "INSERT INTO tool_calls (id, turn_id, tool_name, input, status) VALUES ('tc2', 't1', 'bash', '{}', 'error')",
1059                [],
1060            )
1061            .unwrap();
1062        }
1063
1064        let profile = build_user_profile(&db, "all").unwrap();
1065        assert_eq!(profile.total_sessions, 1);
1066        assert_eq!(profile.total_turns, 3);
1067        assert!((profile.total_cost - 0.05).abs() < 1e-9);
1068        assert!(profile.models_used.contains(&"claude-4".to_string()));
1069        assert!(profile.models_used.contains(&"gpt-4".to_string()));
1070        assert_eq!(profile.model_stats.len(), 2);
1071        assert_eq!(profile.model_stats["claude-4"].turns, 2);
1072        assert_eq!(profile.model_stats["gpt-4"].turns, 1);
1073        assert_eq!(profile.avg_session_length, 2.0); // 2 messages
1074        assert!(profile.avg_tokens_per_turn > 0.0);
1075        assert_eq!(profile.tool_success_rate, 0.5); // 1 success out of 2
1076        assert!(profile.cache_hit_rate > 0.0);
1077    }
1078
1079    #[test]
1080    fn build_user_profile_grade_coverage() {
1081        let db = test_db();
1082        let conn = db.conn();
1083        conn.execute(
1084            "INSERT INTO sessions (id, agent_id, scope_key, status) VALUES ('s1', 'agent-1', 'agent', 'active')",
1085            [],
1086        )
1087        .unwrap();
1088        conn.execute("INSERT INTO turns (id, session_id) VALUES ('t1', 's1')", [])
1089            .unwrap();
1090        conn.execute(
1091            "INSERT INTO turn_feedback (id, turn_id, session_id, grade, source) VALUES ('tf1', 't1', 's1', 4, 'dashboard')",
1092            [],
1093        )
1094        .unwrap();
1095        drop(conn);
1096
1097        record_inference_cost(
1098            &db, "m1", "p1", 100, 50, 0.01, None, false, None, None, false, None,
1099        )
1100        .unwrap();
1101        record_inference_cost(
1102            &db, "m1", "p1", 100, 50, 0.01, None, false, None, None, false, None,
1103        )
1104        .unwrap();
1105
1106        let profile = build_user_profile(&db, "all").unwrap();
1107        assert!(profile.avg_quality.is_some());
1108        assert!((profile.avg_quality.unwrap() - 4.0).abs() < 1e-9);
1109        assert!(profile.grade_coverage > 0.0);
1110        assert!(profile.grade_coverage <= 1.0);
1111    }
1112
1113    #[test]
1114    fn compute_quality_for_model_with_feedback() {
1115        let db = test_db();
1116        let conn = db.conn();
1117        conn.execute(
1118            "INSERT INTO sessions (id, agent_id, scope_key, status) VALUES ('s1', 'a1', 'agent', 'active')",
1119            [],
1120        )
1121        .unwrap();
1122        conn.execute(
1123            "INSERT INTO turns (id, session_id, model, cost) VALUES ('t1', 's1', 'claude-4', 0.01)",
1124            [],
1125        )
1126        .unwrap();
1127        conn.execute(
1128            "INSERT INTO turns (id, session_id, model, cost) VALUES ('t2', 's1', 'claude-4', 0.02)",
1129            [],
1130        )
1131        .unwrap();
1132        conn.execute(
1133            "INSERT INTO turns (id, session_id, model, cost) VALUES ('t3', 's1', 'claude-4', 0.015)",
1134            [],
1135        )
1136        .unwrap();
1137        conn.execute(
1138            "INSERT INTO turns (id, session_id, model, cost) VALUES ('t4', 's1', 'claude-4', 0.025)",
1139            [],
1140        )
1141        .unwrap();
1142        // Add feedback for all 4 turns
1143        conn.execute(
1144            "INSERT INTO turn_feedback (id, turn_id, session_id, grade) VALUES ('f1', 't1', 's1', 3)",
1145            [],
1146        )
1147        .unwrap();
1148        conn.execute(
1149            "INSERT INTO turn_feedback (id, turn_id, session_id, grade) VALUES ('f2', 't2', 's1', 4)",
1150            [],
1151        )
1152        .unwrap();
1153        conn.execute(
1154            "INSERT INTO turn_feedback (id, turn_id, session_id, grade) VALUES ('f3', 't3', 's1', 5)",
1155            [],
1156        )
1157        .unwrap();
1158        conn.execute(
1159            "INSERT INTO turn_feedback (id, turn_id, session_id, grade) VALUES ('f4', 't4', 's1', 5)",
1160            [],
1161        )
1162        .unwrap();
1163
1164        let quality = compute_quality_for_model(&conn, "claude-4", "datetime('1970-01-01')", 4);
1165        drop(conn);
1166
1167        assert!(quality.is_some());
1168        let q = quality.unwrap();
1169        assert_eq!(q.grade_count, 4);
1170        assert!((q.avg_grade - 4.25).abs() < 1e-9);
1171        assert_eq!(q.grade_coverage, 1.0);
1172        assert!(q.cost_per_quality_point > 0.0);
1173        // With 4 feedback entries and improvement from first half (3,4) to second half (5,5),
1174        // trend should be "increasing"
1175        assert_eq!(q.trend, "increasing");
1176    }
1177
1178    #[test]
1179    fn compute_quality_for_model_no_feedback() {
1180        let db = test_db();
1181        let conn = db.conn();
1182        let quality = compute_quality_for_model(&conn, "claude-4", "datetime('1970-01-01')", 10);
1183        drop(conn);
1184        assert!(quality.is_none());
1185    }
1186
1187    #[test]
1188    fn compute_quality_few_entries_stable_trend() {
1189        let db = test_db();
1190        let conn = db.conn();
1191        conn.execute(
1192            "INSERT INTO sessions (id, agent_id, scope_key, status) VALUES ('s1', 'a1', 'agent', 'active')",
1193            [],
1194        )
1195        .unwrap();
1196        conn.execute(
1197            "INSERT INTO turns (id, session_id, model, cost) VALUES ('t1', 's1', 'claude-4', 0.01)",
1198            [],
1199        )
1200        .unwrap();
1201        conn.execute(
1202            "INSERT INTO turn_feedback (id, turn_id, session_id, grade) VALUES ('f1', 't1', 's1', 4)",
1203            [],
1204        )
1205        .unwrap();
1206
1207        let quality = compute_quality_for_model(&conn, "claude-4", "datetime('1970-01-01')", 1);
1208        drop(conn);
1209
1210        assert!(quality.is_some());
1211        let q = quality.unwrap();
1212        // With fewer than 4 entries, trend should be "stable"
1213        assert_eq!(q.trend, "stable");
1214    }
1215
1216    #[test]
1217    fn report_cost_attribution_all_history() {
1218        let db = test_db();
1219        record_inference_cost(
1220            &db, "m1", "p1", 1000, 500, 0.03, None, false, None, None, false, None,
1221        )
1222        .unwrap();
1223
1224        let report = compute_efficiency(&db, "all", None).unwrap();
1225        let m = &report.models["m1"];
1226        // Without context_snapshots, all input tokens are attributed to "history"
1227        assert_eq!(m.cost.attribution.history.pct, 100.0);
1228        assert!(m.cost.attribution.history.tokens > 0);
1229        assert_eq!(m.cost.attribution.system_prompt.tokens, 0);
1230        assert_eq!(m.cost.attribution.memories.tokens, 0);
1231    }
1232
1233    #[test]
1234    fn report_biggest_cost_driver_with_no_models() {
1235        let db = test_db();
1236        let report = compute_efficiency(&db, "all", None).unwrap();
1237        assert_eq!(report.totals.biggest_cost_driver, "none");
1238        assert!(report.totals.most_expensive_model.is_none());
1239        assert!(report.totals.most_efficient_model.is_none());
1240    }
1241
1242    #[test]
1243    fn build_user_profile_memory_retrieval_default() {
1244        let db = test_db();
1245        let profile = build_user_profile(&db, "all").unwrap();
1246        // Without context_snapshots, memory_retrieval_rate defaults to 0.5
1247        assert!((profile.memory_retrieval_rate - 0.5).abs() < 1e-9);
1248    }
1249}