Skip to main content

claude_api/
cost_preview.rs

1//! `CostPreview` -- estimate the USD cost of a request before sending it.
2//!
3//! The input side is exact: it hits `/v1/messages/count_tokens` to get the
4//! tokenizer's actual count. The output side is bounded: we use
5//! `request.max_tokens` as the upper bound, since the actual number of
6//! output tokens is unknown until generation finishes. Use
7//! [`CostPreview::cost_for`] for a point estimate at any specific output
8//! count.
9//!
10//! Obtain via [`Messages::cost_preview`](crate::messages::Messages::cost_preview).
11
12#![cfg(all(feature = "async", feature = "pricing"))]
13
14use std::collections::{HashMap, VecDeque};
15use std::hash::{DefaultHasher, Hash, Hasher};
16use std::sync::Mutex;
17
18use crate::pricing::PricingTable;
19use crate::types::{ModelId, Usage};
20
21/// Pre-flight cost estimate for a request.
22#[non_exhaustive]
23#[derive(Debug, Clone, PartialEq)]
24pub struct CostPreview {
25    /// Model the estimate was computed for.
26    pub model: ModelId,
27    /// Server-counted input tokens (from `/v1/messages/count_tokens`).
28    pub input_tokens: u32,
29    /// Output upper bound, taken from `request.max_tokens`.
30    pub max_output_tokens: u32,
31    /// USD cost of the input tokens alone.
32    pub input_cost_usd: f64,
33    /// USD cost if the model emits exactly `max_output_tokens` output tokens.
34    pub max_output_cost_usd: f64,
35    /// `input_cost_usd + max_output_cost_usd`. The largest amount this
36    /// request could cost in vanilla usage (excludes cache/server-tool
37    /// charges since those are runtime-determined).
38    pub max_total_usd: f64,
39}
40
41impl CostPreview {
42    /// USD cost if the model produces exactly `output_tokens` tokens. Useful
43    /// for plotting expected cost against an empirical output-size estimate.
44    #[must_use]
45    pub fn cost_for(&self, output_tokens: u32, pricing: &PricingTable) -> f64 {
46        pricing.cost(
47            &self.model,
48            &Usage {
49                input_tokens: self.input_tokens,
50                output_tokens,
51                ..Usage::default()
52            },
53        )
54    }
55}
56
57/// Bounded cache for `count_tokens` results, keyed by a stable hash of the
58/// request body. Use to skip the network round-trip on repeated previews
59/// against unchanged inputs (long-running agent sessions, IDE
60/// integrations, etc.).
61///
62/// Eviction is FIFO once `capacity` is reached -- not a true LRU, but
63/// adequate for the common pattern of "preview the same prompt many
64/// times, occasionally see a new one." Wrap in [`std::sync::Arc`] to
65/// share across tasks; the inner state is [`Mutex`]-protected.
66#[derive(Debug)]
67pub struct CountTokensCache {
68    inner: Mutex<CacheInner>,
69    capacity: usize,
70}
71
72#[derive(Debug)]
73struct CacheInner {
74    map: HashMap<u64, u32>,
75    order: VecDeque<u64>,
76}
77
78impl CountTokensCache {
79    /// Build a new cache with the given capacity.
80    #[must_use]
81    pub fn new(capacity: usize) -> Self {
82        Self {
83            inner: Mutex::new(CacheInner {
84                map: HashMap::with_capacity(capacity),
85                order: VecDeque::with_capacity(capacity),
86            }),
87            capacity,
88        }
89    }
90
91    /// Number of entries currently stored.
92    #[must_use]
93    pub fn len(&self) -> usize {
94        self.lock().map.len()
95    }
96
97    /// `true` when no entries are stored.
98    #[must_use]
99    pub fn is_empty(&self) -> bool {
100        self.lock().map.is_empty()
101    }
102
103    /// Drop all cached entries.
104    pub fn clear(&self) {
105        let mut inner = self.lock();
106        inner.map.clear();
107        inner.order.clear();
108    }
109
110    /// Look up a cached input-token count by request-hash.
111    #[must_use]
112    pub fn get(&self, key: u64) -> Option<u32> {
113        self.lock().map.get(&key).copied()
114    }
115
116    /// Insert (or replace) an entry. Evicts the oldest entry by insertion
117    /// order if the cache is at capacity and `key` is not already
118    /// present.
119    // map_entry: the entry API would require disjoint borrows of two
120    // fields through a MutexGuard, which DerefMut can't express.
121    #[allow(clippy::map_entry)]
122    pub fn put(&self, key: u64, value: u32) {
123        let mut inner = self.lock();
124        if inner.map.contains_key(&key) {
125            inner.map.insert(key, value);
126            return;
127        }
128        if inner.order.len() >= self.capacity
129            && let Some(oldest) = inner.order.pop_front()
130        {
131            inner.map.remove(&oldest);
132        }
133        inner.map.insert(key, value);
134        inner.order.push_back(key);
135    }
136
137    fn lock(&self) -> std::sync::MutexGuard<'_, CacheInner> {
138        self.inner
139            .lock()
140            .unwrap_or_else(std::sync::PoisonError::into_inner)
141    }
142}
143
144/// Hash a value's serde-JSON serialization to a stable u64. Suitable as a
145/// cache key for [`CountTokensCache`]. Returns the empty-string hash on
146/// serialization failure (effectively groups malformed inputs together;
147/// shouldn't happen for crate-owned types).
148#[must_use]
149pub fn hash_request<T: serde::Serialize>(value: &T) -> u64 {
150    let bytes = serde_json::to_vec(value).unwrap_or_default();
151    let mut hasher = DefaultHasher::new();
152    bytes.hash(&mut hasher);
153    hasher.finish()
154}
155
156#[cfg(test)]
157mod tests {
158    use super::*;
159
160    #[test]
161    fn cache_put_and_get_round_trips() {
162        let cache = CountTokensCache::new(4);
163        cache.put(1, 100);
164        cache.put(2, 200);
165        assert_eq!(cache.get(1), Some(100));
166        assert_eq!(cache.get(2), Some(200));
167        assert_eq!(cache.get(3), None);
168    }
169
170    #[test]
171    fn cache_evicts_oldest_when_full() {
172        let cache = CountTokensCache::new(2);
173        cache.put(1, 10);
174        cache.put(2, 20);
175        cache.put(3, 30); // evicts 1
176        assert_eq!(cache.get(1), None);
177        assert_eq!(cache.get(2), Some(20));
178        assert_eq!(cache.get(3), Some(30));
179    }
180
181    #[test]
182    fn cache_replace_does_not_change_eviction_order() {
183        let cache = CountTokensCache::new(2);
184        cache.put(1, 10);
185        cache.put(2, 20);
186        cache.put(1, 11); // replace; should not push 1 to the back
187        cache.put(3, 30); // evicts 1 (oldest by insertion)
188        assert_eq!(cache.get(1), None);
189        assert_eq!(cache.get(2), Some(20));
190        assert_eq!(cache.get(3), Some(30));
191    }
192
193    #[test]
194    fn cache_clear_drops_all_entries() {
195        let cache = CountTokensCache::new(2);
196        cache.put(1, 10);
197        cache.clear();
198        assert!(cache.is_empty());
199    }
200
201    #[test]
202    fn hash_request_is_stable_across_calls() {
203        let v = serde_json::json!({"a": 1, "b": [1, 2, 3]});
204        assert_eq!(hash_request(&v), hash_request(&v));
205    }
206
207    #[test]
208    fn hash_request_distinguishes_different_payloads() {
209        let a = serde_json::json!({"a": 1});
210        let b = serde_json::json!({"a": 2});
211        assert_ne!(hash_request(&a), hash_request(&b));
212    }
213}