tokr 0.1.0

Persistent token-usage ledger for AI coding agents. Captures on write, queries forever.
use anyhow::{Context, Result};
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::path::PathBuf;

use crate::paths;

const EMBEDDED_PRICING: &str = include_str!("../data/pricing.toml");

#[derive(Debug, Deserialize, Serialize, Clone)]
pub struct ModelPrice {
    pub name: String,
    pub effective_from: String,
    pub input_per_mtok: f64,
    pub output_per_mtok: f64,
    pub cache_write_5m_per_mtok: f64,
    pub cache_write_1h_per_mtok: f64,
    pub cache_read_per_mtok: f64,
}

#[derive(Debug, Deserialize)]
struct PricingFile {
    #[serde(rename = "model")]
    models: Vec<ModelPrice>,
}

#[derive(Debug, Serialize, Default)]
struct OverridesFile {
    #[serde(rename = "model", default)]
    models: Vec<ModelPrice>,
}

#[derive(Debug, Clone)]
struct PricedEntry {
    effective_at: DateTime<Utc>,
    raw_effective_from: String,
    price: ModelPrice,
}

pub struct Pricing {
    timelines: HashMap<String, Vec<PricedEntry>>,
    fallback: PricedEntry,
}

#[derive(Debug, Clone)]
pub struct Cost {
    pub usd: f64,
    pub version: String,
}

impl Pricing {
    pub fn load() -> Result<Self> {
        let mut entries: Vec<ModelPrice> =
            parse(EMBEDDED_PRICING).context("parsing embedded pricing.toml")?;
        if let Some(extras) = load_overrides()? {
            entries.extend(extras);
        }
        Self::from_entries(entries)
    }

    pub fn from_entries(entries: Vec<ModelPrice>) -> Result<Self> {
        let mut timelines: HashMap<String, Vec<PricedEntry>> = HashMap::new();
        for p in entries {
            let effective_at = parse_iso(&p.effective_from).with_context(|| {
                format!(
                    "model `{}`: effective_from `{}` is not a valid ISO timestamp",
                    p.name, p.effective_from
                )
            })?;
            let raw_effective_from = effective_at.format("%Y-%m-%d").to_string();
            timelines
                .entry(p.name.clone())
                .or_default()
                .push(PricedEntry {
                    effective_at,
                    raw_effective_from,
                    price: p,
                });
        }
        for v in timelines.values_mut() {
            v.sort_by_key(|a| a.effective_at);
        }
        let fallback = timelines
            .get("__unknown__")
            .and_then(|v| v.first().cloned())
            .context("pricing.toml is missing the __unknown__ fallback entry")?;

        Ok(Self {
            timelines,
            fallback,
        })
    }

    pub fn compute(
        &self,
        model: &str,
        event_timestamp: &str,
        input: u64,
        output: u64,
        cache_write_5m: u64,
        cache_write_1h: u64,
        cache_read: u64,
    ) -> Cost {
        let canonical = self.canonicalize(model);
        let entry = self.pick_entry(canonical, event_timestamp);
        let p = &entry.price;
        let mt = 1_000_000.0;
        let usd = (input as f64) / mt * p.input_per_mtok
            + (output as f64) / mt * p.output_per_mtok
            + (cache_write_5m as f64) / mt * p.cache_write_5m_per_mtok
            + (cache_write_1h as f64) / mt * p.cache_write_1h_per_mtok
            + (cache_read as f64) / mt * p.cache_read_per_mtok;
        Cost {
            usd,
            version: format!("{}@{}", canonical, entry.raw_effective_from),
        }
    }

    pub fn canonicalize<'a>(&self, model: &'a str) -> &'a str {
        if self.timelines.contains_key(model) {
            return model;
        }
        if let Some(stripped) = strip_date_suffix(model) {
            if self.timelines.contains_key(stripped) {
                return stripped;
            }
        }
        model
    }

    fn pick_entry(&self, canonical_model: &str, event_timestamp: &str) -> PricedEntry {
        let timeline = match self.timelines.get(canonical_model) {
            Some(t) if !t.is_empty() => t,
            _ => return self.fallback.clone(),
        };

        let event_at = parse_iso(event_timestamp);

        match event_at {
            Ok(t) => {
                let mut chosen: Option<&PricedEntry> = None;
                for e in timeline {
                    if e.effective_at <= t {
                        chosen = Some(e);
                    } else {
                        break;
                    }
                }
                chosen.cloned().unwrap_or_else(|| timeline[0].clone())
            }
            Err(_) => timeline
                .last()
                .cloned()
                .unwrap_or_else(|| self.fallback.clone()),
        }
    }

    pub fn list_all(&self) -> Vec<&ModelPrice> {
        let mut out: Vec<&ModelPrice> = Vec::new();
        let mut keys: Vec<&String> = self.timelines.keys().collect();
        keys.sort();
        for k in keys {
            for e in &self.timelines[k] {
                out.push(&e.price);
            }
        }
        out
    }
}

