use once_cell::sync::Lazy;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelPricing {
pub input_price_per_million: f64,
pub output_price_per_million: f64,
pub cache_read_multiplier: f64,
pub cache_write_multiplier: f64,
}
impl ModelPricing {
pub fn default_average() -> Self {
Self {
input_price_per_million: 3.2, output_price_per_million: 16.0, cache_read_multiplier: 0.1,
cache_write_multiplier: 1.25,
}
}
}
static PRICING_TABLE: Lazy<HashMap<&'static str, ModelPricing>> = Lazy::new(|| {
let mut m = HashMap::new();
let opus_45_pricing = ModelPricing {
input_price_per_million: 5.0,
output_price_per_million: 25.0,
cache_read_multiplier: 0.1, cache_write_multiplier: 1.25, };
m.insert("claude-opus-4-5-20251101", opus_45_pricing.clone());
m.insert("claude-opus-4-6-20250212", opus_45_pricing.clone());
m.insert("opus-4", opus_45_pricing.clone());
m.insert("claude-opus-4", opus_45_pricing);
let opus_legacy_pricing = ModelPricing {
input_price_per_million: 15.0,
output_price_per_million: 75.0,
cache_read_multiplier: 0.1, cache_write_multiplier: 1.25, };
m.insert("claude-opus-4-0-20250514", opus_legacy_pricing.clone());
m.insert("claude-opus-4-1-20250805", opus_legacy_pricing);
let sonnet_pricing = ModelPricing {
input_price_per_million: 3.0,
output_price_per_million: 15.0,
cache_read_multiplier: 0.1, cache_write_multiplier: 1.25, };
m.insert("claude-sonnet-4-5-20250929", sonnet_pricing.clone());
m.insert("sonnet-4", sonnet_pricing.clone());
m.insert("claude-sonnet-4", sonnet_pricing);
let haiku_pricing = ModelPricing {
input_price_per_million: 1.0,
output_price_per_million: 5.0,
cache_read_multiplier: 0.1, cache_write_multiplier: 1.25, };
m.insert("claude-haiku-4-5-20251001", haiku_pricing.clone());
m.insert("haiku-4", haiku_pricing.clone());
m.insert("claude-haiku-4", haiku_pricing);
m
});
pub static MODEL_PRICING: Lazy<HashMap<String, ModelPricing>> = Lazy::new(|| {
PRICING_TABLE
.iter()
.map(|(k, v)| (k.to_string(), v.clone()))
.collect()
});
pub fn get_model_pricing(model: &str) -> ModelPricing {
PRICING_TABLE
.get(model)
.cloned()
.unwrap_or_else(ModelPricing::default_average)
}
#[deprecated(note = "Use get_model_pricing instead")]
pub fn get_pricing(model: &str) -> ModelPricing {
get_model_pricing(model)
}
pub fn calculate_cost(
model: &str,
input: u64,
output: u64,
cache_create: u64,
cache_read: u64,
) -> f64 {
let pricing = get_model_pricing(model);
let input_cost = (input as f64 / 1_000_000.0) * pricing.input_price_per_million;
let output_cost = (output as f64 / 1_000_000.0) * pricing.output_price_per_million;
let cache_create_cost = (cache_create as f64 / 1_000_000.0)
* pricing.input_price_per_million
* pricing.cache_write_multiplier;
let cache_read_cost = (cache_read as f64 / 1_000_000.0)
* pricing.input_price_per_million
* pricing.cache_read_multiplier;
input_cost + output_cost + cache_create_cost + cache_read_cost
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_opus_pricing() {
let pricing = get_pricing("opus-4");
assert_eq!(pricing.input_price_per_million, 5.0);
assert_eq!(pricing.output_price_per_million, 25.0);
assert_eq!(pricing.cache_read_multiplier, 0.1);
assert_eq!(pricing.cache_write_multiplier, 1.25);
}
#[test]
fn test_opus_legacy_pricing() {
let pricing = get_pricing("claude-opus-4-1-20250805");
assert_eq!(pricing.input_price_per_million, 15.0);
assert_eq!(pricing.output_price_per_million, 75.0);
assert_eq!(pricing.cache_read_multiplier, 0.1);
assert_eq!(pricing.cache_write_multiplier, 1.25);
}
#[test]
fn test_sonnet_pricing() {
let pricing = get_pricing("sonnet-4");
assert_eq!(pricing.input_price_per_million, 3.0);
assert_eq!(pricing.output_price_per_million, 15.0);
assert_eq!(pricing.cache_read_multiplier, 0.1);
assert_eq!(pricing.cache_write_multiplier, 1.25);
}
#[test]
fn test_haiku_pricing() {
let pricing = get_pricing("haiku-4");
assert_eq!(pricing.input_price_per_million, 1.0);
assert_eq!(pricing.output_price_per_million, 5.0);
assert_eq!(pricing.cache_read_multiplier, 0.1);
assert_eq!(pricing.cache_write_multiplier, 1.25);
}
#[test]
fn test_full_model_id() {
let pricing = get_pricing("claude-sonnet-4-5-20250929");
assert_eq!(pricing.input_price_per_million, 3.0);
}
#[test]
fn test_unknown_model_fallback() {
let pricing = get_pricing("unknown-model-xyz");
assert_eq!(pricing.input_price_per_million, 3.2); assert_eq!(pricing.output_price_per_million, 16.0);
assert_eq!(pricing.cache_write_multiplier, 1.25);
}
#[test]
fn test_cost_calculation_opus_basic() {
let cost = calculate_cost("opus-4", 1_000_000, 1_000_000, 0, 0);
assert_eq!(cost, 30.0);
}
#[test]
fn test_cost_calculation_opus_legacy() {
let cost = calculate_cost("claude-opus-4-1-20250805", 1_000_000, 1_000_000, 0, 0);
assert_eq!(cost, 90.0);
}
#[test]
fn test_cost_calculation_sonnet_basic() {
let cost = calculate_cost("sonnet-4", 1_000_000, 1_000_000, 0, 0);
assert_eq!(cost, 18.0);
}
#[test]
fn test_cost_calculation_haiku_basic() {
let cost = calculate_cost("haiku-4", 1_000_000, 1_000_000, 0, 0);
assert_eq!(cost, 6.0);
}
#[test]
fn test_cost_calculation_with_cache() {
let cost = calculate_cost("opus-4", 1_000_000, 0, 1_000_000, 10_000_000);
assert_eq!(cost, 16.25);
}
#[test]
fn test_cost_calculation_zero_tokens() {
let cost = calculate_cost("opus-4", 0, 0, 0, 0);
assert_eq!(cost, 0.0);
}
#[test]
fn test_cost_calculation_small_numbers() {
let cost = calculate_cost("sonnet-4", 10_000, 0, 0, 0);
assert_eq!(cost, 0.03);
}
#[test]
fn test_cost_calculation_mixed_tokens() {
let cost = calculate_cost("sonnet-4", 500_000, 100_000, 50_000, 1_000_000);
let expected = 1.5 + 1.5 + 0.1875 + 0.3;
assert!((cost - expected).abs() < 0.0001);
}
#[test]
fn test_total_tokens_includes_cache_read() {
let input = 1000u64;
let output = 500u64;
let cache_create = 100u64;
let cache_read = 50000u64;
let total = input + output + cache_create + cache_read;
assert_eq!(total, 51600);
let cost = calculate_cost("sonnet-4", input, output, cache_create, cache_read);
assert!(cost > 0.0);
}
}