use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, Copy, Default)]
pub struct TokenUsage {
pub input: u64,
pub output: u64,
pub cache_creation: u64,
pub cache_read: u64,
}
#[derive(Debug, Clone, Deserialize, Default)]
#[serde(default)]
pub struct PricingEntry {
pub input_cost_per_token: f64,
pub output_cost_per_token: f64,
pub cache_creation_input_token_cost: Option<f64>,
pub cache_read_input_token_cost: Option<f64>,
pub litellm_provider: Option<String>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize)]
#[serde(rename_all = "lowercase")]
pub enum CostSource {
Litellm,
Fallback,
None,
}
impl CostSource {
pub fn as_str(&self) -> &'static str {
match self {
CostSource::Litellm => "litellm",
CostSource::Fallback => "fallback",
CostSource::None => "none",
}
}
}
#[derive(Debug, Clone)]
pub struct CostResult {
pub cost: Option<f64>,
pub source: CostSource,
pub reason: Option<String>,
}
impl CostResult {
fn none(model: Option<&str>, system: Option<&str>) -> Self {
let reason = match (model, system) {
(Some(m), Some(s)) => Some(format!("no pricing data for {m} on {s}")),
(Some(m), None) => Some(format!("no pricing data for {m}")),
_ => Some("no pricing data (missing model)".to_string()),
};
Self {
cost: None,
source: CostSource::None,
reason,
}
}
}
struct FallbackEntry {
input: f64,
output: f64,
cache_5m: f64,
cache_1h: f64,
cache_read: f64,
}
const MILLION: f64 = 1_000_000.0;
const CLAUDE_OPUS_FALLBACK: FallbackEntry = FallbackEntry {
input: 15.0 / MILLION,
output: 75.0 / MILLION,
cache_5m: 18.75 / MILLION,
cache_1h: 30.0 / MILLION,
cache_read: 1.5 / MILLION,
};
const CLAUDE_SONNET_FALLBACK: FallbackEntry = FallbackEntry {
input: 3.0 / MILLION,
output: 15.0 / MILLION,
cache_5m: 3.75 / MILLION,
cache_1h: 6.0 / MILLION,
cache_read: 0.3 / MILLION,
};
const CLAUDE_HAIKU_FALLBACK: FallbackEntry = FallbackEntry {
input: 1.0 / MILLION,
output: 5.0 / MILLION,
cache_5m: 1.25 / MILLION,
cache_1h: 2.0 / MILLION,
cache_read: 0.1 / MILLION,
};
pub const FALLBACK_LAST_VERIFIED: &str = "2026-05-07";
pub const LITELLM_SOURCE_URL: &str =
"https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json";
pub const LITELLM_RAW_URL: &str =
"https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json";
pub const LITELLM_LICENSE: &str = "MIT — © 2023 Berri AI";
fn claude_fallback(model: &str) -> Option<&'static FallbackEntry> {
let m = model.to_ascii_lowercase();
if m.contains("opus") {
Some(&CLAUDE_OPUS_FALLBACK)
} else if m.contains("sonnet") {
Some(&CLAUDE_SONNET_FALLBACK)
} else if m.contains("haiku") {
Some(&CLAUDE_HAIKU_FALLBACK)
} else {
None
}
}
#[derive(Debug, Clone, Default)]
pub struct PricingDatabase {
entries: HashMap<String, PricingEntry>,
loaded_from_litellm: bool,
}
impl PricingDatabase {
pub fn empty() -> Self {
Self::default()
}
pub fn from_litellm_json(raw: &str) -> serde_json::Result<Self> {
let map: HashMap<String, serde_json::Value> = serde_json::from_str(raw)?;
let mut entries = HashMap::with_capacity(map.len());
for (k, v) in map {
if let Ok(entry) = serde_json::from_value::<PricingEntry>(v) {
entries.insert(k, entry);
}
}
Ok(Self {
entries,
loaded_from_litellm: true,
})
}
pub fn len(&self) -> usize {
self.entries.len()
}
pub fn is_empty(&self) -> bool {
self.entries.is_empty()
}
pub fn is_litellm(&self) -> bool {
self.loaded_from_litellm
}
fn lookup(&self, model: &str, system: Option<&str>) -> Option<&PricingEntry> {
if self.entries.is_empty() {
return None;
}
let lower = model.to_ascii_lowercase();
let sys_lower = system.map(|s| s.to_ascii_lowercase());
let mut candidates: Vec<String> = Vec::with_capacity(6);
candidates.push(model.to_string());
candidates.push(lower.clone());
if let Some(s) = sys_lower.as_deref() {
candidates.push(format!("{s}/{model}"));
candidates.push(format!("{s}/{lower}"));
match s {
"aws.bedrock" | "bedrock" => candidates.push(format!("bedrock/{lower}")),
"gcp.vertex_ai" | "vertex" | "vertex_ai" => {
candidates.push(format!("vertex_ai/{lower}"))
},
_ => {},
}
}
for k in &candidates {
if let Some(e) = self.entries.get(k) {
return Some(e);
}
}
let mut best: Option<(&str, &PricingEntry)> = None;
for (k, v) in &self.entries {
if k.to_ascii_lowercase().contains(&lower) {
match best {
None => best = Some((k, v)),
Some((bk, _)) if k.len() < bk.len() => best = Some((k, v)),
_ => {},
}
}
}
best.map(|(_, v)| v)
}
pub fn compute_cost(
&self,
model: Option<&str>,
usage: TokenUsage,
system: Option<&str>,
) -> CostResult {
let Some(model) = model else {
return CostResult::none(None, system);
};
if let Some(entry) = self.lookup(model, system) {
if entry.input_cost_per_token > 0.0 || entry.output_cost_per_token > 0.0 {
let cct = entry
.cache_creation_input_token_cost
.unwrap_or(entry.input_cost_per_token);
let crt = entry.cache_read_input_token_cost.unwrap_or(0.0);
let cost = (usage.input as f64) * entry.input_cost_per_token
+ (usage.output as f64) * entry.output_cost_per_token
+ (usage.cache_creation as f64) * cct
+ (usage.cache_read as f64) * crt;
return CostResult {
cost: Some(cost),
source: CostSource::Litellm,
reason: None,
};
}
}
if let Some(fb) = claude_fallback(model) {
let cost = (usage.input as f64) * fb.input
+ (usage.output as f64) * fb.output
+ (usage.cache_creation as f64) * fb.cache_5m
+ (usage.cache_read as f64) * fb.cache_read;
let _ = fb.cache_1h;
return CostResult {
cost: Some(cost),
source: CostSource::Fallback,
reason: None,
};
}
CostResult::none(Some(model), system)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn u(input: u64, output: u64) -> TokenUsage {
TokenUsage {
input,
output,
..Default::default()
}
}
#[test]
fn empty_db_uses_claude_fallback_for_sonnet() {
let db = PricingDatabase::empty();
let result = db.compute_cost(Some("claude-sonnet-4"), u(1_000_000, 1_000_000), None);
assert_eq!(result.source, CostSource::Fallback);
let cost = result.cost.unwrap();
assert!((cost - 18.0).abs() < 1e-9, "got {cost}");
}
#[test]
fn claude_fallback_matches_haiku() {
let db = PricingDatabase::empty();
let result = db.compute_cost(Some("claude-haiku-4.5"), u(1_000_000, 0), None);
assert_eq!(result.source, CostSource::Fallback);
assert!((result.cost.unwrap() - 1.0).abs() < 1e-9);
}
#[test]
fn fallback_applies_cache_creation_and_read_rates() {
let db = PricingDatabase::empty();
let usage = TokenUsage {
input: 0,
output: 0,
cache_creation: 1_000_000,
cache_read: 1_000_000,
};
let result = db.compute_cost(Some("claude-sonnet-4"), usage, None);
assert!((result.cost.unwrap() - 4.05).abs() < 1e-9);
}
#[test]
fn unknown_model_returns_none_with_reason() {
let db = PricingDatabase::empty();
let result = db.compute_cost(Some("gpt-9000"), u(1, 1), None);
assert_eq!(result.source, CostSource::None);
assert!(result.cost.is_none());
assert!(result.reason.unwrap().contains("gpt-9000"));
}
#[test]
fn missing_model_returns_none() {
let db = PricingDatabase::empty();
let result = db.compute_cost(None, u(1, 1), Some("openai"));
assert_eq!(result.source, CostSource::None);
assert!(result.cost.is_none());
}
#[test]
fn litellm_parse_and_lookup_exact_key() {
let json = r#"{
"gpt-4o": {
"input_cost_per_token": 2.5e-6,
"output_cost_per_token": 1.0e-5,
"litellm_provider": "openai"
},
"sample_spec": { "note": "ignored" }
}"#;
let db = PricingDatabase::from_litellm_json(json).unwrap();
assert!(db.is_litellm());
let result = db.compute_cost(Some("gpt-4o"), u(1_000_000, 1_000_000), None);
assert_eq!(result.source, CostSource::Litellm);
assert!((result.cost.unwrap() - 12.5).abs() < 1e-6);
}
#[test]
fn litellm_applies_cache_token_costs_when_present() {
let json = r#"{
"claude-sonnet-4": {
"input_cost_per_token": 3e-6,
"output_cost_per_token": 1.5e-5,
"cache_creation_input_token_cost": 3.75e-6,
"cache_read_input_token_cost": 3e-7
}
}"#;
let db = PricingDatabase::from_litellm_json(json).unwrap();
let usage = TokenUsage {
input: 0,
output: 0,
cache_creation: 1_000_000,
cache_read: 1_000_000,
};
let result = db.compute_cost(Some("claude-sonnet-4"), usage, None);
assert_eq!(result.source, CostSource::Litellm);
assert!((result.cost.unwrap() - 4.05).abs() < 1e-6);
}
#[test]
fn litellm_lookup_tries_bedrock_prefix() {
let json = r#"{
"bedrock/anthropic.claude-sonnet-4-v1:0": {
"input_cost_per_token": 5e-6,
"output_cost_per_token": 2e-5
}
}"#;
let db = PricingDatabase::from_litellm_json(json).unwrap();
let result = db.compute_cost(
Some("anthropic.claude-sonnet-4-v1:0"),
u(1_000_000, 0),
Some("bedrock"),
);
assert_eq!(result.source, CostSource::Litellm);
assert!((result.cost.unwrap() - 5.0).abs() < 1e-6);
}
#[test]
fn litellm_substring_match_as_last_resort() {
let json = r#"{
"claude-sonnet-4-20260101": {
"input_cost_per_token": 3e-6,
"output_cost_per_token": 1.5e-5
}
}"#;
let db = PricingDatabase::from_litellm_json(json).unwrap();
let result = db.compute_cost(Some("sonnet-4"), u(1_000_000, 0), None);
assert_eq!(result.source, CostSource::Litellm);
assert!((result.cost.unwrap() - 3.0).abs() < 1e-6);
}
#[test]
fn invalid_litellm_entries_are_skipped_not_fatal() {
let json = r#"{
"sample_spec": { "input_cost_per_token": "unknown" },
"gpt-4o": {
"input_cost_per_token": 2.5e-6,
"output_cost_per_token": 1e-5
}
}"#;
let db = PricingDatabase::from_litellm_json(json).unwrap();
assert!(!db.is_empty());
let result = db.compute_cost(Some("gpt-4o"), u(1_000_000, 0), None);
assert_eq!(result.source, CostSource::Litellm);
}
#[test]
fn litellm_zero_cost_entry_falls_through_to_fallback() {
let json = r#"{
"claude-sonnet-4": {
"input_cost_per_token": 0,
"output_cost_per_token": 0
}
}"#;
let db = PricingDatabase::from_litellm_json(json).unwrap();
let result = db.compute_cost(Some("claude-sonnet-4"), u(1_000_000, 0), None);
assert_eq!(result.source, CostSource::Fallback);
}
#[test]
fn cost_source_serializes_lowercase() {
let json = serde_json::to_string(&CostSource::Litellm).unwrap();
assert_eq!(json, "\"litellm\"");
}
}