use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelCosts {
pub input_tokens: f64,
pub output_tokens: f64,
pub prompt_cache_write_tokens: f64,
pub prompt_cache_read_tokens: f64,
pub web_search_requests: f64,
}
impl ModelCosts {
pub fn input_cost(&self, tokens: u32) -> f64 {
(tokens as f64 / 1_000_000.0) * self.input_tokens
}
pub fn output_cost(&self, tokens: u32) -> f64 {
(tokens as f64 / 1_000_000.0) * self.output_tokens
}
pub fn cache_write_cost(&self, tokens: u32) -> f64 {
(tokens as f64 / 1_000_000.0) * self.prompt_cache_write_tokens
}
pub fn cache_read_cost(&self, tokens: u32) -> f64 {
(tokens as f64 / 1_000_000.0) * self.prompt_cache_read_tokens
}
pub fn total_cost(&self, usage: &TokenUsage) -> f64 {
self.input_cost(usage.input_tokens)
+ self.output_cost(usage.output_tokens)
+ self.cache_write_cost(usage.prompt_cache_write_tokens)
+ self.cache_read_cost(usage.prompt_cache_read_tokens)
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct TokenUsage {
pub input_tokens: u32,
pub output_tokens: u32,
#[serde(rename = "promptCacheWriteTokens")]
pub prompt_cache_write_tokens: u32,
#[serde(rename = "promptCacheReadTokens")]
pub prompt_cache_read_tokens: u32,
}
impl TokenUsage {
pub fn total(&self) -> u32 {
self.input_tokens
+ self.output_tokens
+ self.prompt_cache_write_tokens
+ self.prompt_cache_read_tokens
}
}
pub const COST_TIER_3_15: ModelCosts = ModelCosts {
input_tokens: 3.0,
output_tokens: 15.0,
prompt_cache_write_tokens: 3.75,
prompt_cache_read_tokens: 0.3,
web_search_requests: 0.01,
};
pub const COST_TIER_15_75: ModelCosts = ModelCosts {
input_tokens: 15.0,
output_tokens: 75.0,
prompt_cache_write_tokens: 18.75,
prompt_cache_read_tokens: 1.5,
web_search_requests: 0.01,
};
pub const COST_TIER_5_25: ModelCosts = ModelCosts {
input_tokens: 5.0,
output_tokens: 25.0,
prompt_cache_write_tokens: 6.25,
prompt_cache_read_tokens: 0.5,
web_search_requests: 0.01,
};
pub const COST_TIER_30_150: ModelCosts = ModelCosts {
input_tokens: 30.0,
output_tokens: 150.0,
prompt_cache_write_tokens: 37.5,
prompt_cache_read_tokens: 3.0,
web_search_requests: 0.01,
};
pub const COST_HAIKU_35: ModelCosts = ModelCosts {
input_tokens: 0.8,
output_tokens: 4.0,
prompt_cache_write_tokens: 1.0,
prompt_cache_read_tokens: 0.08,
web_search_requests: 0.01,
};
pub const COST_HAIKU_45: ModelCosts = ModelCosts {
input_tokens: 1.0,
output_tokens: 5.0,
prompt_cache_write_tokens: 1.25,
prompt_cache_read_tokens: 0.1,
web_search_requests: 0.01,
};
pub const COST_DEFAULT: ModelCosts = COST_TIER_5_25;
pub struct ModelCostRegistry {
costs: std::collections::HashMap<String, ModelCosts>,
}
impl ModelCostRegistry {
pub fn new() -> Self {
let mut costs = std::collections::HashMap::new();
costs.insert("claude-opus-4-6".to_string(), COST_TIER_5_25);
costs.insert("claude-opus-4-5".to_string(), COST_TIER_5_25);
costs.insert("claude-opus-4-1".to_string(), COST_TIER_15_75);
costs.insert("claude-opus-4".to_string(), COST_TIER_15_75);
costs.insert("claude-sonnet-4-6".to_string(), COST_TIER_3_15);
costs.insert("claude-sonnet-4-5".to_string(), COST_TIER_3_15);
costs.insert("claude-sonnet-4".to_string(), COST_TIER_3_15);
costs.insert("claude-sonnet-3-5".to_string(), COST_TIER_3_15);
costs.insert("claude-haiku-4-5".to_string(), COST_HAIKU_45);
costs.insert("claude-haiku-3-5".to_string(), COST_HAIKU_35);
costs.insert("MiniMaxAI/MiniMax-M2.5".to_string(), COST_TIER_3_15);
costs.insert("MiniMaxAI/MiniMax-M2".to_string(), COST_TIER_3_15);
costs.insert("gpt-4o".to_string(), COST_TIER_5_25);
costs.insert("gpt-4o-mini".to_string(), COST_HAIKU_35);
costs.insert("gpt-4-turbo".to_string(), COST_TIER_10_30);
costs.insert("gpt-4".to_string(), COST_TIER_30_60);
Self { costs }
}
pub fn get(&self, model: &str) -> &ModelCosts {
if let Some(cost) = self.costs.get(model) {
return cost;
}
for (key, cost) in &self.costs {
if model.starts_with(key) || key.starts_with(model) {
return cost;
}
}
&COST_DEFAULT
}
pub fn register(&mut self, model: &str, costs: ModelCosts) {
self.costs.insert(model.to_string(), costs);
}
}
impl Default for ModelCostRegistry {
fn default() -> Self {
Self::new()
}
}
pub const COST_TIER_30_60: ModelCosts = ModelCosts {
input_tokens: 30.0,
output_tokens: 60.0,
prompt_cache_write_tokens: 30.0,
prompt_cache_read_tokens: 10.0,
web_search_requests: 0.01,
};
pub const COST_TIER_10_30: ModelCosts = ModelCosts {
input_tokens: 10.0,
output_tokens: 30.0,
prompt_cache_write_tokens: 10.0,
prompt_cache_read_tokens: 3.0,
web_search_requests: 0.01,
};
pub fn calculate_cost(model: &str, usage: &TokenUsage) -> f64 {
let registry = ModelCostRegistry::new();
let costs = registry.get(model);
costs.total_cost(usage)
}
pub fn format_cost(cost: f64) -> String {
if cost < 0.01 {
format!("${:.4}", cost)
} else if cost < 1.0 {
format!("${:.2}", cost)
} else {
format!("${:.4}", cost)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CostSummary {
pub input_cost: f64,
pub output_cost: f64,
pub cache_write_cost: f64,
pub cache_read_cost: f64,
pub total_cost: f64,
}
impl CostSummary {
pub fn from_usage(model: &str, usage: &TokenUsage) -> Self {
let registry = ModelCostRegistry::new();
let costs = registry.get(model);
Self {
input_cost: costs.input_cost(usage.input_tokens),
output_cost: costs.output_cost(usage.output_tokens),
cache_write_cost: costs.cache_write_cost(usage.prompt_cache_write_tokens),
cache_read_cost: costs.cache_read_cost(usage.prompt_cache_read_tokens),
total_cost: costs.total_cost(usage),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_model_costs_input() {
let costs = COST_TIER_3_15;
assert_eq!(costs.input_cost(1_000_000), 3.0);
assert_eq!(costs.input_cost(500_000), 1.5);
}
#[test]
fn test_model_costs_output() {
let costs = COST_TIER_3_15;
assert_eq!(costs.output_cost(1_000_000), 15.0);
}
#[test]
fn test_token_usage_total() {
let usage = TokenUsage {
input_tokens: 100,
output_tokens: 50,
prompt_cache_write_tokens: 25,
prompt_cache_read_tokens: 75,
};
assert_eq!(usage.total(), 250);
}
#[test]
fn test_model_cost_registry() {
let registry = ModelCostRegistry::new();
let costs = registry.get("claude-sonnet-4-6");
assert_eq!(costs.input_tokens, 3.0);
let costs = registry.get("claude-haiku-4-5");
assert_eq!(costs.input_tokens, 1.0);
}
#[test]
fn test_model_cost_registry_unknown() {
let registry = ModelCostRegistry::new();
let costs = registry.get("unknown-model");
assert_eq!(costs.input_tokens, COST_DEFAULT.input_tokens);
}
#[test]
fn test_calculate_cost() {
let usage = TokenUsage {
input_tokens: 1_000_000,
output_tokens: 500_000,
prompt_cache_write_tokens: 0,
prompt_cache_read_tokens: 0,
};
let cost = calculate_cost("claude-sonnet-4-6", &usage);
assert!((cost - 10.5).abs() < 0.01);
}
#[test]
fn test_format_cost() {
assert_eq!(format_cost(0.001), "$0.0010");
assert_eq!(format_cost(0.5), "$0.50");
assert_eq!(format_cost(1.5), "$1.5000");
}
#[test]
fn test_cost_summary() {
let usage = TokenUsage {
input_tokens: 1_000_000,
output_tokens: 500_000,
prompt_cache_write_tokens: 100_000,
prompt_cache_read_tokens: 200_000,
};
let summary = CostSummary::from_usage("claude-sonnet-4-6", &usage);
assert!((summary.input_cost - 3.0).abs() < 0.01);
assert!((summary.output_cost - 7.5).abs() < 0.01);
assert!((summary.cache_write_cost - 0.375).abs() < 0.01);
assert!((summary.cache_read_cost - 0.06).abs() < 0.01);
}
}