1use std::collections::HashMap;
29use std::sync::{Mutex, OnceLock};
30
31use crate::types::{ModelId, Usage};
32
33#[derive(Debug, Clone)]
35pub struct PricingTable {
36 rates: HashMap<ModelId, ModelPricing>,
37}
38
39#[derive(Debug, Clone, Copy, PartialEq)]
41#[non_exhaustive]
42pub struct ModelPricing {
43 pub input_per_mtok: f64,
45 pub output_per_mtok: f64,
47 pub cache_creation_5m_per_mtok: f64,
49 pub cache_creation_1h_per_mtok: f64,
51 pub cache_read_per_mtok: f64,
53 pub web_search_per_request: f64,
55}
56
57impl ModelPricing {
58 #[must_use]
61 pub const fn from_input_output(
62 input_per_mtok: f64,
63 output_per_mtok: f64,
64 web_search_per_request: f64,
65 ) -> Self {
66 Self {
67 input_per_mtok,
68 output_per_mtok,
69 cache_creation_5m_per_mtok: input_per_mtok * 1.25,
70 cache_creation_1h_per_mtok: input_per_mtok * 2.0,
71 cache_read_per_mtok: input_per_mtok * 0.1,
72 web_search_per_request,
73 }
74 }
75}
76
77include!(concat!(env!("OUT_DIR"), "/pricing_data.rs"));
80
81impl Default for PricingTable {
82 fn default() -> Self {
83 Self {
87 rates: bundled_rates().into_iter().collect(),
88 }
89 }
90}
91
92impl PricingTable {
93 #[must_use]
95 pub fn custom(rates: HashMap<ModelId, ModelPricing>) -> Self {
96 Self { rates }
97 }
98
99 pub fn set(&mut self, model: ModelId, rates: ModelPricing) {
101 self.rates.insert(model, rates);
102 }
103
104 #[must_use]
106 pub fn get(&self, model: &ModelId) -> Option<&ModelPricing> {
107 self.rates.get(model)
108 }
109
110 #[must_use]
114 pub fn cost(&self, model: &ModelId, usage: &Usage) -> f64 {
115 self.cost_breakdown(model, usage).total
116 }
117
118 #[must_use]
120 pub fn cost_breakdown(&self, model: &ModelId, usage: &Usage) -> CostBreakdown {
121 let Some(rates) = self.rates.get(model) else {
122 warn_missing_once(model.as_str());
123 return CostBreakdown::default();
124 };
125
126 let input = f64::from(usage.input_tokens) / 1_000_000.0 * rates.input_per_mtok;
127 let output = f64::from(usage.output_tokens) / 1_000_000.0 * rates.output_per_mtok;
128
129 let cache_creation = match &usage.cache_creation {
130 Some(b) => {
131 f64::from(b.ephemeral_5m_input_tokens) / 1_000_000.0
132 * rates.cache_creation_5m_per_mtok
133 + f64::from(b.ephemeral_1h_input_tokens) / 1_000_000.0
134 * rates.cache_creation_1h_per_mtok
135 }
136 None => {
137 f64::from(usage.cache_creation_input_tokens.unwrap_or(0)) / 1_000_000.0
140 * rates.cache_creation_5m_per_mtok
141 }
142 };
143
144 let cache_read = f64::from(usage.cache_read_input_tokens.unwrap_or(0)) / 1_000_000.0
145 * rates.cache_read_per_mtok;
146
147 let server_tool_use = usage.server_tool_use.as_ref().map_or(0.0, |s| {
148 f64::from(s.web_search_requests) * rates.web_search_per_request
149 });
150
151 let total = input + output + cache_creation + cache_read + server_tool_use;
152 CostBreakdown {
153 input,
154 output,
155 cache_creation,
156 cache_read,
157 server_tool_use,
158 total,
159 }
160 }
161}
162
163#[derive(Debug, Clone, Copy, PartialEq, Default)]
165#[non_exhaustive]
166pub struct CostBreakdown {
167 pub input: f64,
169 pub output: f64,
171 pub cache_creation: f64,
173 pub cache_read: f64,
175 pub server_tool_use: f64,
177 pub total: f64,
179}
180
181fn warn_missing_once(model: &str) {
182 static WARNED: OnceLock<Mutex<std::collections::HashSet<String>>> = OnceLock::new();
183 let warned = WARNED.get_or_init(|| Mutex::new(std::collections::HashSet::new()));
184 let mut guard = warned
185 .lock()
186 .unwrap_or_else(std::sync::PoisonError::into_inner);
187 if guard.insert(model.to_owned()) {
188 tracing::warn!(
189 model,
190 "claude-api: no bundled pricing data; cost() will return 0. \
191 Override via PricingTable::custom or PricingTable::set."
192 );
193 }
194}
195
196#[cfg(test)]
197mod tests {
198 use super::*;
199 use crate::types::{CacheCreationBreakdown, ServerToolUseUsage};
200
201 fn approx(a: f64, b: f64) {
202 assert!((a - b).abs() < 1e-9, "expected {b} (within 1e-9), got {a}");
203 }
204
205 #[test]
206 fn default_pricing_includes_known_models() {
207 let p = PricingTable::default();
208 assert!(p.get(&ModelId::OPUS_4_7).is_some());
209 assert!(p.get(&ModelId::SONNET_4_6).is_some());
210 assert!(p.get(&ModelId::HAIKU_4_5).is_some());
211 }
212
213 #[test]
214 fn cost_input_and_output_only() {
215 let p = PricingTable::default();
217 let usage = Usage {
218 input_tokens: 1_000_000,
219 output_tokens: 500_000,
220 ..Usage::default()
221 };
222 approx(p.cost(&ModelId::SONNET_4_6, &usage), 10.5);
223 }
224
225 #[test]
226 fn cost_uses_per_ttl_breakdown_when_present() {
227 let p = PricingTable::default();
230 let usage = Usage {
231 input_tokens: 1_000_000,
232 output_tokens: 0,
233 cache_creation: Some(CacheCreationBreakdown {
234 ephemeral_5m_input_tokens: 1_000_000,
235 ephemeral_1h_input_tokens: 1_000_000,
236 }),
237 ..Usage::default()
238 };
239 approx(p.cost(&ModelId::SONNET_4_6, &usage), 3.0 + 3.75 + 6.0);
240 }
241
242 #[test]
243 fn cost_falls_back_to_legacy_cache_field_when_no_breakdown() {
244 let p = PricingTable::default();
246 let usage = Usage {
247 input_tokens: 0,
248 output_tokens: 0,
249 cache_creation_input_tokens: Some(1_000_000),
250 cache_creation: None,
251 ..Usage::default()
252 };
253 approx(p.cost(&ModelId::SONNET_4_6, &usage), 3.75);
254 }
255
256 #[test]
257 fn cost_includes_cache_reads() {
258 let p = PricingTable::default();
260 let usage = Usage {
261 cache_read_input_tokens: Some(1_000_000),
262 ..Usage::default()
263 };
264 approx(p.cost(&ModelId::SONNET_4_6, &usage), 0.30);
265 }
266
267 #[test]
268 fn cost_includes_web_search_requests() {
269 let p = PricingTable::default();
270 let usage = Usage {
271 server_tool_use: Some(ServerToolUseUsage {
272 web_search_requests: 50,
273 }),
274 ..Usage::default()
275 };
276 approx(p.cost(&ModelId::SONNET_4_6, &usage), 0.50);
277 }
278
279 #[test]
280 fn breakdown_components_sum_to_total() {
281 let p = PricingTable::default();
282 let usage = Usage {
283 input_tokens: 100_000,
284 output_tokens: 50_000,
285 cache_creation_input_tokens: Some(20_000),
286 cache_read_input_tokens: Some(80_000),
287 server_tool_use: Some(ServerToolUseUsage {
288 web_search_requests: 3,
289 }),
290 ..Usage::default()
291 };
292 let b = p.cost_breakdown(&ModelId::SONNET_4_6, &usage);
293 approx(
294 b.input + b.output + b.cache_creation + b.cache_read + b.server_tool_use,
295 b.total,
296 );
297 }
298
299 #[test]
300 fn unknown_model_returns_zero_cost() {
301 let p = PricingTable::default();
302 let usage = Usage {
303 input_tokens: 1_000_000,
304 output_tokens: 1_000_000,
305 ..Usage::default()
306 };
307 let cost = p.cost(&ModelId::custom("claude-future-foo"), &usage);
308 approx(cost, 0.0);
309 }
310
311 #[test]
312 fn custom_table_overrides_bundled_rates() {
313 let mut rates = HashMap::new();
314 rates.insert(
315 ModelId::SONNET_4_6,
316 ModelPricing::from_input_output(2.00, 10.00, 0.005),
317 );
318 let p = PricingTable::custom(rates);
319 let usage = Usage {
320 input_tokens: 1_000_000,
321 ..Usage::default()
322 };
323 approx(p.cost(&ModelId::SONNET_4_6, &usage), 2.0);
324 }
325
326 #[test]
327 fn set_inserts_or_replaces_a_single_model() {
328 let mut p = PricingTable::default();
329 p.set(
330 ModelId::SONNET_4_6,
331 ModelPricing::from_input_output(99.99, 99.99, 0.0),
332 );
333 let usage = Usage {
334 input_tokens: 1_000_000,
335 ..Usage::default()
336 };
337 approx(p.cost(&ModelId::SONNET_4_6, &usage), 99.99);
338 }
339
340 #[test]
341 fn from_input_output_derives_cache_multipliers() {
342 let r = ModelPricing::from_input_output(10.0, 50.0, 0.01);
343 approx(r.cache_creation_5m_per_mtok, 12.5);
344 approx(r.cache_creation_1h_per_mtok, 20.0);
345 approx(r.cache_read_per_mtok, 1.0);
346 }
347}