tokr 0.1.0

Persistent token-usage ledger for AI coding agents. Captures on write, queries forever.
use anyhow::Result;
use rusqlite::params;

use crate::{db::Db, pricing::Pricing};

#[derive(Debug, Default)]
pub struct RecostStats {
    pub rows_examined: u64,
    pub rows_changed: u64,
    pub old_total_usd: f64,
    pub new_total_usd: f64,
}

pub struct RecostFilters<'a> {
    pub model_substr: Option<&'a str>,
    pub since_iso: Option<&'a str>,
    pub dry_run: bool,
}

pub fn run(filters: RecostFilters) -> Result<RecostStats> {
    let mut db = Db::open()?;
    let pricing = Pricing::load()?;

    let mut clauses: Vec<String> = Vec::new();
    let mut binds: Vec<rusqlite::types::Value> = Vec::new();
    if let Some(m) = filters.model_substr {
        clauses.push(format!("model LIKE ?{}", binds.len() + 1));
        binds.push(rusqlite::types::Value::Text(format!("%{m}%")));
    }
    if let Some(t) = filters.since_iso {
        clauses.push(format!("timestamp >= ?{}", binds.len() + 1));
        binds.push(rusqlite::types::Value::Text(t.to_string()));
    }
    let where_sql = if clauses.is_empty() {
        String::new()
    } else {
        format!("WHERE {}", clauses.join(" AND "))
    };

    let select_sql = format!(
        "SELECT message_id, model, timestamp,
                input_tokens, output_tokens,
                cache_creation_5m, cache_creation_1h, cache_read_tokens,
                cost_usd, pricing_version
         FROM usage_events {where_sql}"
    );

    let mut stats = RecostStats::default();
    let tx = db.conn.transaction()?;

    {
        let mut stmt = tx.prepare(&select_sql)?;
        let rows: Vec<RowSnapshot> = stmt
            .query_map(rusqlite::params_from_iter(binds.iter()), |r| {
                Ok(RowSnapshot {
                    message_id: r.get(0)?,
                    model: r.get(1)?,
                    timestamp: r.get(2)?,
                    input: r.get::<_, i64>(3)? as u64,
                    output: r.get::<_, i64>(4)? as u64,
                    cw5: r.get::<_, i64>(5)? as u64,
                    cw1: r.get::<_, i64>(6)? as u64,
                    cread: r.get::<_, i64>(7)? as u64,
                    old_cost: r.get(8)?,
                    old_version: r.get(9)?,
                })
            })?
            .collect::<Result<Vec<_>, _>>()?;

        let mut update = tx.prepare(
            "UPDATE usage_events SET cost_usd = ?2, pricing_version = ?3 WHERE message_id = ?1",
        )?;

        for r in rows {
            stats.rows_examined += 1;
            stats.old_total_usd += r.old_cost;
            let cost = pricing.compute(
                &r.model,
                &r.timestamp,
                r.input,
                r.output,
                r.cw5,
                r.cw1,
                r.cread,
            );
            stats.new_total_usd += cost.usd;

            let changed = (cost.usd - r.old_cost).abs() > 1e-9 || cost.version != r.old_version;
            if changed {
                stats.rows_changed += 1;
                if !filters.dry_run {
                    update.execute(params![r.message_id, cost.usd, cost.version])?;
                }
            }
        }
    }

    if filters.dry_run {
        tx.rollback()?;
    } else {
        tx.commit()?;
    }
    Ok(stats)
}

struct RowSnapshot {
    message_id: String,
    model: String,
    timestamp: String,
    input: u64,
    output: u64,
    cw5: u64,
    cw1: u64,
    cread: u64,
    old_cost: f64,
    old_version: String,
}