claude_api/
cost_preview.rs1#![cfg(all(feature = "async", feature = "pricing"))]
13
14use std::collections::{HashMap, VecDeque};
15use std::hash::{DefaultHasher, Hash, Hasher};
16use std::sync::Mutex;
17
18use crate::pricing::PricingTable;
19use crate::types::{ModelId, Usage};
20
21#[non_exhaustive]
23#[derive(Debug, Clone, PartialEq)]
24pub struct CostPreview {
25 pub model: ModelId,
27 pub input_tokens: u32,
29 pub max_output_tokens: u32,
31 pub input_cost_usd: f64,
33 pub max_output_cost_usd: f64,
35 pub max_total_usd: f64,
39}
40
41impl CostPreview {
42 #[must_use]
45 pub fn cost_for(&self, output_tokens: u32, pricing: &PricingTable) -> f64 {
46 pricing.cost(
47 &self.model,
48 &Usage {
49 input_tokens: self.input_tokens,
50 output_tokens,
51 ..Usage::default()
52 },
53 )
54 }
55}
56
57#[derive(Debug)]
67pub struct CountTokensCache {
68 inner: Mutex<CacheInner>,
69 capacity: usize,
70}
71
72#[derive(Debug)]
73struct CacheInner {
74 map: HashMap<u64, u32>,
75 order: VecDeque<u64>,
76}
77
78impl CountTokensCache {
79 #[must_use]
81 pub fn new(capacity: usize) -> Self {
82 Self {
83 inner: Mutex::new(CacheInner {
84 map: HashMap::with_capacity(capacity),
85 order: VecDeque::with_capacity(capacity),
86 }),
87 capacity,
88 }
89 }
90
91 #[must_use]
93 pub fn len(&self) -> usize {
94 self.lock().map.len()
95 }
96
97 #[must_use]
99 pub fn is_empty(&self) -> bool {
100 self.lock().map.is_empty()
101 }
102
103 pub fn clear(&self) {
105 let mut inner = self.lock();
106 inner.map.clear();
107 inner.order.clear();
108 }
109
110 #[must_use]
112 pub fn get(&self, key: u64) -> Option<u32> {
113 self.lock().map.get(&key).copied()
114 }
115
116 #[allow(clippy::map_entry)]
122 pub fn put(&self, key: u64, value: u32) {
123 let mut inner = self.lock();
124 if inner.map.contains_key(&key) {
125 inner.map.insert(key, value);
126 return;
127 }
128 if inner.order.len() >= self.capacity
129 && let Some(oldest) = inner.order.pop_front()
130 {
131 inner.map.remove(&oldest);
132 }
133 inner.map.insert(key, value);
134 inner.order.push_back(key);
135 }
136
137 fn lock(&self) -> std::sync::MutexGuard<'_, CacheInner> {
138 self.inner
139 .lock()
140 .unwrap_or_else(std::sync::PoisonError::into_inner)
141 }
142}
143
144#[must_use]
149pub fn hash_request<T: serde::Serialize>(value: &T) -> u64 {
150 let bytes = serde_json::to_vec(value).unwrap_or_default();
151 let mut hasher = DefaultHasher::new();
152 bytes.hash(&mut hasher);
153 hasher.finish()
154}
155
156#[cfg(test)]
157mod tests {
158 use super::*;
159
160 #[test]
161 fn cache_put_and_get_round_trips() {
162 let cache = CountTokensCache::new(4);
163 cache.put(1, 100);
164 cache.put(2, 200);
165 assert_eq!(cache.get(1), Some(100));
166 assert_eq!(cache.get(2), Some(200));
167 assert_eq!(cache.get(3), None);
168 }
169
170 #[test]
171 fn cache_evicts_oldest_when_full() {
172 let cache = CountTokensCache::new(2);
173 cache.put(1, 10);
174 cache.put(2, 20);
175 cache.put(3, 30); assert_eq!(cache.get(1), None);
177 assert_eq!(cache.get(2), Some(20));
178 assert_eq!(cache.get(3), Some(30));
179 }
180
181 #[test]
182 fn cache_replace_does_not_change_eviction_order() {
183 let cache = CountTokensCache::new(2);
184 cache.put(1, 10);
185 cache.put(2, 20);
186 cache.put(1, 11); cache.put(3, 30); assert_eq!(cache.get(1), None);
189 assert_eq!(cache.get(2), Some(20));
190 assert_eq!(cache.get(3), Some(30));
191 }
192
193 #[test]
194 fn cache_clear_drops_all_entries() {
195 let cache = CountTokensCache::new(2);
196 cache.put(1, 10);
197 cache.clear();
198 assert!(cache.is_empty());
199 }
200
201 #[test]
202 fn hash_request_is_stable_across_calls() {
203 let v = serde_json::json!({"a": 1, "b": [1, 2, 3]});
204 assert_eq!(hash_request(&v), hash_request(&v));
205 }
206
207 #[test]
208 fn hash_request_distinguishes_different_payloads() {
209 let a = serde_json::json!({"a": 1});
210 let b = serde_json::json!({"a": 2});
211 assert_ne!(hash_request(&a), hash_request(&b));
212 }
213}