use anyhow::Result;
use rusqlite::params;
use crate::{db::Db, pricing::Pricing};
#[derive(Debug, Default)]
pub struct RecostStats {
pub rows_examined: u64,
pub rows_changed: u64,
pub old_total_usd: f64,
pub new_total_usd: f64,
}
pub struct RecostFilters<'a> {
pub model_substr: Option<&'a str>,
pub since_iso: Option<&'a str>,
pub dry_run: bool,
}
pub fn run(filters: RecostFilters) -> Result<RecostStats> {
let mut db = Db::open()?;
let pricing = Pricing::load()?;
let mut clauses: Vec<String> = Vec::new();
let mut binds: Vec<rusqlite::types::Value> = Vec::new();
if let Some(m) = filters.model_substr {
clauses.push(format!("model LIKE ?{}", binds.len() + 1));
binds.push(rusqlite::types::Value::Text(format!("%{m}%")));
}
if let Some(t) = filters.since_iso {
clauses.push(format!("timestamp >= ?{}", binds.len() + 1));
binds.push(rusqlite::types::Value::Text(t.to_string()));
}
let where_sql = if clauses.is_empty() {
String::new()
} else {
format!("WHERE {}", clauses.join(" AND "))
};
let select_sql = format!(
"SELECT message_id, model, timestamp,
input_tokens, output_tokens,
cache_creation_5m, cache_creation_1h, cache_read_tokens,
cost_usd, pricing_version
FROM usage_events {where_sql}"
);
let mut stats = RecostStats::default();
let tx = db.conn.transaction()?;
{
let mut stmt = tx.prepare(&select_sql)?;
let rows: Vec<RowSnapshot> = stmt
.query_map(rusqlite::params_from_iter(binds.iter()), |r| {
Ok(RowSnapshot {
message_id: r.get(0)?,
model: r.get(1)?,
timestamp: r.get(2)?,
input: r.get::<_, i64>(3)? as u64,
output: r.get::<_, i64>(4)? as u64,
cw5: r.get::<_, i64>(5)? as u64,
cw1: r.get::<_, i64>(6)? as u64,
cread: r.get::<_, i64>(7)? as u64,
old_cost: r.get(8)?,
old_version: r.get(9)?,
})
})?
.collect::<Result<Vec<_>, _>>()?;
let mut update = tx.prepare(
"UPDATE usage_events SET cost_usd = ?2, pricing_version = ?3 WHERE message_id = ?1",
)?;
for r in rows {
stats.rows_examined += 1;
stats.old_total_usd += r.old_cost;
let cost = pricing.compute(
&r.model,
&r.timestamp,
r.input,
r.output,
r.cw5,
r.cw1,
r.cread,
);
stats.new_total_usd += cost.usd;
let changed = (cost.usd - r.old_cost).abs() > 1e-9 || cost.version != r.old_version;
if changed {
stats.rows_changed += 1;
if !filters.dry_run {
update.execute(params![r.message_id, cost.usd, cost.version])?;
}
}
}
}
if filters.dry_run {
tx.rollback()?;
} else {
tx.commit()?;
}
Ok(stats)
}
struct RowSnapshot {
message_id: String,
model: String,
timestamp: String,
input: u64,
output: u64,
cw5: u64,
cw1: u64,
cread: u64,
old_cost: f64,
old_version: String,
}