1use std::collections::HashMap;
7
8use crate::span::{CostDetails, TokenUsage};
9
10#[derive(Debug, Clone, Copy, Default, PartialEq)]
12pub struct ModelPrice {
13 pub input: f64,
15 pub output: f64,
17 pub cache_read: f64,
19 pub cache_write: f64,
21}
22
23#[derive(Debug, Clone, Default)]
25pub struct PriceTable {
26 inner: HashMap<String, ModelPrice>,
27}
28
29impl PriceTable {
30 pub fn new() -> Self {
32 Self::default()
33 }
34
35 pub fn with_defaults() -> Self {
37 default_pricing_2026_05()
38 }
39
40 pub fn insert(&mut self, model_id: impl Into<String>, price: ModelPrice) -> &mut Self {
42 self.inner.insert(model_id.into(), price);
43 self
44 }
45
46 pub fn get(&self, model_id: &str) -> Option<&ModelPrice> {
49 if let Some(p) = self.inner.get(model_id) {
50 return Some(p);
51 }
52 self.inner
53 .iter()
54 .filter(|(k, _)| model_id.starts_with(k.as_str()))
55 .max_by_key(|(k, _)| k.len())
56 .map(|(_, v)| v)
57 }
58
59 pub fn compute(&self, model_id: &str, usage: TokenUsage) -> Option<CostDetails> {
62 let p = self.get(model_id)?;
63 let scale = 1_000_000.0;
64 let input = (usage.input as f64) * p.input / scale;
65 let output = (usage.output as f64) * p.output / scale;
66 let cache_read = (usage.cache_read as f64) * p.cache_read / scale;
67 let cache_write = (usage.cache_write as f64) * p.cache_write / scale;
68 Some(CostDetails {
69 input,
70 output,
71 cache_read,
72 cache_write,
73 total: input + output + cache_read + cache_write,
74 })
75 }
76}
77
78pub fn default_pricing_2026_05() -> PriceTable {
81 let mut t = PriceTable::new();
82
83 t.insert(
85 "gpt-4o",
86 ModelPrice {
87 input: 2.50,
88 output: 10.00,
89 cache_read: 1.25,
90 cache_write: 0.0,
91 },
92 );
93 t.insert(
94 "gpt-4o-mini",
95 ModelPrice {
96 input: 0.15,
97 output: 0.60,
98 cache_read: 0.075,
99 cache_write: 0.0,
100 },
101 );
102 t.insert(
103 "o1",
104 ModelPrice {
105 input: 15.00,
106 output: 60.00,
107 cache_read: 7.50,
108 cache_write: 0.0,
109 },
110 );
111 t.insert(
112 "o1-mini",
113 ModelPrice {
114 input: 3.00,
115 output: 12.00,
116 cache_read: 1.50,
117 cache_write: 0.0,
118 },
119 );
120
121 t.insert(
123 "claude-opus-4",
124 ModelPrice {
125 input: 15.00,
126 output: 75.00,
127 cache_read: 1.50,
128 cache_write: 18.75,
129 },
130 );
131 t.insert(
132 "claude-sonnet-4",
133 ModelPrice {
134 input: 3.00,
135 output: 15.00,
136 cache_read: 0.30,
137 cache_write: 3.75,
138 },
139 );
140 t.insert(
141 "claude-haiku-4",
142 ModelPrice {
143 input: 0.80,
144 output: 4.00,
145 cache_read: 0.08,
146 cache_write: 1.00,
147 },
148 );
149
150 t.insert(
152 "gemini-2.0-flash",
153 ModelPrice {
154 input: 0.10,
155 output: 0.40,
156 cache_read: 0.025,
157 cache_write: 0.0,
158 },
159 );
160 t.insert(
161 "gemini-1.5-pro",
162 ModelPrice {
163 input: 1.25,
164 output: 5.00,
165 cache_read: 0.3125,
166 cache_write: 0.0,
167 },
168 );
169
170 t
171}
172
173#[cfg(test)]
174mod tests {
175 use super::*;
176
177 #[test]
178 fn empty_table_returns_none() {
179 let t = PriceTable::new();
180 assert!(t.compute("gpt-4o", TokenUsage::default()).is_none());
181 }
182
183 #[test]
184 fn exact_match_used_first() {
185 let mut t = PriceTable::new();
186 t.insert(
187 "gpt-4o-2024-08-06",
188 ModelPrice {
189 input: 1.0,
190 output: 2.0,
191 cache_read: 0.0,
192 cache_write: 0.0,
193 },
194 );
195 t.insert(
196 "gpt-4o",
197 ModelPrice {
198 input: 99.0,
199 output: 99.0,
200 cache_read: 0.0,
201 cache_write: 0.0,
202 },
203 );
204 let p = t.get("gpt-4o-2024-08-06").unwrap();
205 assert_eq!(p.input, 1.0);
206 }
207
208 #[test]
209 fn prefix_match_falls_back() {
210 let mut t = PriceTable::new();
211 t.insert(
212 "gpt-4o",
213 ModelPrice {
214 input: 2.50,
215 output: 10.00,
216 cache_read: 1.25,
217 cache_write: 0.0,
218 },
219 );
220 let p = t.get("gpt-4o-2024-08-06").unwrap();
221 assert_eq!(p.input, 2.50);
222 }
223
224 #[test]
225 fn longest_prefix_wins() {
226 let mut t = PriceTable::new();
227 t.insert(
228 "gpt",
229 ModelPrice {
230 input: 1.0,
231 output: 1.0,
232 cache_read: 0.0,
233 cache_write: 0.0,
234 },
235 );
236 t.insert(
237 "gpt-4o",
238 ModelPrice {
239 input: 2.50,
240 output: 10.00,
241 cache_read: 0.0,
242 cache_write: 0.0,
243 },
244 );
245 let p = t.get("gpt-4o-2024-08-06").unwrap();
246 assert_eq!(p.input, 2.50);
247 }
248
249 #[test]
250 fn compute_scales_per_million_tokens() {
251 let mut t = PriceTable::new();
252 t.insert(
253 "gpt-4o",
254 ModelPrice {
255 input: 2.50,
256 output: 10.00,
257 cache_read: 1.25,
258 cache_write: 0.0,
259 },
260 );
261 let usage = TokenUsage {
262 input: 1_000_000,
263 output: 500_000,
264 cache_read: 0,
265 cache_write: 0,
266 };
267 let c = t.compute("gpt-4o", usage).unwrap();
268 assert!((c.input - 2.50).abs() < 1e-9);
269 assert!((c.output - 5.00).abs() < 1e-9);
270 assert!((c.total - 7.50).abs() < 1e-9);
271 }
272
273 #[test]
274 fn defaults_contains_mainstream_models() {
275 let t = default_pricing_2026_05();
276 assert!(t.get("gpt-4o").is_some());
277 assert!(t.get("claude-sonnet-4").is_some());
278 assert!(t.get("gemini-2.0-flash").is_some());
279 }
280}