Skip to main content

claude_api/
pricing.rs

1//! Per-model pricing and cost calculation.
2//!
3//! [`PricingTable`] is a snapshot of Anthropic's published rates as of the
4//! crate release. Rates are provided in **USD per million tokens**
5//! (`input_per_mtok`, `output_per_mtok`, etc.). Server-tool fees are billed
6//! per-request.
7//!
8//! Anthropic adjusts pricing periodically; treat the bundled rates as a
9//! best-effort default and pin your own via [`PricingTable::custom`] for
10//! billing-critical workloads. When [`PricingTable::cost`] is called for a
11//! model the table doesn't know about, the function returns `0.0` and emits
12//! a one-time `tracing::warn!`.
13//!
14//! Gated on the `pricing` feature.
15//!
16//! ```
17//! use claude_api::pricing::PricingTable;
18//! use claude_api::types::{ModelId, Usage};
19//!
20//! let pricing = PricingTable::default();
21//! let mut usage = Usage::default();
22//! usage.input_tokens = 1_000_000;
23//! usage.output_tokens = 200_000;
24//! let usd = pricing.cost(&ModelId::SONNET_4_6, &usage);
25//! assert!(usd > 0.0);
26//! ```
27
28use std::collections::HashMap;
29use std::sync::{Mutex, OnceLock};
30
31use crate::types::{ModelId, Usage};
32
33/// Per-model pricing snapshot.
34#[derive(Debug, Clone)]
35pub struct PricingTable {
36    rates: HashMap<ModelId, ModelPricing>,
37}
38
39/// Rates for a single model, all in USD per million tokens unless noted.
40#[derive(Debug, Clone, Copy, PartialEq)]
41#[non_exhaustive]
42pub struct ModelPricing {
43    /// USD per million input tokens.
44    pub input_per_mtok: f64,
45    /// USD per million output tokens.
46    pub output_per_mtok: f64,
47    /// USD per million tokens written to a 5-minute-TTL cache entry.
48    pub cache_creation_5m_per_mtok: f64,
49    /// USD per million tokens written to a 1-hour-TTL cache entry.
50    pub cache_creation_1h_per_mtok: f64,
51    /// USD per million tokens read from any cache entry.
52    pub cache_read_per_mtok: f64,
53    /// USD per server-side `web_search` request.
54    pub web_search_per_request: f64,
55}
56
57impl ModelPricing {
58    /// Build a pricing record from input + output rates and the standard
59    /// cache multipliers (5m = 1.25x input, 1h = 2x input, read = 0.1x input).
60    #[must_use]
61    pub const fn from_input_output(
62        input_per_mtok: f64,
63        output_per_mtok: f64,
64        web_search_per_request: f64,
65    ) -> Self {
66        Self {
67            input_per_mtok,
68            output_per_mtok,
69            cache_creation_5m_per_mtok: input_per_mtok * 1.25,
70            cache_creation_1h_per_mtok: input_per_mtok * 2.0,
71            cache_read_per_mtok: input_per_mtok * 0.1,
72            web_search_per_request,
73        }
74    }
75}
76
77// Pulls in `bundled_rates() -> Vec<(ModelId, ModelPricing)>` generated by
78// `build.rs` from `pricing.toml` at compile time.
79include!(concat!(env!("OUT_DIR"), "/pricing_data.rs"));
80
81impl Default for PricingTable {
82    fn default() -> Self {
83        // Bundled rates from `pricing.toml` -- best-effort. Override via
84        // PricingTable::custom or PricingTable::set for billing-critical
85        // workloads.
86        Self {
87            rates: bundled_rates().into_iter().collect(),
88        }
89    }
90}
91
92impl PricingTable {
93    /// Build a custom pricing table from a fully populated map.
94    #[must_use]
95    pub fn custom(rates: HashMap<ModelId, ModelPricing>) -> Self {
96        Self { rates }
97    }
98
99    /// Override or insert a rate for a single model.
100    pub fn set(&mut self, model: ModelId, rates: ModelPricing) {
101        self.rates.insert(model, rates);
102    }
103
104    /// Borrow the rate row for a model, if known.
105    #[must_use]
106    pub fn get(&self, model: &ModelId) -> Option<&ModelPricing> {
107        self.rates.get(model)
108    }
109
110    /// Total cost in USD for the given usage on the given model. Returns
111    /// `0.0` when the model is unknown to the table; a `tracing::warn!` is
112    /// emitted once per process per missing model.
113    #[must_use]
114    pub fn cost(&self, model: &ModelId, usage: &Usage) -> f64 {
115        self.cost_breakdown(model, usage).total
116    }
117
118    /// Detailed cost breakdown.
119    #[must_use]
120    pub fn cost_breakdown(&self, model: &ModelId, usage: &Usage) -> CostBreakdown {
121        let Some(rates) = self.rates.get(model) else {
122            warn_missing_once(model.as_str());
123            return CostBreakdown::default();
124        };
125
126        let input = f64::from(usage.input_tokens) / 1_000_000.0 * rates.input_per_mtok;
127        let output = f64::from(usage.output_tokens) / 1_000_000.0 * rates.output_per_mtok;
128
129        let cache_creation = match &usage.cache_creation {
130            Some(b) => {
131                f64::from(b.ephemeral_5m_input_tokens) / 1_000_000.0
132                    * rates.cache_creation_5m_per_mtok
133                    + f64::from(b.ephemeral_1h_input_tokens) / 1_000_000.0
134                        * rates.cache_creation_1h_per_mtok
135            }
136            None => {
137                // No per-TTL breakdown: fall back to the legacy total field
138                // and assume 5-minute TTL (the more common default).
139                f64::from(usage.cache_creation_input_tokens.unwrap_or(0)) / 1_000_000.0
140                    * rates.cache_creation_5m_per_mtok
141            }
142        };
143
144        let cache_read = f64::from(usage.cache_read_input_tokens.unwrap_or(0)) / 1_000_000.0
145            * rates.cache_read_per_mtok;
146
147        let server_tool_use = usage.server_tool_use.as_ref().map_or(0.0, |s| {
148            f64::from(s.web_search_requests) * rates.web_search_per_request
149        });
150
151        let total = input + output + cache_creation + cache_read + server_tool_use;
152        CostBreakdown {
153            input,
154            output,
155            cache_creation,
156            cache_read,
157            server_tool_use,
158            total,
159        }
160    }
161}
162
163/// Per-category breakdown of a usage cost, all in USD.
164#[derive(Debug, Clone, Copy, PartialEq, Default)]
165#[non_exhaustive]
166pub struct CostBreakdown {
167    /// Cost of input tokens.
168    pub input: f64,
169    /// Cost of output tokens.
170    pub output: f64,
171    /// Cost of cache writes (5m + 1h combined).
172    pub cache_creation: f64,
173    /// Cost of cache reads.
174    pub cache_read: f64,
175    /// Cost of server-side tool invocations (e.g. `web_search`).
176    pub server_tool_use: f64,
177    /// Sum of the above.
178    pub total: f64,
179}
180
181fn warn_missing_once(model: &str) {
182    static WARNED: OnceLock<Mutex<std::collections::HashSet<String>>> = OnceLock::new();
183    let warned = WARNED.get_or_init(|| Mutex::new(std::collections::HashSet::new()));
184    let mut guard = warned
185        .lock()
186        .unwrap_or_else(std::sync::PoisonError::into_inner);
187    if guard.insert(model.to_owned()) {
188        tracing::warn!(
189            model,
190            "claude-api: no bundled pricing data; cost() will return 0. \
191             Override via PricingTable::custom or PricingTable::set."
192        );
193    }
194}
195
196#[cfg(test)]
197mod tests {
198    use super::*;
199    use crate::types::{CacheCreationBreakdown, ServerToolUseUsage};
200
201    fn approx(a: f64, b: f64) {
202        assert!((a - b).abs() < 1e-9, "expected {b} (within 1e-9), got {a}");
203    }
204
205    #[test]
206    fn default_pricing_includes_known_models() {
207        let p = PricingTable::default();
208        assert!(p.get(&ModelId::OPUS_4_7).is_some());
209        assert!(p.get(&ModelId::SONNET_4_6).is_some());
210        assert!(p.get(&ModelId::HAIKU_4_5).is_some());
211    }
212
213    #[test]
214    fn cost_input_and_output_only() {
215        // 1M input @ $3/MTok + 0.5M output @ $15/MTok = $3 + $7.5 = $10.5
216        let p = PricingTable::default();
217        let usage = Usage {
218            input_tokens: 1_000_000,
219            output_tokens: 500_000,
220            ..Usage::default()
221        };
222        approx(p.cost(&ModelId::SONNET_4_6, &usage), 10.5);
223    }
224
225    #[test]
226    fn cost_uses_per_ttl_breakdown_when_present() {
227        // 1M input + breakdown 1M @ 5m + 1M @ 1h on Sonnet 4.6
228        // input: $3, 5m write: $3.75, 1h write: $6, total $12.75
229        let p = PricingTable::default();
230        let usage = Usage {
231            input_tokens: 1_000_000,
232            output_tokens: 0,
233            cache_creation: Some(CacheCreationBreakdown {
234                ephemeral_5m_input_tokens: 1_000_000,
235                ephemeral_1h_input_tokens: 1_000_000,
236            }),
237            ..Usage::default()
238        };
239        approx(p.cost(&ModelId::SONNET_4_6, &usage), 3.0 + 3.75 + 6.0);
240    }
241
242    #[test]
243    fn cost_falls_back_to_legacy_cache_field_when_no_breakdown() {
244        // Without breakdown we assume 5m TTL: 1M @ $3.75/MTok = $3.75
245        let p = PricingTable::default();
246        let usage = Usage {
247            input_tokens: 0,
248            output_tokens: 0,
249            cache_creation_input_tokens: Some(1_000_000),
250            cache_creation: None,
251            ..Usage::default()
252        };
253        approx(p.cost(&ModelId::SONNET_4_6, &usage), 3.75);
254    }
255
256    #[test]
257    fn cost_includes_cache_reads() {
258        // 1M cache read @ $0.30/MTok = $0.30
259        let p = PricingTable::default();
260        let usage = Usage {
261            cache_read_input_tokens: Some(1_000_000),
262            ..Usage::default()
263        };
264        approx(p.cost(&ModelId::SONNET_4_6, &usage), 0.30);
265    }
266
267    #[test]
268    fn cost_includes_web_search_requests() {
269        let p = PricingTable::default();
270        let usage = Usage {
271            server_tool_use: Some(ServerToolUseUsage {
272                web_search_requests: 50,
273            }),
274            ..Usage::default()
275        };
276        approx(p.cost(&ModelId::SONNET_4_6, &usage), 0.50);
277    }
278
279    #[test]
280    fn breakdown_components_sum_to_total() {
281        let p = PricingTable::default();
282        let usage = Usage {
283            input_tokens: 100_000,
284            output_tokens: 50_000,
285            cache_creation_input_tokens: Some(20_000),
286            cache_read_input_tokens: Some(80_000),
287            server_tool_use: Some(ServerToolUseUsage {
288                web_search_requests: 3,
289            }),
290            ..Usage::default()
291        };
292        let b = p.cost_breakdown(&ModelId::SONNET_4_6, &usage);
293        approx(
294            b.input + b.output + b.cache_creation + b.cache_read + b.server_tool_use,
295            b.total,
296        );
297    }
298
299    #[test]
300    fn unknown_model_returns_zero_cost() {
301        let p = PricingTable::default();
302        let usage = Usage {
303            input_tokens: 1_000_000,
304            output_tokens: 1_000_000,
305            ..Usage::default()
306        };
307        let cost = p.cost(&ModelId::custom("claude-future-foo"), &usage);
308        approx(cost, 0.0);
309    }
310
311    #[test]
312    fn custom_table_overrides_bundled_rates() {
313        let mut rates = HashMap::new();
314        rates.insert(
315            ModelId::SONNET_4_6,
316            ModelPricing::from_input_output(2.00, 10.00, 0.005),
317        );
318        let p = PricingTable::custom(rates);
319        let usage = Usage {
320            input_tokens: 1_000_000,
321            ..Usage::default()
322        };
323        approx(p.cost(&ModelId::SONNET_4_6, &usage), 2.0);
324    }
325
326    #[test]
327    fn set_inserts_or_replaces_a_single_model() {
328        let mut p = PricingTable::default();
329        p.set(
330            ModelId::SONNET_4_6,
331            ModelPricing::from_input_output(99.99, 99.99, 0.0),
332        );
333        let usage = Usage {
334            input_tokens: 1_000_000,
335            ..Usage::default()
336        };
337        approx(p.cost(&ModelId::SONNET_4_6, &usage), 99.99);
338    }
339
340    #[test]
341    fn from_input_output_derives_cache_multipliers() {
342        let r = ModelPricing::from_input_output(10.0, 50.0, 0.01);
343        approx(r.cache_creation_5m_per_mtok, 12.5);
344        approx(r.cache_creation_1h_per_mtok, 20.0);
345        approx(r.cache_read_per_mtok, 1.0);
346    }
347}