use std::collections::HashMap;
use std::sync::OnceLock;
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelPricing {
pub input_per_million: f64,
pub output_per_million: f64,
pub cached_input_per_million: Option<f64>,
pub cache_write_per_million: Option<f64>,
pub effective_at: DateTime<Utc>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct ModelInfo {
pub id: String,
pub provider: String,
pub capabilities: Vec<Capability>,
pub max_input_tokens: u64,
pub max_output_tokens: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum Capability {
Text,
Vision,
Audio,
Tools,
JsonMode,
Streaming,
Reasoning,
PromptCaching,
}
const PRICING_TOML: &str = include_str!("../data/pricing.toml");
#[derive(Debug, Deserialize)]
struct RawEntry {
provider: String,
model: String,
input_per_million: f64,
output_per_million: f64,
#[serde(default)]
cached_input_per_million: Option<f64>,
#[serde(default)]
cache_write_per_million: Option<f64>,
effective_at: DateTime<Utc>,
}
#[derive(Debug, Deserialize)]
struct RawCatalog {
#[serde(default)]
entry: Vec<RawEntry>,
}
#[derive(Debug)]
pub struct PricingCatalog {
by_model: HashMap<(String, String), Vec<ModelPricing>>,
}
impl PricingCatalog {
pub fn parse(toml_text: &str) -> Result<Self, toml::de::Error> {
let raw: RawCatalog = toml::from_str(toml_text)?;
let mut by_model: HashMap<(String, String), Vec<ModelPricing>> = HashMap::new();
for e in raw.entry {
by_model
.entry((e.provider, e.model))
.or_default()
.push(ModelPricing {
input_per_million: e.input_per_million,
output_per_million: e.output_per_million,
cached_input_per_million: e.cached_input_per_million,
cache_write_per_million: e.cache_write_per_million,
effective_at: e.effective_at,
});
}
for history in by_model.values_mut() {
history.sort_by_key(|p| p.effective_at);
}
Ok(Self { by_model })
}
pub fn latest(&self, provider: &str, model: &str) -> Option<ModelPricing> {
self.by_model
.get(&(provider.to_string(), model.to_string()))?
.last()
.cloned()
}
pub fn at(&self, provider: &str, model: &str, at: DateTime<Utc>) -> Option<ModelPricing> {
let history = self
.by_model
.get(&(provider.to_string(), model.to_string()))?;
history
.iter()
.rev()
.find(|p| p.effective_at <= at)
.or_else(|| history.first())
.cloned()
}
pub fn latest_for_provider(&self, provider: &str) -> Vec<(String, ModelPricing)> {
self.by_model
.iter()
.filter(|((p, _), _)| p == provider)
.filter_map(|((_, model), history)| history.last().map(|p| (model.clone(), p.clone())))
.collect()
}
pub fn pairs(&self) -> Vec<(String, String)> {
self.by_model.keys().cloned().collect()
}
pub fn len(&self) -> usize {
self.by_model.len()
}
pub fn is_empty(&self) -> bool {
self.by_model.is_empty()
}
pub fn catalog_max_effective_at(&self) -> Option<DateTime<Utc>> {
self.by_model
.values()
.filter_map(|history| history.last().map(|p| p.effective_at))
.max()
}
}
pub fn catalog() -> &'static PricingCatalog {
static CATALOG: OnceLock<PricingCatalog> = OnceLock::new();
CATALOG.get_or_init(|| {
PricingCatalog::parse(PRICING_TOML).expect("embedded data/pricing.toml must be valid")
})
}
#[must_use]
pub fn is_stale(newest: Option<DateTime<Utc>>, now: DateTime<Utc>, max_days: i64) -> bool {
match newest {
Some(d) => (now - d).num_days() > max_days,
None => false,
}
}
#[cfg(test)]
mod catalog_tests {
use super::*;
use chrono::TimeZone;
#[test]
fn is_stale_thresholds() {
use chrono::Duration;
let now: DateTime<Utc> = "2026-06-05T00:00:00Z".parse().unwrap();
assert!(!is_stale(None, now, 90)); assert!(!is_stale(Some(now - Duration::days(10)), now, 90));
assert!(is_stale(Some(now - Duration::days(100)), now, 90));
}
#[test]
fn embedded_catalog_parses_and_is_populated() {
let c = catalog();
assert!(!c.is_empty(), "embedded catalog should not be empty");
assert_eq!(
c.len(),
36,
"unexpected catalog size — update if intentional"
);
}
#[test]
fn catalog_max_effective_at_is_present() {
let c = catalog();
let max_date = c
.catalog_max_effective_at()
.expect("non-empty catalog must have a max effective_at");
let floor = Utc.with_ymd_and_hms(2026, 1, 1, 0, 0, 0).unwrap();
assert!(
max_date >= floor,
"catalog_max_effective_at = {max_date} is older than expected floor {floor}"
);
}
#[test]
fn catalog_max_effective_at_picks_newest() {
let toml = r#"
[[entry]]
provider = "p"
model = "m1"
input_per_million = 1.0
output_per_million = 2.0
effective_at = "2026-03-01T00:00:00Z"
[[entry]]
provider = "p"
model = "m2"
input_per_million = 3.0
output_per_million = 4.0
effective_at = "2026-05-01T00:00:00Z"
"#;
let c = PricingCatalog::parse(toml).expect("valid");
let max = c.catalog_max_effective_at().expect("present");
assert_eq!(
max,
Utc.with_ymd_and_hms(2026, 5, 1, 0, 0, 0).unwrap(),
"should return the newest effective_at across all models"
);
}
#[test]
fn catalog_max_effective_at_empty_catalog() {
let c = PricingCatalog::parse("").expect("empty TOML is valid");
assert!(c.catalog_max_effective_at().is_none());
}
#[test]
fn latest_returns_known_rates() {
let c = catalog();
let p = c.latest("openai", "gpt-4o").expect("gpt-4o present");
assert_eq!(p.input_per_million, 2.50);
assert_eq!(p.output_per_million, 10.00);
assert_eq!(p.cached_input_per_million, Some(1.25));
let g = c.latest("groq", "llama-3.1-8b-instant").expect("present");
assert_eq!(g.cached_input_per_million, None);
}
#[test]
fn anthropic_models_have_cache_write_rate() {
let c = catalog();
let haiku = c.latest("anthropic", "claude-haiku-4-5").expect("present");
assert_eq!(
haiku.cache_write_per_million,
Some(1.25),
"haiku write rate = 1.25× base input (1.00)"
);
let sonnet = c.latest("anthropic", "claude-sonnet-4-6").expect("present");
assert_eq!(
sonnet.cache_write_per_million,
Some(3.75),
"sonnet write rate = 1.25× base input (3.00)"
);
let opus = c.latest("anthropic", "claude-opus-4-7").expect("present");
assert_eq!(
opus.cache_write_per_million,
Some(6.25),
"opus write rate = 1.25× base input (5.00)"
);
let gpt4o = c.latest("openai", "gpt-4o").expect("gpt-4o present");
assert_eq!(
gpt4o.cache_write_per_million, None,
"OpenAI has no cache-write premium"
);
let groq_llama = c.latest("groq", "llama-3.1-8b-instant").expect("present");
assert_eq!(
groq_llama.cache_write_per_million, None,
"Groq has no cache-write premium"
);
}
#[test]
fn unknown_provider_or_model_is_none() {
let c = catalog();
assert!(c.latest("openai", "no-such-model").is_none());
assert!(c.latest("no-such-provider", "gpt-4o").is_none());
}
#[test]
fn at_selects_rate_effective_at_timestamp() {
let toml = r#"
[[entry]]
provider = "p"
model = "m"
input_per_million = 1.0
output_per_million = 2.0
effective_at = "2026-01-01T00:00:00Z"
[[entry]]
provider = "p"
model = "m"
input_per_million = 3.0
output_per_million = 4.0
effective_at = "2026-06-01T00:00:00Z"
"#;
let c = PricingCatalog::parse(toml).expect("valid");
let before = c
.at("p", "m", Utc.with_ymd_and_hms(2025, 1, 1, 0, 0, 0).unwrap())
.unwrap();
assert_eq!(before.input_per_million, 1.0);
let mid = c
.at("p", "m", Utc.with_ymd_and_hms(2026, 3, 1, 0, 0, 0).unwrap())
.unwrap();
assert_eq!(mid.input_per_million, 1.0);
let after = c
.at("p", "m", Utc.with_ymd_and_hms(2026, 9, 1, 0, 0, 0).unwrap())
.unwrap();
assert_eq!(after.input_per_million, 3.0);
assert_eq!(c.latest("p", "m").unwrap().input_per_million, 3.0);
}
}