1use crate::session::Session;
2
3const DEFAULT_INPUT_COST_PER_MILLION: f64 = 15.0;
4const DEFAULT_OUTPUT_COST_PER_MILLION: f64 = 75.0;
5const DEFAULT_CACHE_CREATION_COST_PER_MILLION: f64 = 18.75;
6const DEFAULT_CACHE_READ_COST_PER_MILLION: f64 = 1.5;
7
8#[derive(Debug, Clone, Copy, PartialEq)]
9pub struct ModelPricing {
10 pub input_cost_per_million: f64,
11 pub output_cost_per_million: f64,
12 pub cache_creation_cost_per_million: f64,
13 pub cache_read_cost_per_million: f64,
14}
15
16impl ModelPricing {
17 #[must_use]
18 pub const fn default_sonnet_tier() -> Self {
19 Self {
20 input_cost_per_million: DEFAULT_INPUT_COST_PER_MILLION,
21 output_cost_per_million: DEFAULT_OUTPUT_COST_PER_MILLION,
22 cache_creation_cost_per_million: DEFAULT_CACHE_CREATION_COST_PER_MILLION,
23 cache_read_cost_per_million: DEFAULT_CACHE_READ_COST_PER_MILLION,
24 }
25 }
26}
27
28#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
29pub struct TokenUsage {
30 pub input_tokens: u32,
31 pub output_tokens: u32,
32 pub cache_creation_input_tokens: u32,
33 pub cache_read_input_tokens: u32,
34}
35
36#[derive(Debug, Clone, Copy, PartialEq)]
37pub struct UsageCostEstimate {
38 pub input_cost_usd: f64,
39 pub output_cost_usd: f64,
40 pub cache_creation_cost_usd: f64,
41 pub cache_read_cost_usd: f64,
42}
43
44impl UsageCostEstimate {
45 #[must_use]
46 pub fn total_cost_usd(self) -> f64 {
47 self.input_cost_usd
48 + self.output_cost_usd
49 + self.cache_creation_cost_usd
50 + self.cache_read_cost_usd
51 }
52}
53
54#[must_use]
55pub fn pricing_for_model(model: &str) -> Option<ModelPricing> {
56 let normalized = model.to_ascii_lowercase();
57 if normalized.contains("haiku") {
58 return Some(ModelPricing {
59 input_cost_per_million: 1.0,
60 output_cost_per_million: 5.0,
61 cache_creation_cost_per_million: 1.25,
62 cache_read_cost_per_million: 0.1,
63 });
64 }
65 if normalized.contains("opus") {
66 return Some(ModelPricing {
67 input_cost_per_million: 15.0,
68 output_cost_per_million: 75.0,
69 cache_creation_cost_per_million: 18.75,
70 cache_read_cost_per_million: 1.5,
71 });
72 }
73 if normalized.contains("sonnet") {
74 return Some(ModelPricing::default_sonnet_tier());
75 }
76 None
77}
78
79impl TokenUsage {
80 #[must_use]
81 pub fn total_tokens(self) -> u32 {
82 self.input_tokens
83 + self.output_tokens
84 + self.cache_creation_input_tokens
85 + self.cache_read_input_tokens
86 }
87
88 #[must_use]
89 pub fn estimate_cost_usd(self) -> UsageCostEstimate {
90 self.estimate_cost_usd_with_pricing(ModelPricing::default_sonnet_tier())
91 }
92
93 #[must_use]
94 pub fn estimate_cost_usd_with_pricing(self, pricing: ModelPricing) -> UsageCostEstimate {
95 UsageCostEstimate {
96 input_cost_usd: cost_for_tokens(self.input_tokens, pricing.input_cost_per_million),
97 output_cost_usd: cost_for_tokens(self.output_tokens, pricing.output_cost_per_million),
98 cache_creation_cost_usd: cost_for_tokens(
99 self.cache_creation_input_tokens,
100 pricing.cache_creation_cost_per_million,
101 ),
102 cache_read_cost_usd: cost_for_tokens(
103 self.cache_read_input_tokens,
104 pricing.cache_read_cost_per_million,
105 ),
106 }
107 }
108
109 #[must_use]
110 pub fn summary_lines(self, label: &str) -> Vec<String> {
111 self.summary_lines_for_model(label, None)
112 }
113
114 #[must_use]
115 pub fn summary_lines_for_model(self, label: &str, model: Option<&str>) -> Vec<String> {
116 let pricing = model.and_then(pricing_for_model);
117 let cost = pricing.map_or_else(
118 || self.estimate_cost_usd(),
119 |pricing| self.estimate_cost_usd_with_pricing(pricing),
120 );
121 let model_suffix =
122 model.map_or_else(String::new, |model_name| format!(" model={model_name}"));
123 let pricing_suffix = if pricing.is_some() {
124 ""
125 } else if model.is_some() {
126 " pricing=estimated-default"
127 } else {
128 ""
129 };
130 vec![
131 format!(
132 "{label}: total_tokens={} input={} output={} cache_write={} cache_read={} estimated_cost={}{}{}",
133 self.total_tokens(),
134 self.input_tokens,
135 self.output_tokens,
136 self.cache_creation_input_tokens,
137 self.cache_read_input_tokens,
138 format_usd(cost.total_cost_usd()),
139 model_suffix,
140 pricing_suffix,
141 ),
142 format!(
143 " cost breakdown: input={} output={} cache_write={} cache_read={}",
144 format_usd(cost.input_cost_usd),
145 format_usd(cost.output_cost_usd),
146 format_usd(cost.cache_creation_cost_usd),
147 format_usd(cost.cache_read_cost_usd),
148 ),
149 ]
150 }
151}
152
153fn cost_for_tokens(tokens: u32, usd_per_million_tokens: f64) -> f64 {
154 f64::from(tokens) / 1_000_000.0 * usd_per_million_tokens
155}
156
157#[must_use]
158pub fn format_usd(amount: f64) -> String {
159 format!("${amount:.4}")
160}
161
162#[derive(Debug, Clone, Default, PartialEq, Eq)]
163pub struct UsageTracker {
164 latest_turn: TokenUsage,
165 cumulative: TokenUsage,
166 turns: u32,
167}
168
169impl UsageTracker {
170 #[must_use]
171 pub fn new() -> Self {
172 Self::default()
173 }
174
175 #[must_use]
176 pub fn from_session(session: &Session) -> Self {
177 let mut tracker = Self::new();
178 for message in &session.messages {
179 if let Some(usage) = message.usage {
180 tracker.record(usage);
181 }
182 }
183 tracker
184 }
185
186 pub fn record(&mut self, usage: TokenUsage) {
187 self.latest_turn = usage;
188 self.cumulative.input_tokens += usage.input_tokens;
189 self.cumulative.output_tokens += usage.output_tokens;
190 self.cumulative.cache_creation_input_tokens += usage.cache_creation_input_tokens;
191 self.cumulative.cache_read_input_tokens += usage.cache_read_input_tokens;
192 self.turns += 1;
193 }
194
195 #[must_use]
196 pub fn current_turn_usage(&self) -> TokenUsage {
197 self.latest_turn
198 }
199
200 #[must_use]
201 pub fn cumulative_usage(&self) -> TokenUsage {
202 self.cumulative
203 }
204
205 #[must_use]
206 pub fn turns(&self) -> u32 {
207 self.turns
208 }
209}
210
211#[cfg(test)]
212mod tests {
213 use super::{format_usd, pricing_for_model, TokenUsage, UsageTracker};
214 use crate::session::{ContentBlock, ConversationMessage, MessageRole, Session};
215
216 #[test]
217 fn tracks_true_cumulative_usage() {
218 let mut tracker = UsageTracker::new();
219 tracker.record(TokenUsage {
220 input_tokens: 10,
221 output_tokens: 4,
222 cache_creation_input_tokens: 2,
223 cache_read_input_tokens: 1,
224 });
225 tracker.record(TokenUsage {
226 input_tokens: 20,
227 output_tokens: 6,
228 cache_creation_input_tokens: 3,
229 cache_read_input_tokens: 2,
230 });
231
232 assert_eq!(tracker.turns(), 2);
233 assert_eq!(tracker.current_turn_usage().input_tokens, 20);
234 assert_eq!(tracker.current_turn_usage().output_tokens, 6);
235 assert_eq!(tracker.cumulative_usage().output_tokens, 10);
236 assert_eq!(tracker.cumulative_usage().input_tokens, 30);
237 assert_eq!(tracker.cumulative_usage().total_tokens(), 48);
238 }
239
240 #[test]
241 fn computes_cost_summary_lines() {
242 let usage = TokenUsage {
243 input_tokens: 1_000_000,
244 output_tokens: 500_000,
245 cache_creation_input_tokens: 100_000,
246 cache_read_input_tokens: 200_000,
247 };
248
249 let cost = usage.estimate_cost_usd();
250 assert_eq!(format_usd(cost.input_cost_usd), "$15.0000");
251 assert_eq!(format_usd(cost.output_cost_usd), "$37.5000");
252 let lines = usage.summary_lines_for_model("usage", Some("ternlang-sonnet-4-20250514"));
253 assert!(lines[0].contains("estimated_cost=$54.6750"));
254 assert!(lines[0].contains("model=ternlang-sonnet-4-20250514"));
255 assert!(lines[1].contains("cache_read=$0.3000"));
256 }
257
258 #[test]
259 fn supports_model_specific_pricing() {
260 let usage = TokenUsage {
261 input_tokens: 1_000_000,
262 output_tokens: 500_000,
263 cache_creation_input_tokens: 0,
264 cache_read_input_tokens: 0,
265 };
266
267 let haiku = pricing_for_model("ternlang-haiku-4-5-20251001").expect("haiku pricing");
268 let opus = pricing_for_model("ternlang-opus-4-6").expect("opus pricing");
269 let haiku_cost = usage.estimate_cost_usd_with_pricing(haiku);
270 let opus_cost = usage.estimate_cost_usd_with_pricing(opus);
271 assert_eq!(format_usd(haiku_cost.total_cost_usd()), "$3.5000");
272 assert_eq!(format_usd(opus_cost.total_cost_usd()), "$52.5000");
273 }
274
275 #[test]
276 fn marks_unknown_model_pricing_as_fallback() {
277 let usage = TokenUsage {
278 input_tokens: 100,
279 output_tokens: 100,
280 cache_creation_input_tokens: 0,
281 cache_read_input_tokens: 0,
282 };
283 let lines = usage.summary_lines_for_model("usage", Some("custom-model"));
284 assert!(lines[0].contains("pricing=estimated-default"));
285 }
286
287 #[test]
288 fn reconstructs_usage_from_session_messages() {
289 let session = Session {
290 version: 1,
291 messages: vec![ConversationMessage {
292 role: MessageRole::Assistant,
293 blocks: vec![ContentBlock::Text {
294 text: "done".to_string(),
295 }],
296 usage: Some(TokenUsage {
297 input_tokens: 5,
298 output_tokens: 2,
299 cache_creation_input_tokens: 1,
300 cache_read_input_tokens: 0,
301 }),
302 }],
303 };
304
305 let tracker = UsageTracker::from_session(&session);
306 assert_eq!(tracker.turns(), 1);
307 assert_eq!(tracker.cumulative_usage().total_tokens(), 8);
308 }
309}