use std::collections::HashMap;
use std::sync::{Mutex, OnceLock};
use crate::types::{ModelId, Usage};
#[derive(Debug, Clone)]
pub struct PricingTable {
rates: HashMap<ModelId, ModelPricing>,
}
#[derive(Debug, Clone, Copy, PartialEq)]
#[non_exhaustive]
pub struct ModelPricing {
pub input_per_mtok: f64,
pub output_per_mtok: f64,
pub cache_creation_5m_per_mtok: f64,
pub cache_creation_1h_per_mtok: f64,
pub cache_read_per_mtok: f64,
pub web_search_per_request: f64,
}
impl ModelPricing {
#[must_use]
pub const fn from_input_output(
input_per_mtok: f64,
output_per_mtok: f64,
web_search_per_request: f64,
) -> Self {
Self {
input_per_mtok,
output_per_mtok,
cache_creation_5m_per_mtok: input_per_mtok * 1.25,
cache_creation_1h_per_mtok: input_per_mtok * 2.0,
cache_read_per_mtok: input_per_mtok * 0.1,
web_search_per_request,
}
}
}
include!(concat!(env!("OUT_DIR"), "/pricing_data.rs"));
impl Default for PricingTable {
fn default() -> Self {
Self {
rates: bundled_rates().into_iter().collect(),
}
}
}
impl PricingTable {
#[must_use]
pub fn custom(rates: HashMap<ModelId, ModelPricing>) -> Self {
Self { rates }
}
pub fn set(&mut self, model: ModelId, rates: ModelPricing) {
self.rates.insert(model, rates);
}
#[must_use]
pub fn get(&self, model: &ModelId) -> Option<&ModelPricing> {
self.rates.get(model)
}
#[must_use]
pub fn cost(&self, model: &ModelId, usage: &Usage) -> f64 {
self.cost_breakdown(model, usage).total
}
#[must_use]
pub fn cost_breakdown(&self, model: &ModelId, usage: &Usage) -> CostBreakdown {
let Some(rates) = self.rates.get(model) else {
warn_missing_once(model.as_str());
return CostBreakdown::default();
};
let input = f64::from(usage.input_tokens) / 1_000_000.0 * rates.input_per_mtok;
let output = f64::from(usage.output_tokens) / 1_000_000.0 * rates.output_per_mtok;
let cache_creation = match &usage.cache_creation {
Some(b) => {
f64::from(b.ephemeral_5m_input_tokens) / 1_000_000.0
* rates.cache_creation_5m_per_mtok
+ f64::from(b.ephemeral_1h_input_tokens) / 1_000_000.0
* rates.cache_creation_1h_per_mtok
}
None => {
f64::from(usage.cache_creation_input_tokens.unwrap_or(0)) / 1_000_000.0
* rates.cache_creation_5m_per_mtok
}
};
let cache_read = f64::from(usage.cache_read_input_tokens.unwrap_or(0)) / 1_000_000.0
* rates.cache_read_per_mtok;
let server_tool_use = usage.server_tool_use.as_ref().map_or(0.0, |s| {
f64::from(s.web_search_requests) * rates.web_search_per_request
});
let total = input + output + cache_creation + cache_read + server_tool_use;
CostBreakdown {
input,
output,
cache_creation,
cache_read,
server_tool_use,
total,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Default)]
#[non_exhaustive]
pub struct CostBreakdown {
pub input: f64,
pub output: f64,
pub cache_creation: f64,
pub cache_read: f64,
pub server_tool_use: f64,
pub total: f64,
}
fn warn_missing_once(model: &str) {
static WARNED: OnceLock<Mutex<std::collections::HashSet<String>>> = OnceLock::new();
let warned = WARNED.get_or_init(|| Mutex::new(std::collections::HashSet::new()));
let mut guard = warned
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
if guard.insert(model.to_owned()) {
tracing::warn!(
model,
"claude-api: no bundled pricing data; cost() will return 0. \
Override via PricingTable::custom or PricingTable::set."
);
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::{CacheCreationBreakdown, ServerToolUseUsage};
fn approx(a: f64, b: f64) {
assert!((a - b).abs() < 1e-9, "expected {b} (within 1e-9), got {a}");
}
#[test]
fn default_pricing_includes_known_models() {
let p = PricingTable::default();
assert!(p.get(&ModelId::OPUS_4_7).is_some());
assert!(p.get(&ModelId::SONNET_4_6).is_some());
assert!(p.get(&ModelId::HAIKU_4_5).is_some());
}
#[test]
fn cost_input_and_output_only() {
let p = PricingTable::default();
let usage = Usage {
input_tokens: 1_000_000,
output_tokens: 500_000,
..Usage::default()
};
approx(p.cost(&ModelId::SONNET_4_6, &usage), 10.5);
}
#[test]
fn cost_uses_per_ttl_breakdown_when_present() {
let p = PricingTable::default();
let usage = Usage {
input_tokens: 1_000_000,
output_tokens: 0,
cache_creation: Some(CacheCreationBreakdown {
ephemeral_5m_input_tokens: 1_000_000,
ephemeral_1h_input_tokens: 1_000_000,
}),
..Usage::default()
};
approx(p.cost(&ModelId::SONNET_4_6, &usage), 3.0 + 3.75 + 6.0);
}
#[test]
fn cost_falls_back_to_legacy_cache_field_when_no_breakdown() {
let p = PricingTable::default();
let usage = Usage {
input_tokens: 0,
output_tokens: 0,
cache_creation_input_tokens: Some(1_000_000),
cache_creation: None,
..Usage::default()
};
approx(p.cost(&ModelId::SONNET_4_6, &usage), 3.75);
}
#[test]
fn cost_includes_cache_reads() {
let p = PricingTable::default();
let usage = Usage {
cache_read_input_tokens: Some(1_000_000),
..Usage::default()
};
approx(p.cost(&ModelId::SONNET_4_6, &usage), 0.30);
}
#[test]
fn cost_includes_web_search_requests() {
let p = PricingTable::default();
let usage = Usage {
server_tool_use: Some(ServerToolUseUsage {
web_search_requests: 50,
}),
..Usage::default()
};
approx(p.cost(&ModelId::SONNET_4_6, &usage), 0.50);
}
#[test]
fn breakdown_components_sum_to_total() {
let p = PricingTable::default();
let usage = Usage {
input_tokens: 100_000,
output_tokens: 50_000,
cache_creation_input_tokens: Some(20_000),
cache_read_input_tokens: Some(80_000),
server_tool_use: Some(ServerToolUseUsage {
web_search_requests: 3,
}),
..Usage::default()
};
let b = p.cost_breakdown(&ModelId::SONNET_4_6, &usage);
approx(
b.input + b.output + b.cache_creation + b.cache_read + b.server_tool_use,
b.total,
);
}
#[test]
fn unknown_model_returns_zero_cost() {
let p = PricingTable::default();
let usage = Usage {
input_tokens: 1_000_000,
output_tokens: 1_000_000,
..Usage::default()
};
let cost = p.cost(&ModelId::custom("claude-future-foo"), &usage);
approx(cost, 0.0);
}
#[test]
fn custom_table_overrides_bundled_rates() {
let mut rates = HashMap::new();
rates.insert(
ModelId::SONNET_4_6,
ModelPricing::from_input_output(2.00, 10.00, 0.005),
);
let p = PricingTable::custom(rates);
let usage = Usage {
input_tokens: 1_000_000,
..Usage::default()
};
approx(p.cost(&ModelId::SONNET_4_6, &usage), 2.0);
}
#[test]
fn set_inserts_or_replaces_a_single_model() {
let mut p = PricingTable::default();
p.set(
ModelId::SONNET_4_6,
ModelPricing::from_input_output(99.99, 99.99, 0.0),
);
let usage = Usage {
input_tokens: 1_000_000,
..Usage::default()
};
approx(p.cost(&ModelId::SONNET_4_6, &usage), 99.99);
}
#[test]
fn from_input_output_derives_cache_multipliers() {
let r = ModelPricing::from_input_output(10.0, 50.0, 0.01);
approx(r.cache_creation_5m_per_mtok, 12.5);
approx(r.cache_creation_1h_per_mtok, 20.0);
approx(r.cache_read_per_mtok, 1.0);
}
}