fn parse(s: &str) -> Result<Vec<ModelPrice>> {
    let f: PricingFile = toml::from_str(s)?;
    Ok(f.models)
}

fn parse_iso(s: &str) -> Result<DateTime<Utc>> {
    if let Ok(t) = DateTime::parse_from_rfc3339(s) {
        return Ok(t.with_timezone(&Utc));
    }
    if let Ok(d) = chrono::NaiveDate::parse_from_str(s, "%Y-%m-%d") {
        let dt = d.and_hms_opt(0, 0, 0).context("invalid date")?;
        return Ok(DateTime::<Utc>::from_naive_utc_and_offset(dt, Utc));
    }
    anyhow::bail!("not a recognized timestamp: `{s}`")
}

fn strip_date_suffix(s: &str) -> Option<&str> {
    let bytes = s.as_bytes();
    if bytes.len() < 9 {
        return None;
    }
    let n = bytes.len();
    if bytes[n - 9] != b'-' {
        return None;
    }
    if !bytes[n - 8..].iter().all(u8::is_ascii_digit) {
        return None;
    }
    Some(&s[..n - 9])
}

pub fn overrides_path() -> Result<PathBuf> {
    Ok(paths::data_dir()?.join("pricing_overrides.toml"))
}

fn load_overrides() -> Result<Option<Vec<ModelPrice>>> {
    let path = overrides_path()?;
    if !path.exists() {
        return Ok(None);
    }
    let text =
        std::fs::read_to_string(&path).with_context(|| format!("reading {}", path.display()))?;
    let entries = parse(&text).with_context(|| format!("parsing {}", path.display()))?;
    Ok(Some(entries))
}

pub fn append_overrides(new_entries: &[ModelPrice]) -> Result<usize> {
    paths::ensure_data_dir()?;
    let mut existing: Vec<ModelPrice> = load_overrides()?.unwrap_or_default();
    let mut added = 0usize;
    let key = |m: &ModelPrice| {
        let dt = parse_iso(&m.effective_from).map_or_else(
            |_| m.effective_from.clone(),
            |t| t.format("%Y-%m-%d").to_string(),
        );
        (m.name.clone(), dt)
    };
    let existing_keys: std::collections::HashSet<(String, String)> =
        existing.iter().map(key).collect();

    for m in new_entries {
        if !existing_keys.contains(&key(m)) {
            existing.push(m.clone());
            added += 1;
        }
    }
    if added == 0 {
        return Ok(0);
    }
    let out = OverridesFile { models: existing };
    let text = toml::to_string_pretty(&out).context("serializing overrides")?;
    let path = overrides_path()?;
    std::fs::write(&path, text).with_context(|| format!("writing {}", path.display()))?;
    Ok(added)
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn strips_date_suffix() {
        assert_eq!(
            strip_date_suffix("claude-haiku-4-5-20251001"),
            Some("claude-haiku-4-5")
        );
        assert_eq!(strip_date_suffix("claude-opus-4-7"), None);
        assert_eq!(strip_date_suffix("foo-12345678"), Some("foo"));
        assert_eq!(strip_date_suffix("short"), None);
    }

    #[test]
    fn picks_most_recent_qualifying_entry() {
        let entries = vec![
            ModelPrice {
                name: "x".into(),
                effective_from: "2025-01-01T00:00:00Z".into(),
                input_per_mtok: 1.0,
                output_per_mtok: 1.0,
                cache_write_5m_per_mtok: 0.0,
                cache_write_1h_per_mtok: 0.0,
                cache_read_per_mtok: 0.0,
            },
            ModelPrice {
                name: "x".into(),
                effective_from: "2026-01-01T00:00:00Z".into(),
                input_per_mtok: 2.0,
                output_per_mtok: 2.0,
                cache_write_5m_per_mtok: 0.0,
                cache_write_1h_per_mtok: 0.0,
                cache_read_per_mtok: 0.0,
            },
            ModelPrice {
                name: "__unknown__".into(),
                effective_from: "2024-01-01T00:00:00Z".into(),
                input_per_mtok: 99.0,
                output_per_mtok: 99.0,
                cache_write_5m_per_mtok: 0.0,
                cache_write_1h_per_mtok: 0.0,
                cache_read_per_mtok: 0.0,
            },
        ];
        let p = Pricing::from_entries(entries).unwrap();
        let c = p.compute("x", "2025-06-01T00:00:00Z", 1_000_000, 0, 0, 0, 0);
        assert!((c.usd - 1.0).abs() < 1e-9);
        assert_eq!(c.version, "x@2025-01-01");
        let c = p.compute("x", "2026-06-01T00:00:00Z", 1_000_000, 0, 0, 0, 0);
        assert!((c.usd - 2.0).abs() < 1e-9);
        assert_eq!(c.version, "x@2026-01-01");
        let c = p.compute("x", "2024-06-01T00:00:00Z", 1_000_000, 0, 0, 0, 0);
        assert!((c.usd - 1.0).abs() < 1e-9);
    }
}