use anyhow::{Context, Result};
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::path::PathBuf;
use crate::paths;
const EMBEDDED_PRICING: &str = include_str!("../data/pricing.toml");
#[derive(Debug, Deserialize, Serialize, Clone)]
pub struct ModelPrice {
pub name: String,
pub effective_from: String,
pub input_per_mtok: f64,
pub output_per_mtok: f64,
pub cache_write_5m_per_mtok: f64,
pub cache_write_1h_per_mtok: f64,
pub cache_read_per_mtok: f64,
}
#[derive(Debug, Deserialize)]
struct PricingFile {
#[serde(rename = "model")]
models: Vec<ModelPrice>,
}
#[derive(Debug, Serialize, Default)]
struct OverridesFile {
#[serde(rename = "model", default)]
models: Vec<ModelPrice>,
}
#[derive(Debug, Clone)]
struct PricedEntry {
effective_at: DateTime<Utc>,
raw_effective_from: String,
price: ModelPrice,
}
pub struct Pricing {
timelines: HashMap<String, Vec<PricedEntry>>,
fallback: PricedEntry,
}
#[derive(Debug, Clone)]
pub struct Cost {
pub usd: f64,
pub version: String,
}
impl Pricing {
pub fn load() -> Result<Self> {
let mut entries: Vec<ModelPrice> =
parse(EMBEDDED_PRICING).context("parsing embedded pricing.toml")?;
if let Some(extras) = load_overrides()? {
entries.extend(extras);
}
Self::from_entries(entries)
}
pub fn from_entries(entries: Vec<ModelPrice>) -> Result<Self> {
let mut timelines: HashMap<String, Vec<PricedEntry>> = HashMap::new();
for p in entries {
let effective_at = parse_iso(&p.effective_from).with_context(|| {
format!(
"model `{}`: effective_from `{}` is not a valid ISO timestamp",
p.name, p.effective_from
)
})?;
let raw_effective_from = effective_at.format("%Y-%m-%d").to_string();
timelines
.entry(p.name.clone())
.or_default()
.push(PricedEntry {
effective_at,
raw_effective_from,
price: p,
});
}
for v in timelines.values_mut() {
v.sort_by_key(|a| a.effective_at);
}
let fallback = timelines
.get("__unknown__")
.and_then(|v| v.first().cloned())
.context("pricing.toml is missing the __unknown__ fallback entry")?;
Ok(Self {
timelines,
fallback,
})
}
pub fn compute(
&self,
model: &str,
event_timestamp: &str,
input: u64,
output: u64,
cache_write_5m: u64,
cache_write_1h: u64,
cache_read: u64,
) -> Cost {
let canonical = self.canonicalize(model);
let entry = self.pick_entry(canonical, event_timestamp);
let p = &entry.price;
let mt = 1_000_000.0;
let usd = (input as f64) / mt * p.input_per_mtok
+ (output as f64) / mt * p.output_per_mtok
+ (cache_write_5m as f64) / mt * p.cache_write_5m_per_mtok
+ (cache_write_1h as f64) / mt * p.cache_write_1h_per_mtok
+ (cache_read as f64) / mt * p.cache_read_per_mtok;
Cost {
usd,
version: format!("{}@{}", canonical, entry.raw_effective_from),
}
}
pub fn canonicalize<'a>(&self, model: &'a str) -> &'a str {
if self.timelines.contains_key(model) {
return model;
}
if let Some(stripped) = strip_date_suffix(model) {
if self.timelines.contains_key(stripped) {
return stripped;
}
}
model
}
fn pick_entry(&self, canonical_model: &str, event_timestamp: &str) -> PricedEntry {
let timeline = match self.timelines.get(canonical_model) {
Some(t) if !t.is_empty() => t,
_ => return self.fallback.clone(),
};
let event_at = parse_iso(event_timestamp);
match event_at {
Ok(t) => {
let mut chosen: Option<&PricedEntry> = None;
for e in timeline {
if e.effective_at <= t {
chosen = Some(e);
} else {
break;
}
}
chosen.cloned().unwrap_or_else(|| timeline[0].clone())
}
Err(_) => timeline
.last()
.cloned()
.unwrap_or_else(|| self.fallback.clone()),
}
}
pub fn list_all(&self) -> Vec<&ModelPrice> {
let mut out: Vec<&ModelPrice> = Vec::new();
let mut keys: Vec<&String> = self.timelines.keys().collect();
keys.sort();
for k in keys {
for e in &self.timelines[k] {
out.push(&e.price);
}
}
out
}
}
fn parse(s: &str) -> Result<Vec<ModelPrice>> {
let f: PricingFile = toml::from_str(s)?;
Ok(f.models)
}
fn parse_iso(s: &str) -> Result<DateTime<Utc>> {
if let Ok(t) = DateTime::parse_from_rfc3339(s) {
return Ok(t.with_timezone(&Utc));
}
if let Ok(d) = chrono::NaiveDate::parse_from_str(s, "%Y-%m-%d") {
let dt = d.and_hms_opt(0, 0, 0).context("invalid date")?;
return Ok(DateTime::<Utc>::from_naive_utc_and_offset(dt, Utc));
}
anyhow::bail!("not a recognized timestamp: `{s}`")
}
fn strip_date_suffix(s: &str) -> Option<&str> {
let bytes = s.as_bytes();
if bytes.len() < 9 {
return None;
}
let n = bytes.len();
if bytes[n - 9] != b'-' {
return None;
}
if !bytes[n - 8..].iter().all(u8::is_ascii_digit) {
return None;
}
Some(&s[..n - 9])
}
pub fn overrides_path() -> Result<PathBuf> {
Ok(paths::data_dir()?.join("pricing_overrides.toml"))
}
fn load_overrides() -> Result<Option<Vec<ModelPrice>>> {
let path = overrides_path()?;
if !path.exists() {
return Ok(None);
}
let text =
std::fs::read_to_string(&path).with_context(|| format!("reading {}", path.display()))?;
let entries = parse(&text).with_context(|| format!("parsing {}", path.display()))?;
Ok(Some(entries))
}
pub fn append_overrides(new_entries: &[ModelPrice]) -> Result<usize> {
paths::ensure_data_dir()?;
let mut existing: Vec<ModelPrice> = load_overrides()?.unwrap_or_default();
let mut added = 0usize;
let key = |m: &ModelPrice| {
let dt = parse_iso(&m.effective_from).map_or_else(
|_| m.effective_from.clone(),
|t| t.format("%Y-%m-%d").to_string(),
);
(m.name.clone(), dt)
};
let existing_keys: std::collections::HashSet<(String, String)> =
existing.iter().map(key).collect();
for m in new_entries {
if !existing_keys.contains(&key(m)) {
existing.push(m.clone());
added += 1;
}
}
if added == 0 {
return Ok(0);
}
let out = OverridesFile { models: existing };
let text = toml::to_string_pretty(&out).context("serializing overrides")?;
let path = overrides_path()?;
std::fs::write(&path, text).with_context(|| format!("writing {}", path.display()))?;
Ok(added)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn strips_date_suffix() {
assert_eq!(
strip_date_suffix("claude-haiku-4-5-20251001"),
Some("claude-haiku-4-5")
);
assert_eq!(strip_date_suffix("claude-opus-4-7"), None);
assert_eq!(strip_date_suffix("foo-12345678"), Some("foo"));
assert_eq!(strip_date_suffix("short"), None);
}
#[test]
fn picks_most_recent_qualifying_entry() {
let entries = vec![
ModelPrice {
name: "x".into(),
effective_from: "2025-01-01T00:00:00Z".into(),
input_per_mtok: 1.0,
output_per_mtok: 1.0,
cache_write_5m_per_mtok: 0.0,
cache_write_1h_per_mtok: 0.0,
cache_read_per_mtok: 0.0,
},
ModelPrice {
name: "x".into(),
effective_from: "2026-01-01T00:00:00Z".into(),
input_per_mtok: 2.0,
output_per_mtok: 2.0,
cache_write_5m_per_mtok: 0.0,
cache_write_1h_per_mtok: 0.0,
cache_read_per_mtok: 0.0,
},
ModelPrice {
name: "__unknown__".into(),
effective_from: "2024-01-01T00:00:00Z".into(),
input_per_mtok: 99.0,
output_per_mtok: 99.0,
cache_write_5m_per_mtok: 0.0,
cache_write_1h_per_mtok: 0.0,
cache_read_per_mtok: 0.0,
},
];
let p = Pricing::from_entries(entries).unwrap();
let c = p.compute("x", "2025-06-01T00:00:00Z", 1_000_000, 0, 0, 0, 0);
assert!((c.usd - 1.0).abs() < 1e-9);
assert_eq!(c.version, "x@2025-01-01");
let c = p.compute("x", "2026-06-01T00:00:00Z", 1_000_000, 0, 0, 0, 0);
assert!((c.usd - 2.0).abs() < 1e-9);
assert_eq!(c.version, "x@2026-01-01");
let c = p.compute("x", "2024-06-01T00:00:00Z", 1_000_000, 0, 0, 0, 0);
assert!((c.usd - 1.0).abs() < 1e-9);
}
}