1use crate::session::Session;
2use serde::{Deserialize, Serialize};
3
4const DEFAULT_INPUT_COST_PER_MILLION: f64 = 15.0;
5const DEFAULT_OUTPUT_COST_PER_MILLION: f64 = 75.0;
6const DEFAULT_CACHE_CREATION_COST_PER_MILLION: f64 = 18.75;
7const DEFAULT_CACHE_READ_COST_PER_MILLION: f64 = 1.5;
8
9#[derive(Debug, Clone, Copy, PartialEq)]
10pub struct ModelPricing {
11 pub input_cost_per_million: f64,
12 pub output_cost_per_million: f64,
13 pub cache_creation_cost_per_million: f64,
14 pub cache_read_cost_per_million: f64,
15}
16
17impl ModelPricing {
18 #[must_use]
19 pub const fn default_sonnet_tier() -> Self {
20 Self {
21 input_cost_per_million: DEFAULT_INPUT_COST_PER_MILLION,
22 output_cost_per_million: DEFAULT_OUTPUT_COST_PER_MILLION,
23 cache_creation_cost_per_million: DEFAULT_CACHE_CREATION_COST_PER_MILLION,
24 cache_read_cost_per_million: DEFAULT_CACHE_READ_COST_PER_MILLION,
25 }
26 }
27
28 #[must_use]
29 pub const fn no_cache(input_cost_per_million: f64, output_cost_per_million: f64) -> Self {
30 Self {
31 input_cost_per_million,
32 output_cost_per_million,
33 cache_creation_cost_per_million: 0.0,
34 cache_read_cost_per_million: 0.0,
35 }
36 }
37}
38
39#[derive(Debug, Clone, Copy, Serialize, Deserialize, Default, PartialEq, Eq)]
40pub struct TokenUsage {
41 pub input_tokens: u32,
42 pub output_tokens: u32,
43 pub cache_creation_input_tokens: u32,
44 pub cache_read_input_tokens: u32,
45}
46
47#[derive(Debug, Clone, Copy, PartialEq)]
48pub struct UsageCostEstimate {
49 pub input_cost_usd: f64,
50 pub output_cost_usd: f64,
51 pub cache_creation_cost_usd: f64,
52 pub cache_read_cost_usd: f64,
53}
54
55impl UsageCostEstimate {
56 #[must_use]
57 pub fn total_cost_usd(self) -> f64 {
58 self.input_cost_usd
59 + self.output_cost_usd
60 + self.cache_creation_cost_usd
61 + self.cache_read_cost_usd
62 }
63}
64
65#[must_use]
67pub fn pricing_for_model(model: &str) -> Option<ModelPricing> {
68 let normalized = model.to_ascii_lowercase();
69
70 if normalized.contains("haiku") {
71 return Some(ModelPricing {
72 input_cost_per_million: 1.0,
73 output_cost_per_million: 5.0,
74 cache_creation_cost_per_million: 1.25,
75 cache_read_cost_per_million: 0.1,
76 });
77 }
78 if normalized.contains("opus") {
79 return Some(ModelPricing {
80 input_cost_per_million: 15.0,
81 output_cost_per_million: 75.0,
82 cache_creation_cost_per_million: 18.75,
83 cache_read_cost_per_million: 1.5,
84 });
85 }
86 if normalized.contains("sonnet") {
87 return Some(ModelPricing::default_sonnet_tier());
88 }
89
90 if normalized.starts_with("gpt-4o-mini") {
91 return Some(ModelPricing::no_cache(0.15, 0.60));
92 }
93 if normalized.starts_with("gpt-4o") || normalized.starts_with("chatgpt-4o") {
94 return Some(ModelPricing::no_cache(2.5, 10.0));
95 }
96 if normalized.starts_with("gpt-4.1") {
97 return Some(ModelPricing::no_cache(2.0, 8.0));
98 }
99 if normalized.starts_with("o3-mini") || normalized.starts_with("o4-mini") {
100 return Some(ModelPricing::no_cache(1.1, 4.4));
101 }
102 if normalized.starts_with("o3") {
103 return Some(ModelPricing::no_cache(10.0, 40.0));
104 }
105
106 if normalized.starts_with("grok-3-mini") {
107 return Some(ModelPricing::no_cache(0.30, 0.50));
108 }
109 if normalized.starts_with("grok-3") {
110 return Some(ModelPricing::no_cache(3.0, 15.0));
111 }
112
113 None
114}
115
116impl TokenUsage {
117 #[must_use]
118 pub fn total_tokens(self) -> u32 {
119 self.input_tokens
120 + self.output_tokens
121 + self.cache_creation_input_tokens
122 + self.cache_read_input_tokens
123 }
124
125 #[must_use]
126 pub fn estimate_cost_usd(self) -> UsageCostEstimate {
127 self.estimate_cost_usd_with_pricing(ModelPricing::default_sonnet_tier())
128 }
129
130 #[must_use]
131 pub fn estimate_cost_usd_with_pricing(self, pricing: ModelPricing) -> UsageCostEstimate {
132 UsageCostEstimate {
133 input_cost_usd: cost_for_tokens(self.input_tokens, pricing.input_cost_per_million),
134 output_cost_usd: cost_for_tokens(self.output_tokens, pricing.output_cost_per_million),
135 cache_creation_cost_usd: cost_for_tokens(
136 self.cache_creation_input_tokens,
137 pricing.cache_creation_cost_per_million,
138 ),
139 cache_read_cost_usd: cost_for_tokens(
140 self.cache_read_input_tokens,
141 pricing.cache_read_cost_per_million,
142 ),
143 }
144 }
145
146 #[must_use]
147 pub fn summary_lines(self, label: &str) -> Vec<String> {
148 self.summary_lines_for_model(label, None)
149 }
150
151 #[must_use]
152 pub fn summary_lines_for_model(self, label: &str, model: Option<&str>) -> Vec<String> {
153 let pricing = model.and_then(pricing_for_model);
154 let cost = pricing.map_or_else(
155 || self.estimate_cost_usd(),
156 |pricing| self.estimate_cost_usd_with_pricing(pricing),
157 );
158 let model_suffix =
159 model.map_or_else(String::new, |model_name| format!(" model={model_name}"));
160 let pricing_suffix = if pricing.is_some() {
161 ""
162 } else if model.is_some() {
163 " pricing=estimated-default"
164 } else {
165 ""
166 };
167 vec![
168 format!(
169 "{label}: total_tokens={} input={} output={} cache_write={} cache_read={} estimated_cost={}{}{}",
170 self.total_tokens(),
171 self.input_tokens,
172 self.output_tokens,
173 self.cache_creation_input_tokens,
174 self.cache_read_input_tokens,
175 format_usd(cost.total_cost_usd()),
176 model_suffix,
177 pricing_suffix,
178 ),
179 format!(
180 " cost breakdown: input={} output={} cache_write={} cache_read={}",
181 format_usd(cost.input_cost_usd),
182 format_usd(cost.output_cost_usd),
183 format_usd(cost.cache_creation_cost_usd),
184 format_usd(cost.cache_read_cost_usd),
185 ),
186 ]
187 }
188}
189
190fn cost_for_tokens(tokens: u32, usd_per_million_tokens: f64) -> f64 {
191 f64::from(tokens) / 1_000_000.0 * usd_per_million_tokens
192}
193
194#[must_use]
195pub fn format_usd(amount: f64) -> String {
196 format!("${amount:.4}")
197}
198
199#[derive(Debug, Clone, Default, PartialEq, Eq)]
200pub struct UsageTracker {
201 latest_turn: TokenUsage,
202 cumulative: TokenUsage,
203 turns: u32,
204}
205
206impl UsageTracker {
207 #[must_use]
208 pub fn new() -> Self {
209 Self::default()
210 }
211
212 #[must_use]
213 pub fn from_session(session: &Session) -> Self {
214 let mut tracker = Self::new();
215 for message in &session.messages {
216 if let Some(usage) = message.usage {
217 tracker.record(usage);
218 }
219 }
220 tracker
221 }
222
223 pub fn record(&mut self, usage: TokenUsage) {
224 self.latest_turn = usage;
225 self.cumulative.input_tokens += usage.input_tokens;
226 self.cumulative.output_tokens += usage.output_tokens;
227 self.cumulative.cache_creation_input_tokens += usage.cache_creation_input_tokens;
228 self.cumulative.cache_read_input_tokens += usage.cache_read_input_tokens;
229 self.turns += 1;
230 }
231
232 #[must_use]
233 pub fn current_turn_usage(&self) -> TokenUsage {
234 self.latest_turn
235 }
236
237 #[must_use]
238 pub fn cumulative_usage(&self) -> TokenUsage {
239 self.cumulative
240 }
241
242 #[must_use]
243 pub fn turns(&self) -> u32 {
244 self.turns
245 }
246}
247
248#[cfg(test)]
249mod tests {
250 use super::{format_usd, pricing_for_model, TokenUsage, UsageTracker};
251 use crate::session::{ContentBlock, ConversationMessage, MessageRole, Session};
252
253 #[test]
254 fn tracks_true_cumulative_usage() {
255 let mut tracker = UsageTracker::new();
256 tracker.record(TokenUsage {
257 input_tokens: 10,
258 output_tokens: 4,
259 cache_creation_input_tokens: 2,
260 cache_read_input_tokens: 1,
261 });
262 tracker.record(TokenUsage {
263 input_tokens: 20,
264 output_tokens: 6,
265 cache_creation_input_tokens: 3,
266 cache_read_input_tokens: 2,
267 });
268
269 assert_eq!(tracker.turns(), 2);
270 assert_eq!(tracker.current_turn_usage().input_tokens, 20);
271 assert_eq!(tracker.current_turn_usage().output_tokens, 6);
272 assert_eq!(tracker.cumulative_usage().output_tokens, 10);
273 assert_eq!(tracker.cumulative_usage().input_tokens, 30);
274 assert_eq!(tracker.cumulative_usage().total_tokens(), 48);
275 }
276
277 #[test]
278 fn computes_cost_summary_lines() {
279 let usage = TokenUsage {
280 input_tokens: 1_000_000,
281 output_tokens: 500_000,
282 cache_creation_input_tokens: 100_000,
283 cache_read_input_tokens: 200_000,
284 };
285
286 let cost = usage.estimate_cost_usd();
287 assert_eq!(format_usd(cost.input_cost_usd), "$15.0000");
288 assert_eq!(format_usd(cost.output_cost_usd), "$37.5000");
289 let lines = usage.summary_lines_for_model("usage", Some("claude-sonnet-4-6"));
290 assert!(lines[0].contains("estimated_cost=$54.6750"));
291 assert!(lines[0].contains("model=claude-sonnet-4-6"));
292 assert!(lines[1].contains("cache_read=$0.3000"));
293 }
294
295 #[test]
296 fn supports_model_specific_pricing() {
297 let usage = TokenUsage {
298 input_tokens: 1_000_000,
299 output_tokens: 500_000,
300 cache_creation_input_tokens: 0,
301 cache_read_input_tokens: 0,
302 };
303
304 let haiku = pricing_for_model("claude-haiku-4-5-20251213").expect("haiku pricing");
305 let opus = pricing_for_model("claude-opus-4-6").expect("opus pricing");
306 let haiku_cost = usage.estimate_cost_usd_with_pricing(haiku);
307 let opus_cost = usage.estimate_cost_usd_with_pricing(opus);
308 assert_eq!(format_usd(haiku_cost.total_cost_usd()), "$3.5000");
309 assert_eq!(format_usd(opus_cost.total_cost_usd()), "$52.5000");
310 }
311
312 #[test]
313 fn supports_openai_and_xai_pricing() {
314 assert!(pricing_for_model("gpt-4o").is_some());
315 assert!(pricing_for_model("gpt-4o-mini").is_some());
316 assert!(pricing_for_model("gpt-4.1-nano").is_some());
317 assert!(pricing_for_model("o3-mini").is_some());
318 assert!(pricing_for_model("o3").is_some());
319 assert!(pricing_for_model("grok-3").is_some());
320 assert!(pricing_for_model("grok-3-mini-fast").is_some());
321
322 let gpt4o = pricing_for_model("gpt-4o").unwrap();
323 assert!((gpt4o.input_cost_per_million - 2.5).abs() < f64::EPSILON);
324 let grok3 = pricing_for_model("grok-3").unwrap();
325 assert!((grok3.input_cost_per_million - 3.0).abs() < f64::EPSILON);
326 }
327
328 #[test]
329 fn marks_unknown_model_pricing_as_fallback() {
330 let usage = TokenUsage {
331 input_tokens: 100,
332 output_tokens: 100,
333 cache_creation_input_tokens: 0,
334 cache_read_input_tokens: 0,
335 };
336 let lines = usage.summary_lines_for_model("usage", Some("custom-model"));
337 assert!(lines[0].contains("pricing=estimated-default"));
338 }
339
340 #[test]
341 fn reconstructs_usage_from_session_messages() {
342 let session = Session {
343 version: 1,
344 messages: vec![ConversationMessage {
345 role: MessageRole::Assistant,
346 blocks: vec![ContentBlock::Text {
347 text: "done".to_string(),
348 }],
349 usage: Some(TokenUsage {
350 input_tokens: 5,
351 output_tokens: 2,
352 cache_creation_input_tokens: 1,
353 cache_read_input_tokens: 0,
354 }),
355 }],
356 };
357
358 let tracker = UsageTracker::from_session(&session);
359 assert_eq!(tracker.turns(), 1);
360 assert_eq!(tracker.cumulative_usage().total_tokens(), 8);
361 }
362}