use serde::Deserialize;
use std::collections::HashMap;
use std::fs;
use std::path::Path;
use crate::pricing_data;
#[derive(Debug, Clone, Copy, Deserialize)]
pub struct ModelPrice {
pub input_per_mtok: f64,
pub output_per_mtok: f64,
}
#[derive(Debug, Clone, Default, Deserialize)]
pub struct PriceTable {
#[serde(default)]
pub models: HashMap<String, ModelPrice>,
}
pub fn prices_updated() -> &'static str { pricing_data::PRICES_UPDATED }
pub fn prices_source() -> &'static str { pricing_data::PRICES_SOURCE }
const LOCAL_MARKERS: &[&str] = &[
"ollama/", "ollama_chat/", "ollama:",
"lmstudio/", "lm-studio/", "vllm/", "llama-cpp/", "llamacpp/",
"localhost:", "127.0.0.1:", "huggingface/",
];
pub fn is_local_model(model: &str) -> bool {
if model.is_empty() { return false; }
let lower = model.to_ascii_lowercase();
LOCAL_MARKERS.iter().any(|m| lower.contains(m))
}
impl PriceTable {
pub fn builtin() -> Self {
let mut m: HashMap<String, ModelPrice> = HashMap::with_capacity(
pricing_data::GENERATED.len() + 32,
);
for (k, p) in pricing_data::GENERATED {
m.insert((*k).to_string(), *p);
}
let put = |m: &mut HashMap<String, ModelPrice>, k: &str, i: f64, o: f64| {
m.insert(k.into(), ModelPrice { input_per_mtok: i, output_per_mtok: o });
};
put(&mut m, "claude-sonnet-4-5", 3.00, 15.00);
put(&mut m, "claude-sonnet-4-6", 3.00, 15.00);
put(&mut m, "claude-sonnet-4-7", 3.00, 15.00);
put(&mut m, "claude-opus-4-1", 15.00, 75.00);
put(&mut m, "claude-opus-4-7", 15.00, 75.00);
put(&mut m, "claude-haiku-4-5", 0.80, 4.00);
put(&mut m, "claude-3-5-sonnet", 3.00, 15.00);
put(&mut m, "claude-3-5-haiku", 0.80, 4.00);
put(&mut m, "claude-3-opus", 15.00, 75.00);
put(&mut m, "gpt-5", 1.25, 10.00);
put(&mut m, "gpt-5-mini", 0.25, 2.00);
put(&mut m, "gpt-5-nano", 0.05, 0.40);
put(&mut m, "gpt-4o", 2.50, 10.00);
put(&mut m, "gpt-4o-mini", 0.15, 0.60);
put(&mut m, "gpt-4-turbo", 10.00, 30.00);
put(&mut m, "o1", 15.00, 60.00);
put(&mut m, "o1-mini", 1.10, 4.40);
put(&mut m, "o3", 2.00, 8.00);
put(&mut m, "o3-mini", 1.10, 4.40);
put(&mut m, "gemini-2.0-flash", 0.10, 0.40);
put(&mut m, "gemini-1.5-pro", 1.25, 5.00);
put(&mut m, "gemini-1.5-flash", 0.075, 0.30);
Self { models: m }
}
pub fn load(path: &Path) -> anyhow::Result<Self> {
let text = fs::read_to_string(path)?;
let parsed: PriceTable = toml::from_str(&text)?;
Ok(parsed)
}
pub fn merge(mut self, other: PriceTable) -> Self {
for (k, v) in other.models {
self.models.insert(k, v);
}
self
}
pub fn lookup(&self, model: &str) -> Option<ModelPrice> {
if let Some(p) = self.models.get(model) { return Some(*p); }
let mut s = model;
for _ in 0..4 {
let Some(i) = s.rfind('-') else { break };
s = &s[..i];
if let Some(p) = self.models.get(s) { return Some(*p); }
}
None
}
pub fn cost(&self, model: &str, in_tok: u64, out_tok: u64) -> f64 {
if is_local_model(model) {
return 0.0;
}
match self.lookup(model) {
Some(p) => (in_tok as f64 / 1_000_000.0) * p.input_per_mtok
+ (out_tok as f64 / 1_000_000.0) * p.output_per_mtok,
None => 0.0,
}
}
}
pub fn format_cost(usd: f64) -> String {
if usd <= 0.0 { return "—".into(); }
if usd < 0.01 { return "<$0.01".into(); }
if usd < 10.0 { return format!("${:.2}", usd); }
if usd < 1000.0 { return format!("${:.1}", usd); }
if usd < 1_000_000.0 { return format!("${:.1}k", usd / 1000.0); }
format!("${:.1}M", usd / 1_000_000.0)
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CostBasis {
Api,
Local,
Unknown,
}
pub fn cost_basis(table: &PriceTable, model: &str) -> CostBasis {
if is_local_model(model) { return CostBasis::Local; }
if model.is_empty() { return CostBasis::Unknown; }
if table.lookup(model).is_some() { CostBasis::Api } else { CostBasis::Unknown }
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn lookup_strips_date_suffixes() {
let t = PriceTable::builtin();
let p = t.lookup("claude-sonnet-4-7-20260101").unwrap();
assert_eq!(p.input_per_mtok, 3.0);
}
#[test]
fn cost_math_is_per_million() {
let t = PriceTable::builtin();
let c = t.cost("claude-sonnet-4-7", 1_000_000, 0);
assert!((c - 3.0).abs() < 1e-6);
let c = t.cost("claude-sonnet-4-7", 0, 1_000_000);
assert!((c - 15.0).abs() < 1e-6);
}
#[test]
fn unknown_model_is_zero_cost() {
let t = PriceTable::builtin();
assert_eq!(t.cost("totally-made-up-model", 999_999, 999_999), 0.0);
}
#[test]
fn format_cost_buckets() {
assert_eq!(format_cost(0.0), "—");
assert_eq!(format_cost(0.001), "<$0.01");
assert_eq!(format_cost(0.04), "$0.04");
assert_eq!(format_cost(1.23), "$1.23");
assert_eq!(format_cost(42.10), "$42.1");
assert_eq!(format_cost(1234.0), "$1.2k");
}
#[test]
fn local_models_short_circuit_to_zero() {
let t = PriceTable::builtin();
assert!(is_local_model("ollama/llama3"));
assert!(is_local_model("Ollama:codellama"));
assert!(is_local_model("vllm/mistral-7b"));
assert_eq!(t.cost("ollama/llama3", 5_000_000, 5_000_000), 0.0);
assert_eq!(cost_basis(&t, "ollama/llama3"), CostBasis::Local);
}
#[test]
fn cost_basis_classifies_three_buckets() {
let t = PriceTable::builtin();
assert_eq!(cost_basis(&t, "claude-sonnet-4-7"), CostBasis::Api);
assert_eq!(cost_basis(&t, "ollama/llama3"), CostBasis::Local);
assert_eq!(cost_basis(&t, "totally-made-up"), CostBasis::Unknown);
assert_eq!(cost_basis(&t, ""), CostBasis::Unknown);
}
#[test]
fn generated_table_has_substantial_coverage() {
assert!(pricing_data::GENERATED.len() > 500,
"generated table only has {} models", pricing_data::GENERATED.len());
}
}