1use std::collections::HashMap;
2
3use chrono::NaiveDate;
4
5use crate::data::models::TokenUsage;
6
7pub const PRICING_FETCH_DATE: &str = "2026-03-21";
9pub const PRICING_SOURCE: &str = "platform.claude.com/docs/en/about-claude/pricing";
11
12#[derive(Debug, Clone)]
16pub struct ModelPrice {
17 pub base_input: f64,
19 pub cache_write_5m: f64,
21 pub cache_write_1h: f64,
23 pub cache_read: f64,
25 pub output: f64,
27}
28
29#[derive(Debug, Clone)]
31pub struct CostBreakdown {
32 pub input_cost: f64,
33 pub cache_write_5m_cost: f64,
34 pub cache_write_1h_cost: f64,
35 pub cache_read_cost: f64,
36 pub output_cost: f64,
37 pub total: f64,
38 pub price_source: PriceSource,
39}
40
41#[derive(Debug, Clone, PartialEq)]
43pub enum PriceSource {
44 Builtin,
46 Config,
48 Unknown,
50}
51
52fn builtin_prices() -> HashMap<String, ModelPrice> {
55 let entries: Vec<(&str, ModelPrice)> = vec![
56 (
57 "claude-opus-4-6",
58 ModelPrice {
59 base_input: 5.0,
60 cache_write_5m: 6.25,
61 cache_write_1h: 10.0,
62 cache_read: 0.50,
63 output: 25.0,
64 },
65 ),
66 (
67 "claude-opus-4-5",
68 ModelPrice {
69 base_input: 5.0,
70 cache_write_5m: 6.25,
71 cache_write_1h: 10.0,
72 cache_read: 0.50,
73 output: 25.0,
74 },
75 ),
76 (
77 "claude-opus-4-1",
78 ModelPrice {
79 base_input: 15.0,
80 cache_write_5m: 18.75,
81 cache_write_1h: 30.0,
82 cache_read: 1.50,
83 output: 75.0,
84 },
85 ),
86 (
87 "claude-opus-4",
88 ModelPrice {
89 base_input: 15.0,
90 cache_write_5m: 18.75,
91 cache_write_1h: 30.0,
92 cache_read: 1.50,
93 output: 75.0,
94 },
95 ),
96 (
97 "claude-sonnet-4-6",
98 ModelPrice {
99 base_input: 3.0,
100 cache_write_5m: 3.75,
101 cache_write_1h: 6.0,
102 cache_read: 0.30,
103 output: 15.0,
104 },
105 ),
106 (
107 "claude-sonnet-4-5",
108 ModelPrice {
109 base_input: 3.0,
110 cache_write_5m: 3.75,
111 cache_write_1h: 6.0,
112 cache_read: 0.30,
113 output: 15.0,
114 },
115 ),
116 (
117 "claude-sonnet-4",
118 ModelPrice {
119 base_input: 3.0,
120 cache_write_5m: 3.75,
121 cache_write_1h: 6.0,
122 cache_read: 0.30,
123 output: 15.0,
124 },
125 ),
126 (
127 "claude-haiku-4-5",
128 ModelPrice {
129 base_input: 1.0,
130 cache_write_5m: 1.25,
131 cache_write_1h: 2.0,
132 cache_read: 0.10,
133 output: 5.0,
134 },
135 ),
136 (
137 "claude-haiku-3-5",
138 ModelPrice {
139 base_input: 0.80,
140 cache_write_5m: 1.0,
141 cache_write_1h: 1.60,
142 cache_read: 0.08,
143 output: 4.0,
144 },
145 ),
146 (
147 "claude-3-haiku",
148 ModelPrice {
149 base_input: 0.25,
150 cache_write_5m: 0.30,
151 cache_write_1h: 0.50,
152 cache_read: 0.03,
153 output: 1.25,
154 },
155 ),
156 ];
157
158 entries
159 .into_iter()
160 .map(|(k, v)| (k.to_string(), v))
161 .collect()
162}
163
164pub struct PricingCalculator {
168 prices: HashMap<String, ModelPrice>,
169 overrides: HashMap<String, ModelPrice>,
170}
171
172impl Default for PricingCalculator {
173 fn default() -> Self {
174 Self::new()
175 }
176}
177
178impl PricingCalculator {
179 pub fn new() -> Self {
181 Self {
182 prices: builtin_prices(),
183 overrides: HashMap::new(),
184 }
185 }
186
187 pub fn with_overrides(mut self, overrides: HashMap<String, ModelPrice>) -> Self {
189 self.overrides = overrides;
190 self
191 }
192
193 pub fn get_price(&self, model: &str) -> Option<(&ModelPrice, PriceSource)> {
201 if let Some(p) = self.overrides.get(model) {
203 return Some((p, PriceSource::Config));
204 }
205 if let Some(p) = Self::prefix_lookup(&self.overrides, model) {
207 return Some((p, PriceSource::Config));
208 }
209 if let Some(p) = self.prices.get(model) {
211 return Some((p, PriceSource::Builtin));
212 }
213 if let Some(p) = Self::prefix_lookup(&self.prices, model) {
215 return Some((p, PriceSource::Builtin));
216 }
217 None
218 }
219
220 fn prefix_lookup<'a>(
222 map: &'a HashMap<String, ModelPrice>,
223 model: &str,
224 ) -> Option<&'a ModelPrice> {
225 map.iter()
226 .filter(|(key, _)| model.starts_with(key.as_str()))
227 .max_by_key(|(key, _)| key.len())
228 .map(|(_, v)| v)
229 }
230
231 pub fn calculate_turn_cost(&self, model: &str, usage: &TokenUsage) -> CostBreakdown {
233 let (price, source) = match self.get_price(model) {
234 Some((p, s)) => (p, s),
235 None => {
236 return CostBreakdown {
237 input_cost: 0.0,
238 cache_write_5m_cost: 0.0,
239 cache_write_1h_cost: 0.0,
240 cache_read_cost: 0.0,
241 output_cost: 0.0,
242 total: 0.0,
243 price_source: PriceSource::Unknown,
244 };
245 }
246 };
247
248 let input_mtok = usage.input_tokens.unwrap_or(0) as f64 / 1_000_000.0;
249 let output_mtok = usage.output_tokens.unwrap_or(0) as f64 / 1_000_000.0;
250 let cache_read_mtok = usage.cache_read_input_tokens.unwrap_or(0) as f64 / 1_000_000.0;
251
252 let (cw_5m, cw_1h) = match &usage.cache_creation {
254 Some(detail) => (
255 detail.ephemeral_5m_input_tokens.unwrap_or(0) as f64 / 1_000_000.0,
256 detail.ephemeral_1h_input_tokens.unwrap_or(0) as f64 / 1_000_000.0,
257 ),
258 None => {
259 let total_cw = usage.cache_creation_input_tokens.unwrap_or(0) as f64 / 1_000_000.0;
261 (total_cw, 0.0)
262 }
263 };
264
265 let input_cost = input_mtok * price.base_input;
266 let cache_write_5m_cost = cw_5m * price.cache_write_5m;
267 let cache_write_1h_cost = cw_1h * price.cache_write_1h;
268 let cache_read_cost = cache_read_mtok * price.cache_read;
269 let output_cost = output_mtok * price.output;
270
271 let total =
272 input_cost + cache_write_5m_cost + cache_write_1h_cost + cache_read_cost + output_cost;
273
274 CostBreakdown {
275 input_cost,
276 cache_write_5m_cost,
277 cache_write_1h_cost,
278 cache_read_cost,
279 output_cost,
280 total,
281 price_source: source,
282 }
283 }
284
285 pub fn pricing_age_days() -> i64 {
287 let fetch_date =
288 NaiveDate::parse_from_str(PRICING_FETCH_DATE, "%Y-%m-%d").expect("valid date constant");
289 let today = chrono::Utc::now().date_naive();
290 (today - fetch_date).num_days()
291 }
292
293 pub fn is_pricing_stale() -> bool {
295 Self::pricing_age_days() > 90
296 }
297}
298
299#[cfg(test)]
302mod tests {
303 use super::*;
304 use crate::data::models::{CacheCreationDetail, TokenUsage};
305
306 fn make_usage(
308 input: u64,
309 output: u64,
310 cache_create: u64,
311 cache_read: u64,
312 cw_5m: u64,
313 cw_1h: u64,
314 ) -> TokenUsage {
315 let cache_creation = if cw_5m > 0 || cw_1h > 0 {
316 Some(CacheCreationDetail {
317 ephemeral_5m_input_tokens: Some(cw_5m),
318 ephemeral_1h_input_tokens: Some(cw_1h),
319 })
320 } else {
321 None
322 };
323
324 TokenUsage {
325 input_tokens: Some(input),
326 output_tokens: Some(output),
327 cache_creation_input_tokens: Some(cache_create),
328 cache_read_input_tokens: Some(cache_read),
329 cache_creation,
330 server_tool_use: None,
331 service_tier: None,
332 speed: None,
333 inference_geo: None,
334 }
335 }
336
337 #[test]
338 fn opus_46_pricing() {
339 let calc = PricingCalculator::new();
340 let usage = make_usage(1_000_000, 1_000_000, 1_000_000, 1_000_000, 1_000_000, 0);
342 let cost = calc.calculate_turn_cost("claude-opus-4-6", &usage);
343
344 assert!(
345 (cost.input_cost - 5.0).abs() < 1e-9,
346 "input_cost: {}",
347 cost.input_cost
348 );
349 assert!(
350 (cost.cache_write_5m_cost - 6.25).abs() < 1e-9,
351 "cache_write_5m_cost: {}",
352 cost.cache_write_5m_cost
353 );
354 assert!(
355 (cost.cache_write_1h_cost - 0.0).abs() < 1e-9,
356 "cache_write_1h_cost: {}",
357 cost.cache_write_1h_cost
358 );
359 assert!(
360 (cost.cache_read_cost - 0.50).abs() < 1e-9,
361 "cache_read_cost: {}",
362 cost.cache_read_cost
363 );
364 assert!(
365 (cost.output_cost - 25.0).abs() < 1e-9,
366 "output_cost: {}",
367 cost.output_cost
368 );
369 assert!((cost.total - 36.75).abs() < 1e-9, "total: {}", cost.total);
370 assert_eq!(cost.price_source, PriceSource::Builtin);
371 }
372
373 #[test]
374 fn distinguishes_5m_and_1h_cache() {
375 let calc = PricingCalculator::new();
376 let usage = make_usage(0, 0, 1_000_000, 0, 500_000, 500_000);
378 let cost = calc.calculate_turn_cost("claude-opus-4-6", &usage);
379
380 assert!(
382 (cost.cache_write_5m_cost - 3.125).abs() < 1e-9,
383 "cache_write_5m_cost: {}",
384 cost.cache_write_5m_cost
385 );
386 assert!(
388 (cost.cache_write_1h_cost - 5.0).abs() < 1e-9,
389 "cache_write_1h_cost: {}",
390 cost.cache_write_1h_cost
391 );
392 assert!((cost.total - 8.125).abs() < 1e-9, "total: {}", cost.total);
393 }
394
395 #[test]
396 fn prefix_matching() {
397 let calc = PricingCalculator::new();
398 let usage = make_usage(1_000_000, 0, 0, 0, 0, 0);
399 let cost = calc.calculate_turn_cost("claude-opus-4-5-20251101", &usage);
400
401 assert!(
403 (cost.input_cost - 5.0).abs() < 1e-9,
404 "input_cost: {}",
405 cost.input_cost
406 );
407 assert_eq!(cost.price_source, PriceSource::Builtin);
408 }
409
410 #[test]
411 fn unknown_model_zero() {
412 let calc = PricingCalculator::new();
413 let usage = make_usage(1_000_000, 1_000_000, 1_000_000, 1_000_000, 1_000_000, 0);
414 let cost = calc.calculate_turn_cost("gpt-99-turbo", &usage);
415
416 assert!((cost.total - 0.0).abs() < 1e-9, "total: {}", cost.total);
417 assert_eq!(cost.price_source, PriceSource::Unknown);
418 }
419
420 #[test]
421 fn config_override_priority() {
422 let mut overrides = HashMap::new();
423 overrides.insert(
424 "claude-opus-4-6".to_string(),
425 ModelPrice {
426 base_input: 99.0,
427 cache_write_5m: 0.0,
428 cache_write_1h: 0.0,
429 cache_read: 0.0,
430 output: 0.0,
431 },
432 );
433
434 let calc = PricingCalculator::new().with_overrides(overrides);
435 let usage = make_usage(1_000_000, 0, 0, 0, 0, 0);
436 let cost = calc.calculate_turn_cost("claude-opus-4-6", &usage);
437
438 assert!(
439 (cost.input_cost - 99.0).abs() < 1e-9,
440 "input_cost: {}",
441 cost.input_cost
442 );
443 assert_eq!(cost.price_source, PriceSource::Config);
444 }
445}