Skip to main content

roboticus_db/
efficiency.rs

1use std::collections::HashMap;
2
3use serde::{Deserialize, Serialize};
4
5use crate::{Database, DbResultExt};
6use roboticus_core::Result;
7
8#[derive(Debug, Clone, Serialize, Deserialize)]
9pub struct MemoryImpact {
10    pub with_memory: f64,
11    pub without_memory: f64,
12}
13
14#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct QualityMetrics {
16    pub avg_grade: f64,
17    pub grade_count: i64,
18    pub grade_coverage: f64,
19    pub cost_per_quality_point: f64,
20    pub by_complexity: HashMap<String, f64>,
21    pub memory_impact: MemoryImpact,
22    pub trend: String,
23}
24
25#[derive(Debug, Clone, Serialize, Deserialize)]
26pub struct ModelEfficiency {
27    pub model: String,
28    pub total_turns: i64,
29    pub last_invoked_at: Option<String>,
30    pub success_count: i64,
31    pub error_count: i64,
32    pub avg_output_density: f64,
33    pub avg_budget_utilization: f64,
34    pub avg_memory_roi: f64,
35    pub avg_system_prompt_weight: f64,
36    pub cache_hit_rate: f64,
37    pub context_pressure_rate: f64,
38    pub cost: CostMetrics,
39    pub trend: TrendMetrics,
40    pub quality: Option<QualityMetrics>,
41}
42
43#[derive(Debug, Clone, Serialize, Deserialize)]
44pub struct CostMetrics {
45    pub total: f64,
46    pub per_output_token: f64,
47    pub effective_per_turn: f64,
48    pub cache_savings: f64,
49    pub cumulative_trend: String,
50    pub attribution: CostAttribution,
51    pub wasted_budget_cost: f64,
52}
53
54#[derive(Debug, Clone, Serialize, Deserialize)]
55pub struct CostAttribution {
56    pub system_prompt: AttributionDetail,
57    pub memories: AttributionDetail,
58    pub history: AttributionDetail,
59}
60
61#[derive(Debug, Clone, Serialize, Deserialize)]
62pub struct AttributionDetail {
63    pub tokens: i64,
64    pub cost: f64,
65    pub pct: f64,
66}
67
68#[derive(Debug, Clone, Serialize, Deserialize)]
69pub struct TrendMetrics {
70    pub output_density: String,
71    pub cost_per_turn: String,
72    pub cache_hit_rate: String,
73}
74
75#[derive(Debug, Clone, Serialize, Deserialize)]
76pub struct TimeSeriesPoint {
77    pub bucket: String,
78    pub model: String,
79    pub output_density: f64,
80    pub cost: f64,
81    pub turns: i64,
82    pub budget_utilization: f64,
83    pub cached_count: i64,
84}
85
86#[derive(Debug, Clone, Serialize, Deserialize)]
87pub struct EfficiencyTotals {
88    pub total_cost: f64,
89    pub total_cache_savings: f64,
90    pub total_turns: i64,
91    pub most_expensive_model: Option<String>,
92    pub most_efficient_model: Option<String>,
93    pub biggest_cost_driver: String,
94}
95
96#[derive(Debug, Clone, Serialize, Deserialize)]
97pub struct EfficiencyReport {
98    pub period: String,
99    pub models: HashMap<String, ModelEfficiency>,
100    pub time_series: Vec<TimeSeriesPoint>,
101    pub totals: EfficiencyTotals,
102}
103
104/// Round a float to 6 decimal places to avoid floating-point display noise.
105fn round6(v: f64) -> f64 {
106    (v * 1_000_000.0).round() / 1_000_000.0
107}
108
109fn cutoff_expr(period: &str) -> &'static str {
110    match period {
111        "1h" => "datetime('now', '-1 hour')",
112        "24h" => "datetime('now', '-1 day')",
113        "7d" => "datetime('now', '-7 days')",
114        "30d" => "datetime('now', '-30 days')",
115        _ => "datetime('1970-01-01')",
116    }
117}
118
119struct RawModelRow {
120    model: String,
121    total_turns: i64,
122    last_invoked_at: Option<String>,
123    avg_output_density: f64,
124    total_cost: f64,
125    total_tokens_out: i64,
126    total_tokens_in: i64,
127    cached_count: i64,
128    avg_cost_per_turn: f64,
129}
130
131fn trend_label(first_half: f64, second_half: f64) -> String {
132    if first_half == 0.0 && second_half == 0.0 {
133        return "stable".into();
134    }
135    let delta = second_half - first_half;
136    let base = first_half.max(0.001);
137    let pct = delta / base;
138    if pct > 0.05 {
139        "increasing".into()
140    } else if pct < -0.05 {
141        "decreasing".into()
142    } else {
143        "stable".into()
144    }
145}
146
147fn compute_quality_for_model(
148    conn: &rusqlite::Connection,
149    model: &str,
150    cutoff: &str,
151    total_turns: i64,
152) -> Option<QualityMetrics> {
153    let sql = format!(
154        "SELECT tf.grade, t.model, t.cost \
155         FROM turn_feedback tf \
156         JOIN turns t ON t.id = tf.turn_id \
157         WHERE t.model = ?1 AND tf.created_at >= {cutoff}"
158    );
159    let mut stmt = conn.prepare(&sql)
160        .inspect_err(|e| tracing::warn!(model, error = %e, "efficiency: failed to prepare quality metrics query"))
161        .ok()?;
162    let rows: Vec<(i32, String, f64)> = stmt
163        .query_map(rusqlite::params![model], |row| {
164            Ok((
165                row.get::<_, i32>(0)?,
166                row.get::<_, String>(1)?,
167                row.get::<_, f64>(2).unwrap_or(0.0),
168            ))
169        })
170        .inspect_err(|e| tracing::warn!(model, error = %e, "efficiency: failed to execute quality metrics query"))
171        .ok()?
172        .collect::<std::result::Result<Vec<_>, _>>()
173        .inspect_err(|e| tracing::warn!(model, error = %e, "efficiency: failed to collect quality metrics rows"))
174        .ok()?;
175
176    if rows.is_empty() {
177        return None;
178    }
179
180    let grade_count = rows.len() as i64;
181    let sum_grade: i64 = rows.iter().map(|(g, _, _)| *g as i64).sum();
182    let avg_grade = sum_grade as f64 / grade_count as f64;
183    let grade_coverage = if total_turns > 0 {
184        grade_count as f64 / total_turns as f64
185    } else {
186        0.0
187    };
188
189    let total_cost: f64 = rows.iter().map(|(_, _, c)| c).sum();
190    let total_quality: f64 = rows.iter().map(|(g, _, _)| *g as f64).sum();
191    let cost_per_quality_point = if total_quality > 0.0 {
192        total_cost / total_quality
193    } else {
194        0.0
195    };
196
197    let half = rows.len() / 2;
198    let trend = if rows.len() >= 4 {
199        let first_avg = rows[..half].iter().map(|(g, _, _)| *g as f64).sum::<f64>() / half as f64;
200        let second_avg = rows[half..].iter().map(|(g, _, _)| *g as f64).sum::<f64>()
201            / (rows.len() - half) as f64;
202        trend_label(first_avg, second_avg)
203    } else {
204        "stable".into()
205    };
206
207    Some(QualityMetrics {
208        avg_grade,
209        grade_count,
210        grade_coverage,
211        cost_per_quality_point,
212        by_complexity: HashMap::new(),
213        memory_impact: MemoryImpact {
214            with_memory: 0.0,
215            without_memory: 0.0,
216        },
217        trend,
218    })
219}
220
221pub fn compute_efficiency(
222    db: &Database,
223    period: &str,
224    model_filter: Option<&str>,
225) -> Result<EfficiencyReport> {
226    let cutoff = cutoff_expr(period);
227    let conn = db.conn();
228
229    let model_clause = if model_filter.is_some() {
230        " AND model = ?1"
231    } else {
232        ""
233    };
234
235    // ── Per-model aggregates ─────────────────────────────────
236    let main_sql = format!(
237        "SELECT \
238            model, \
239            COUNT(*) AS total_turns, \
240            MAX(created_at) AS last_invoked_at, \
241            AVG(CAST(tokens_out AS REAL) / NULLIF(tokens_in, 0)) AS avg_output_density, \
242            SUM(cost) AS total_cost, \
243            SUM(tokens_out) AS total_tokens_out, \
244            SUM(tokens_in) AS total_tokens_in, \
245            SUM(CASE WHEN cached = 1 THEN 1 ELSE 0 END) AS cached_count, \
246            AVG(cost) AS avg_cost_per_turn \
247         FROM inference_costs \
248         WHERE created_at >= {cutoff}{model_clause} \
249         GROUP BY model \
250         ORDER BY total_cost DESC"
251    );
252
253    let mut stmt = conn.prepare(&main_sql).db_err()?;
254
255    let map_row = |row: &rusqlite::Row| -> rusqlite::Result<RawModelRow> {
256        Ok(RawModelRow {
257            model: row.get(0)?,
258            total_turns: row.get(1)?,
259            last_invoked_at: row.get::<_, Option<String>>(2)?,
260            avg_output_density: row.get::<_, Option<f64>>(3)?.unwrap_or(0.0),
261            total_cost: row.get::<_, Option<f64>>(4)?.unwrap_or(0.0),
262            total_tokens_out: row.get::<_, Option<i64>>(5)?.unwrap_or(0),
263            total_tokens_in: row.get::<_, Option<i64>>(6)?.unwrap_or(0),
264            cached_count: row.get::<_, Option<i64>>(7)?.unwrap_or(0),
265            avg_cost_per_turn: row.get::<_, Option<f64>>(8)?.unwrap_or(0.0),
266        })
267    };
268
269    let rows: Vec<RawModelRow> = if let Some(mf) = model_filter {
270        stmt.query_map(rusqlite::params![mf], map_row)
271    } else {
272        stmt.query_map([], map_row)
273    }
274    .db_err()?
275    .collect::<std::result::Result<Vec<_>, _>>()
276    .db_err()?;
277
278    // ── Time-series (daily buckets) ──────────────────────────
279    let ts_sql = format!(
280        "SELECT \
281            strftime('%Y-%m-%d', created_at) AS bucket, \
282            model, \
283            AVG(CAST(tokens_out AS REAL) / NULLIF(tokens_in, 0)) AS output_density, \
284            SUM(cost) AS cost, \
285            COUNT(*) AS turns, \
286            SUM(CASE WHEN cached = 1 THEN 1 ELSE 0 END) AS cached_count \
287         FROM inference_costs \
288         WHERE created_at >= {cutoff}{model_clause} \
289         GROUP BY bucket, model \
290         ORDER BY bucket"
291    );
292
293    let mut ts_stmt = conn.prepare(&ts_sql).db_err()?;
294
295    let ts_map = |row: &rusqlite::Row| -> rusqlite::Result<TimeSeriesPoint> {
296        Ok(TimeSeriesPoint {
297            bucket: row.get(0)?,
298            model: row.get(1)?,
299            output_density: row.get::<_, Option<f64>>(2)?.unwrap_or(0.0),
300            cost: row.get::<_, Option<f64>>(3)?.unwrap_or(0.0),
301            turns: row.get(4)?,
302            budget_utilization: 0.0,
303            cached_count: row.get::<_, Option<i64>>(5)?.unwrap_or(0),
304        })
305    };
306
307    let time_series: Vec<TimeSeriesPoint> = if let Some(mf) = model_filter {
308        ts_stmt.query_map(rusqlite::params![mf], ts_map)
309    } else {
310        ts_stmt.query_map([], ts_map)
311    }
312    .db_err()?
313    .collect::<std::result::Result<Vec<_>, _>>()
314    .db_err()?;
315
316    // ── Build per-model trend data from time series ──────────
317    let mut model_ts: HashMap<String, Vec<&TimeSeriesPoint>> = HashMap::new();
318    for pt in &time_series {
319        model_ts.entry(pt.model.clone()).or_default().push(pt);
320    }
321
322    // ── Assemble ModelEfficiency map ─────────────────────────
323    let mut models: HashMap<String, ModelEfficiency> = HashMap::new();
324    let mut grand_total_cost = 0.0_f64;
325    let mut grand_total_turns = 0_i64;
326    let mut most_expensive: Option<(String, f64)> = None;
327    let mut most_efficient: Option<(String, f64)> = None;
328
329    for r in &rows {
330        let cache_hit_rate = if r.total_turns > 0 {
331            r.cached_count as f64 / r.total_turns as f64
332        } else {
333            0.0
334        };
335
336        let per_output_token = if r.total_tokens_out > 0 {
337            r.total_cost / r.total_tokens_out as f64
338        } else {
339            0.0
340        };
341
342        // Estimate cache savings: cached requests would have cost roughly the
343        // average per-turn cost, so savings ≈ cached_count × avg_cost_per_turn × input_fraction.
344        let input_fraction = if r.total_tokens_in + r.total_tokens_out > 0 {
345            r.total_tokens_in as f64 / (r.total_tokens_in + r.total_tokens_out) as f64
346        } else {
347            0.5
348        };
349        let cache_savings = r.cached_count as f64 * r.avg_cost_per_turn * input_fraction;
350
351        // Trends from time-series split
352        let pts = model_ts.get(&r.model).cloned().unwrap_or_default();
353        let trend = if pts.len() >= 2 {
354            let mid = pts.len() / 2;
355            let (first, second) = pts.split_at(mid);
356
357            let avg = |slice: &[&TimeSeriesPoint], f: fn(&TimeSeriesPoint) -> f64| -> f64 {
358                if slice.is_empty() {
359                    return 0.0;
360                }
361                slice.iter().map(|p| f(p)).sum::<f64>() / slice.len() as f64
362            };
363
364            let first_density = avg(first, |p| p.output_density);
365            let second_density = avg(second, |p| p.output_density);
366
367            let first_cpt = avg(first, |p| {
368                if p.turns > 0 {
369                    p.cost / p.turns as f64
370                } else {
371                    0.0
372                }
373            });
374            let second_cpt = avg(second, |p| {
375                if p.turns > 0 {
376                    p.cost / p.turns as f64
377                } else {
378                    0.0
379                }
380            });
381
382            let first_cache = avg(first, |p| {
383                if p.turns > 0 {
384                    p.cached_count as f64 / p.turns as f64
385                } else {
386                    0.0
387                }
388            });
389            let second_cache = avg(second, |p| {
390                if p.turns > 0 {
391                    p.cached_count as f64 / p.turns as f64
392                } else {
393                    0.0
394                }
395            });
396
397            TrendMetrics {
398                output_density: trend_label(first_density, second_density),
399                cost_per_turn: trend_label(first_cpt, second_cpt),
400                cache_hit_rate: trend_label(first_cache, second_cache),
401            }
402        } else {
403            TrendMetrics {
404                output_density: "stable".into(),
405                cost_per_turn: "stable".into(),
406                cache_hit_rate: "stable".into(),
407            }
408        };
409
410        let cumulative_trend = trend.cost_per_turn.clone();
411
412        // Without context_snapshots, attribute all input tokens to "history".
413        let attribution = CostAttribution {
414            system_prompt: AttributionDetail {
415                tokens: 0,
416                cost: 0.0,
417                pct: 0.0,
418            },
419            memories: AttributionDetail {
420                tokens: 0,
421                cost: 0.0,
422                pct: 0.0,
423            },
424            history: AttributionDetail {
425                tokens: r.total_tokens_in,
426                cost: r.total_cost * input_fraction,
427                pct: 100.0,
428            },
429        };
430
431        let quality = compute_quality_for_model(&conn, &r.model, cutoff, r.total_turns);
432
433        let eff = ModelEfficiency {
434            model: r.model.clone(),
435            total_turns: r.total_turns,
436            last_invoked_at: r.last_invoked_at.clone(),
437            success_count: r.total_turns,
438            error_count: 0,
439            avg_output_density: r.avg_output_density,
440            avg_budget_utilization: 0.0,
441            avg_memory_roi: 0.0,
442            avg_system_prompt_weight: 0.0,
443            cache_hit_rate,
444            context_pressure_rate: 0.0,
445            cost: CostMetrics {
446                total: round6(r.total_cost),
447                per_output_token: round6(per_output_token),
448                effective_per_turn: round6(r.avg_cost_per_turn),
449                cache_savings: round6(cache_savings),
450                cumulative_trend,
451                attribution,
452                wasted_budget_cost: 0.0,
453            },
454            trend,
455            quality,
456        };
457
458        grand_total_cost += r.total_cost;
459        grand_total_turns += r.total_turns;
460
461        match &most_expensive {
462            None => most_expensive = Some((r.model.clone(), r.total_cost)),
463            Some((_, c)) if r.total_cost > *c => {
464                most_expensive = Some((r.model.clone(), r.total_cost));
465            }
466            _ => {}
467        }
468
469        let density = r.avg_output_density;
470        match &most_efficient {
471            None => most_efficient = Some((r.model.clone(), density)),
472            Some((_, d)) if density > *d => {
473                most_efficient = Some((r.model.clone(), density));
474            }
475            _ => {}
476        }
477
478        models.insert(r.model.clone(), eff);
479    }
480
481    let total_cache_savings: f64 = models.values().map(|m| m.cost.cache_savings).sum();
482
483    let biggest_cost_driver = most_expensive
484        .as_ref()
485        .map(|(m, _)| m.clone())
486        .unwrap_or_else(|| "none".into());
487
488    let totals = EfficiencyTotals {
489        total_cost: round6(grand_total_cost),
490        total_cache_savings: round6(total_cache_savings),
491        total_turns: grand_total_turns,
492        most_expensive_model: most_expensive.map(|(m, _)| m),
493        most_efficient_model: most_efficient.map(|(m, _)| m),
494        biggest_cost_driver,
495    };
496
497    Ok(EfficiencyReport {
498        period: period.to_string(),
499        models,
500        time_series,
501        totals,
502    })
503}
504
505// ── UserProfile types for recommendations ────────────────────
506
507#[derive(Debug, Clone, Serialize, Deserialize)]
508pub struct RecommendationModelStats {
509    pub turns: i64,
510    pub avg_cost: f64,
511    pub avg_quality: Option<f64>,
512    pub cache_hit_rate: f64,
513    pub avg_output_density: f64,
514}
515
516#[derive(Debug, Clone, Serialize, Deserialize)]
517pub struct RecommendationUserProfile {
518    pub total_sessions: i64,
519    pub total_turns: i64,
520    pub total_cost: f64,
521    pub avg_quality: Option<f64>,
522    pub grade_coverage: f64,
523    pub models_used: Vec<String>,
524    pub model_stats: HashMap<String, RecommendationModelStats>,
525    pub avg_session_length: f64,
526    pub avg_tokens_per_turn: f64,
527    pub tool_success_rate: f64,
528    pub cache_hit_rate: f64,
529    pub memory_retrieval_rate: f64,
530}
531
532pub fn build_user_profile(db: &Database, period: &str) -> Result<RecommendationUserProfile> {
533    let cutoff = cutoff_expr(period);
534    let conn = db.conn();
535
536    let (total_sessions, avg_session_length): (i64, f64) = conn
537        .query_row(
538            &format!(
539                "SELECT COUNT(*), COALESCE(AVG(msg_count), 0) FROM (\
540                   SELECT s.id, COUNT(m.id) AS msg_count \
541                   FROM sessions s \
542                   LEFT JOIN session_messages m ON m.session_id = s.id \
543                   WHERE s.created_at >= {cutoff} \
544                   GROUP BY s.id\
545                 )"
546            ),
547            [],
548            |row| Ok((row.get(0)?, row.get(1)?)),
549        )
550        .db_err()?;
551
552    let (total_turns, total_cost, avg_tokens_per_turn, cache_hit_rate): (i64, f64, f64, f64) = conn
553        .query_row(
554            &format!(
555                "SELECT \
556                   COUNT(*), \
557                   COALESCE(SUM(cost), 0), \
558                   COALESCE(AVG(tokens_in + tokens_out), 0), \
559                   CASE WHEN COUNT(*) > 0 \
560                     THEN CAST(SUM(CASE WHEN cached = 1 THEN 1 ELSE 0 END) AS REAL) / COUNT(*) \
561                     ELSE 0.0 END \
562                 FROM inference_costs \
563                 WHERE created_at >= {cutoff}"
564            ),
565            [],
566            |row| Ok((row.get(0)?, row.get(1)?, row.get(2)?, row.get(3)?)),
567        )
568        .db_err()?;
569
570    let mut model_stmt = conn
571        .prepare(&format!(
572            "SELECT \
573               model, \
574               COUNT(*) AS turns, \
575               AVG(cost) AS avg_cost, \
576               CASE WHEN COUNT(*) > 0 \
577                 THEN CAST(SUM(CASE WHEN cached = 1 THEN 1 ELSE 0 END) AS REAL) / COUNT(*) \
578                 ELSE 0.0 END AS cache_rate, \
579               AVG(CAST(tokens_out AS REAL) / NULLIF(tokens_in, 0)) AS avg_density \
580             FROM inference_costs \
581             WHERE created_at >= {cutoff} \
582             GROUP BY model \
583             ORDER BY turns DESC"
584        ))
585        .db_err()?;
586
587    let mut models_used = Vec::new();
588    let mut model_stats = HashMap::new();
589
590    let rows = model_stmt
591        .query_map([], |row| {
592            Ok((
593                row.get::<_, String>(0)?,
594                row.get::<_, i64>(1)?,
595                row.get::<_, f64>(2)?,
596                row.get::<_, f64>(3)?,
597                row.get::<_, Option<f64>>(4)?,
598            ))
599        })
600        .db_err()?;
601
602    for row in rows {
603        let (model, turns, avg_cost, cache_rate, avg_density) = row.db_err()?;
604        models_used.push(model.clone());
605        model_stats.insert(
606            model,
607            RecommendationModelStats {
608                turns,
609                avg_cost,
610                avg_quality: None,
611                cache_hit_rate: cache_rate,
612                avg_output_density: avg_density.unwrap_or(0.0),
613            },
614        );
615    }
616
617    let tool_success_rate: f64 = conn
618        .query_row(
619            &format!(
620                "SELECT CASE WHEN COUNT(*) > 0 \
621                   THEN CAST(SUM(CASE WHEN status = 'success' THEN 1 ELSE 0 END) AS REAL) / COUNT(*) \
622                   ELSE 1.0 END \
623                 FROM tool_calls WHERE created_at >= {cutoff}"
624            ),
625            [],
626            |row| row.get(0),
627        )
628        .db_err()?;
629
630    let (graded_turns, avg_quality): (i64, Option<f64>) = conn
631        .query_row(
632            &format!(
633                "SELECT COUNT(*), AVG(CAST(tf.grade AS REAL)) \
634                 FROM turn_feedback tf \
635                 JOIN turns t ON t.id = tf.turn_id \
636                 JOIN sessions s ON s.id = t.session_id \
637                 WHERE s.created_at >= {cutoff}"
638            ),
639            [],
640            |row| Ok((row.get(0)?, row.get(1)?)),
641        )
642        .unwrap_or((0, None));
643
644    let grade_coverage = if total_turns > 0 {
645        graded_turns as f64 / total_turns as f64
646    } else {
647        0.0
648    };
649
650    // Compute memory retrieval rate from context_snapshots if available,
651    // otherwise default to 0.5
652    let memory_retrieval_rate: f64 = conn
653        .query_row(
654            &format!(
655                "SELECT CASE WHEN COUNT(*) > 0 \
656                   THEN CAST(SUM(CASE WHEN memory_tokens > 0 THEN 1 ELSE 0 END) AS REAL) / COUNT(*) \
657                   ELSE 0.5 END \
658                 FROM context_snapshots WHERE created_at >= {cutoff}"
659            ),
660            [],
661            |row| row.get(0),
662        )
663        .unwrap_or(0.5);
664
665    Ok(RecommendationUserProfile {
666        total_sessions,
667        total_turns,
668        total_cost,
669        avg_quality,
670        grade_coverage,
671        models_used,
672        model_stats,
673        avg_session_length,
674        avg_tokens_per_turn,
675        tool_success_rate,
676        cache_hit_rate,
677        memory_retrieval_rate,
678    })
679}
680
681#[cfg(test)]
682mod tests {
683    use super::*;
684    use crate::metrics::record_inference_cost;
685
686    fn test_db() -> Database {
687        Database::new(":memory:").unwrap()
688    }
689
690    #[test]
691    fn empty_database_returns_empty_report() {
692        let db = test_db();
693        let report = compute_efficiency(&db, "7d", None).unwrap();
694        assert!(report.models.is_empty());
695        assert!(report.time_series.is_empty());
696        assert_eq!(report.totals.total_turns, 0);
697        assert_eq!(report.totals.total_cost, 0.0);
698    }
699
700    #[test]
701    fn single_model_report() {
702        let db = test_db();
703        record_inference_cost(
704            &db,
705            "claude-4",
706            "anthropic",
707            1000,
708            500,
709            0.015,
710            Some("T1"),
711            false,
712            None,
713            None,
714            false,
715            None,
716        )
717        .unwrap();
718        record_inference_cost(
719            &db,
720            "claude-4",
721            "anthropic",
722            2000,
723            800,
724            0.025,
725            Some("T1"),
726            true,
727            None,
728            None,
729            false,
730            None,
731        )
732        .unwrap();
733
734        let report = compute_efficiency(&db, "all", None).unwrap();
735        assert_eq!(report.models.len(), 1);
736        let m = &report.models["claude-4"];
737        assert_eq!(m.total_turns, 2);
738        assert!(m.last_invoked_at.is_some());
739        assert_eq!(m.success_count, 2);
740        assert_eq!(m.error_count, 0);
741        assert!(m.avg_output_density > 0.0);
742        assert!((m.cost.total - 0.04).abs() < 1e-9);
743        assert_eq!(m.cache_hit_rate, 0.5);
744    }
745
746    #[test]
747    fn model_filter_works() {
748        let db = test_db();
749        record_inference_cost(
750            &db,
751            "claude-4",
752            "anthropic",
753            100,
754            50,
755            0.01,
756            None,
757            false,
758            None,
759            None,
760            false,
761            None,
762        )
763        .unwrap();
764        record_inference_cost(
765            &db, "gpt-4", "openai", 200, 100, 0.02, None, false, None, None, false, None,
766        )
767        .unwrap();
768
769        let report = compute_efficiency(&db, "all", Some("gpt-4")).unwrap();
770        assert_eq!(report.models.len(), 1);
771        assert!(report.models.contains_key("gpt-4"));
772    }
773
774    #[test]
775    fn multiple_models_totals() {
776        let db = test_db();
777        record_inference_cost(
778            &db,
779            "claude-4",
780            "anthropic",
781            100,
782            50,
783            0.01,
784            None,
785            false,
786            None,
787            None,
788            false,
789            None,
790        )
791        .unwrap();
792        record_inference_cost(
793            &db, "gpt-4", "openai", 200, 100, 0.02, None, false, None, None, false, None,
794        )
795        .unwrap();
796
797        let report = compute_efficiency(&db, "all", None).unwrap();
798        assert_eq!(report.totals.total_turns, 2);
799        assert!((report.totals.total_cost - 0.03).abs() < 1e-9);
800        assert_eq!(report.totals.most_expensive_model.as_deref(), Some("gpt-4"));
801    }
802
803    #[test]
804    fn time_series_has_entries() {
805        let db = test_db();
806        record_inference_cost(
807            &db,
808            "claude-4",
809            "anthropic",
810            100,
811            50,
812            0.01,
813            None,
814            false,
815            None,
816            None,
817            false,
818            None,
819        )
820        .unwrap();
821
822        let report = compute_efficiency(&db, "all", None).unwrap();
823        assert!(!report.time_series.is_empty());
824        assert_eq!(report.time_series[0].model, "claude-4");
825    }
826
827    #[test]
828    fn trend_label_logic() {
829        assert_eq!(trend_label(1.0, 1.5), "increasing");
830        assert_eq!(trend_label(1.0, 0.5), "decreasing");
831        assert_eq!(trend_label(1.0, 1.02), "stable");
832        assert_eq!(trend_label(0.0, 0.0), "stable");
833    }
834
835    #[test]
836    fn all_cached_full_rate() {
837        let db = test_db();
838        record_inference_cost(
839            &db, "m1", "p1", 100, 50, 0.01, None, true, None, None, false, None,
840        )
841        .unwrap();
842        record_inference_cost(
843            &db, "m1", "p1", 100, 50, 0.01, None, true, None, None, false, None,
844        )
845        .unwrap();
846
847        let report = compute_efficiency(&db, "all", None).unwrap();
848        assert_eq!(report.models["m1"].cache_hit_rate, 1.0);
849    }
850
851    #[test]
852    fn period_all_vs_default() {
853        assert_eq!(cutoff_expr("all"), "datetime('1970-01-01')");
854        assert_eq!(cutoff_expr("7d"), "datetime('now', '-7 days')");
855        assert_eq!(cutoff_expr("unknown"), "datetime('1970-01-01')");
856    }
857
858    #[test]
859    fn zero_tokens_in_no_division_by_zero() {
860        let db = test_db();
861        record_inference_cost(
862            &db, "m1", "p1", 0, 50, 0.01, None, false, None, None, false, None,
863        )
864        .unwrap();
865
866        let report = compute_efficiency(&db, "all", None).unwrap();
867        let m = &report.models["m1"];
868        assert!(m.avg_output_density.is_finite());
869    }
870
871    #[test]
872    fn cutoff_expr_1h() {
873        assert_eq!(cutoff_expr("1h"), "datetime('now', '-1 hour')");
874    }
875
876    #[test]
877    fn cutoff_expr_24h() {
878        assert_eq!(cutoff_expr("24h"), "datetime('now', '-1 day')");
879    }
880
881    #[test]
882    fn cutoff_expr_30d() {
883        assert_eq!(cutoff_expr("30d"), "datetime('now', '-30 days')");
884    }
885
886    #[test]
887    fn trend_label_edge_cases() {
888        // Clearly within stable band (3% change)
889        assert_eq!(trend_label(1.0, 1.03), "stable");
890        // Clearly above threshold (10% increase)
891        assert_eq!(trend_label(1.0, 1.10), "increasing");
892        // Clearly within stable band (3% decrease)
893        assert_eq!(trend_label(1.0, 0.97), "stable");
894        // Clearly below threshold (10% decrease)
895        assert_eq!(trend_label(1.0, 0.90), "decreasing");
896        // First half near zero uses 0.001 as base
897        assert_eq!(trend_label(0.001, 0.5), "increasing");
898        // Both near zero
899        assert_eq!(trend_label(0.0001, 0.0001), "stable");
900    }
901
902    #[test]
903    fn zero_tokens_out_no_division_by_zero() {
904        let db = test_db();
905        record_inference_cost(
906            &db, "m1", "p1", 1000, 0, 0.01, None, false, None, None, false, None,
907        )
908        .unwrap();
909
910        let report = compute_efficiency(&db, "all", None).unwrap();
911        let m = &report.models["m1"];
912        assert_eq!(m.cost.per_output_token, 0.0);
913        assert!(m.cost.total.is_finite());
914    }
915
916    #[test]
917    fn no_cached_zero_rate() {
918        let db = test_db();
919        record_inference_cost(
920            &db, "m1", "p1", 100, 50, 0.01, None, false, None, None, false, None,
921        )
922        .unwrap();
923        record_inference_cost(
924            &db, "m1", "p1", 100, 50, 0.01, None, false, None, None, false, None,
925        )
926        .unwrap();
927
928        let report = compute_efficiency(&db, "all", None).unwrap();
929        assert_eq!(report.models["m1"].cache_hit_rate, 0.0);
930    }
931
932    #[test]
933    fn multiple_models_identifies_most_efficient() {
934        let db = test_db();
935        // m1: 1000 in, 500 out -> density ~0.5
936        record_inference_cost(
937            &db, "m1", "p1", 1000, 500, 0.01, None, false, None, None, false, None,
938        )
939        .unwrap();
940        // m2: 100 in, 200 out -> density ~2.0  (more efficient)
941        record_inference_cost(
942            &db, "m2", "p2", 100, 200, 0.005, None, false, None, None, false, None,
943        )
944        .unwrap();
945
946        let report = compute_efficiency(&db, "all", None).unwrap();
947        assert_eq!(
948            report.totals.most_efficient_model.as_deref(),
949            Some("m2"),
950            "m2 has higher output density"
951        );
952    }
953
954    #[test]
955    fn trend_metrics_with_time_series() {
956        let db = test_db();
957        // Insert enough records over multiple days to generate time-series buckets
958        let conn = db.conn();
959        for i in 0..6 {
960            let day = format!("2025-01-{:02}T12:00:00", i + 1);
961            conn.execute(
962                "INSERT INTO inference_costs (id, model, provider, tokens_in, tokens_out, cost, cached, created_at) \
963                 VALUES (?1, 'claude-4', 'anthropic', ?2, ?3, ?4, 0, ?5)",
964                rusqlite::params![
965                    format!("ic-{i}"),
966                    1000 + i * 100,
967                    500 + i * 50,
968                    0.01 + i as f64 * 0.005,
969                    day,
970                ],
971            )
972            .unwrap();
973        }
974        drop(conn);
975
976        let report = compute_efficiency(&db, "all", None).unwrap();
977        let m = &report.models["claude-4"];
978        // With 6 data points over 6 different days, we should have time series data
979        assert!(report.time_series.len() >= 2);
980        // Trend should be computed (not just "stable")
981        assert!(!m.trend.output_density.is_empty());
982        assert!(!m.trend.cost_per_turn.is_empty());
983        assert!(!m.trend.cache_hit_rate.is_empty());
984    }
985
986    #[test]
987    fn build_user_profile_empty_db() {
988        let db = test_db();
989        let profile = build_user_profile(&db, "7d").unwrap();
990        assert_eq!(profile.total_sessions, 0);
991        assert_eq!(profile.total_turns, 0);
992        assert_eq!(profile.total_cost, 0.0);
993        assert!(profile.models_used.is_empty());
994        assert!(profile.model_stats.is_empty());
995        assert_eq!(profile.avg_session_length, 0.0);
996        assert_eq!(profile.avg_tokens_per_turn, 0.0);
997        assert_eq!(profile.tool_success_rate, 1.0); // No tools => default 1.0
998    }
999
1000    #[test]
1001    fn build_user_profile_with_data() {
1002        let db = test_db();
1003        let conn = db.conn();
1004        // Create a session
1005        conn.execute(
1006            "INSERT INTO sessions (id, agent_id, scope_key, status) VALUES ('s1', 'agent-1', 'agent', 'active')",
1007            [],
1008        )
1009        .unwrap();
1010        // Create messages for the session
1011        conn.execute(
1012            "INSERT INTO session_messages (id, session_id, role, content) VALUES ('m1', 's1', 'user', 'hello')",
1013            [],
1014        )
1015        .unwrap();
1016        conn.execute(
1017            "INSERT INTO session_messages (id, session_id, role, content) VALUES ('m2', 's1', 'assistant', 'hi')",
1018            [],
1019        )
1020        .unwrap();
1021        drop(conn);
1022
1023        // Add inference costs
1024        record_inference_cost(
1025            &db,
1026            "claude-4",
1027            "anthropic",
1028            1000,
1029            500,
1030            0.015,
1031            Some("T1"),
1032            false,
1033            None,
1034            None,
1035            false,
1036            None,
1037        )
1038        .unwrap();
1039        record_inference_cost(
1040            &db,
1041            "claude-4",
1042            "anthropic",
1043            2000,
1044            800,
1045            0.025,
1046            Some("T1"),
1047            true,
1048            None,
1049            None,
1050            false,
1051            None,
1052        )
1053        .unwrap();
1054        record_inference_cost(
1055            &db, "gpt-4", "openai", 500, 200, 0.01, None, false, None, None, false, None,
1056        )
1057        .unwrap();
1058
1059        // Add a tool call
1060        {
1061            let conn = db.conn();
1062            conn.execute("INSERT INTO turns (id, session_id) VALUES ('t1', 's1')", [])
1063                .unwrap();
1064            conn.execute(
1065                "INSERT INTO tool_calls (id, turn_id, tool_name, input, status) VALUES ('tc1', 't1', 'bash', '{}', 'success')",
1066                [],
1067            )
1068            .unwrap();
1069            conn.execute(
1070                "INSERT INTO tool_calls (id, turn_id, tool_name, input, status) VALUES ('tc2', 't1', 'bash', '{}', 'error')",
1071                [],
1072            )
1073            .unwrap();
1074        }
1075
1076        let profile = build_user_profile(&db, "all").unwrap();
1077        assert_eq!(profile.total_sessions, 1);
1078        assert_eq!(profile.total_turns, 3);
1079        assert!((profile.total_cost - 0.05).abs() < 1e-9);
1080        assert!(profile.models_used.contains(&"claude-4".to_string()));
1081        assert!(profile.models_used.contains(&"gpt-4".to_string()));
1082        assert_eq!(profile.model_stats.len(), 2);
1083        assert_eq!(profile.model_stats["claude-4"].turns, 2);
1084        assert_eq!(profile.model_stats["gpt-4"].turns, 1);
1085        assert_eq!(profile.avg_session_length, 2.0); // 2 messages
1086        assert!(profile.avg_tokens_per_turn > 0.0);
1087        assert_eq!(profile.tool_success_rate, 0.5); // 1 success out of 2
1088        assert!(profile.cache_hit_rate > 0.0);
1089    }
1090
1091    #[test]
1092    fn build_user_profile_grade_coverage() {
1093        let db = test_db();
1094        let conn = db.conn();
1095        conn.execute(
1096            "INSERT INTO sessions (id, agent_id, scope_key, status) VALUES ('s1', 'agent-1', 'agent', 'active')",
1097            [],
1098        )
1099        .unwrap();
1100        conn.execute("INSERT INTO turns (id, session_id) VALUES ('t1', 's1')", [])
1101            .unwrap();
1102        conn.execute(
1103            "INSERT INTO turn_feedback (id, turn_id, session_id, grade, source) VALUES ('tf1', 't1', 's1', 4, 'dashboard')",
1104            [],
1105        )
1106        .unwrap();
1107        drop(conn);
1108
1109        record_inference_cost(
1110            &db, "m1", "p1", 100, 50, 0.01, None, false, None, None, false, None,
1111        )
1112        .unwrap();
1113        record_inference_cost(
1114            &db, "m1", "p1", 100, 50, 0.01, None, false, None, None, false, None,
1115        )
1116        .unwrap();
1117
1118        let profile = build_user_profile(&db, "all").unwrap();
1119        assert!(profile.avg_quality.is_some());
1120        assert!((profile.avg_quality.unwrap() - 4.0).abs() < 1e-9);
1121        assert!(profile.grade_coverage > 0.0);
1122        assert!(profile.grade_coverage <= 1.0);
1123    }
1124
1125    #[test]
1126    fn compute_quality_for_model_with_feedback() {
1127        let db = test_db();
1128        let conn = db.conn();
1129        conn.execute(
1130            "INSERT INTO sessions (id, agent_id, scope_key, status) VALUES ('s1', 'a1', 'agent', 'active')",
1131            [],
1132        )
1133        .unwrap();
1134        conn.execute(
1135            "INSERT INTO turns (id, session_id, model, cost) VALUES ('t1', 's1', 'claude-4', 0.01)",
1136            [],
1137        )
1138        .unwrap();
1139        conn.execute(
1140            "INSERT INTO turns (id, session_id, model, cost) VALUES ('t2', 's1', 'claude-4', 0.02)",
1141            [],
1142        )
1143        .unwrap();
1144        conn.execute(
1145            "INSERT INTO turns (id, session_id, model, cost) VALUES ('t3', 's1', 'claude-4', 0.015)",
1146            [],
1147        )
1148        .unwrap();
1149        conn.execute(
1150            "INSERT INTO turns (id, session_id, model, cost) VALUES ('t4', 's1', 'claude-4', 0.025)",
1151            [],
1152        )
1153        .unwrap();
1154        // Add feedback for all 4 turns
1155        conn.execute(
1156            "INSERT INTO turn_feedback (id, turn_id, session_id, grade) VALUES ('f1', 't1', 's1', 3)",
1157            [],
1158        )
1159        .unwrap();
1160        conn.execute(
1161            "INSERT INTO turn_feedback (id, turn_id, session_id, grade) VALUES ('f2', 't2', 's1', 4)",
1162            [],
1163        )
1164        .unwrap();
1165        conn.execute(
1166            "INSERT INTO turn_feedback (id, turn_id, session_id, grade) VALUES ('f3', 't3', 's1', 5)",
1167            [],
1168        )
1169        .unwrap();
1170        conn.execute(
1171            "INSERT INTO turn_feedback (id, turn_id, session_id, grade) VALUES ('f4', 't4', 's1', 5)",
1172            [],
1173        )
1174        .unwrap();
1175
1176        let quality = compute_quality_for_model(&conn, "claude-4", "datetime('1970-01-01')", 4);
1177        drop(conn);
1178
1179        assert!(quality.is_some());
1180        let q = quality.unwrap();
1181        assert_eq!(q.grade_count, 4);
1182        assert!((q.avg_grade - 4.25).abs() < 1e-9);
1183        assert_eq!(q.grade_coverage, 1.0);
1184        assert!(q.cost_per_quality_point > 0.0);
1185        // With 4 feedback entries and improvement from first half (3,4) to second half (5,5),
1186        // trend should be "increasing"
1187        assert_eq!(q.trend, "increasing");
1188    }
1189
1190    #[test]
1191    fn compute_quality_for_model_no_feedback() {
1192        let db = test_db();
1193        let conn = db.conn();
1194        let quality = compute_quality_for_model(&conn, "claude-4", "datetime('1970-01-01')", 10);
1195        drop(conn);
1196        assert!(quality.is_none());
1197    }
1198
1199    #[test]
1200    fn compute_quality_few_entries_stable_trend() {
1201        let db = test_db();
1202        let conn = db.conn();
1203        conn.execute(
1204            "INSERT INTO sessions (id, agent_id, scope_key, status) VALUES ('s1', 'a1', 'agent', 'active')",
1205            [],
1206        )
1207        .unwrap();
1208        conn.execute(
1209            "INSERT INTO turns (id, session_id, model, cost) VALUES ('t1', 's1', 'claude-4', 0.01)",
1210            [],
1211        )
1212        .unwrap();
1213        conn.execute(
1214            "INSERT INTO turn_feedback (id, turn_id, session_id, grade) VALUES ('f1', 't1', 's1', 4)",
1215            [],
1216        )
1217        .unwrap();
1218
1219        let quality = compute_quality_for_model(&conn, "claude-4", "datetime('1970-01-01')", 1);
1220        drop(conn);
1221
1222        assert!(quality.is_some());
1223        let q = quality.unwrap();
1224        // With fewer than 4 entries, trend should be "stable"
1225        assert_eq!(q.trend, "stable");
1226    }
1227
1228    #[test]
1229    fn report_cost_attribution_all_history() {
1230        let db = test_db();
1231        record_inference_cost(
1232            &db, "m1", "p1", 1000, 500, 0.03, None, false, None, None, false, None,
1233        )
1234        .unwrap();
1235
1236        let report = compute_efficiency(&db, "all", None).unwrap();
1237        let m = &report.models["m1"];
1238        // Without context_snapshots, all input tokens are attributed to "history"
1239        assert_eq!(m.cost.attribution.history.pct, 100.0);
1240        assert!(m.cost.attribution.history.tokens > 0);
1241        assert_eq!(m.cost.attribution.system_prompt.tokens, 0);
1242        assert_eq!(m.cost.attribution.memories.tokens, 0);
1243    }
1244
1245    #[test]
1246    fn report_biggest_cost_driver_with_no_models() {
1247        let db = test_db();
1248        let report = compute_efficiency(&db, "all", None).unwrap();
1249        assert_eq!(report.totals.biggest_cost_driver, "none");
1250        assert!(report.totals.most_expensive_model.is_none());
1251        assert!(report.totals.most_efficient_model.is_none());
1252    }
1253
1254    #[test]
1255    fn build_user_profile_memory_retrieval_default() {
1256        let db = test_db();
1257        let profile = build_user_profile(&db, "all").unwrap();
1258        // Without context_snapshots, memory_retrieval_rate defaults to 0.5
1259        assert!((profile.memory_retrieval_rate - 0.5).abs() < 1e-9);
1260    }
1261}