Skip to main content

codineer_runtime/
usage.rs

1use crate::session::Session;
2use serde::{Deserialize, Serialize};
3
4const DEFAULT_INPUT_COST_PER_MILLION: f64 = 15.0;
5const DEFAULT_OUTPUT_COST_PER_MILLION: f64 = 75.0;
6const DEFAULT_CACHE_CREATION_COST_PER_MILLION: f64 = 18.75;
7const DEFAULT_CACHE_READ_COST_PER_MILLION: f64 = 1.5;
8
9#[derive(Debug, Clone, Copy, PartialEq)]
10pub struct ModelPricing {
11    pub input_cost_per_million: f64,
12    pub output_cost_per_million: f64,
13    pub cache_creation_cost_per_million: f64,
14    pub cache_read_cost_per_million: f64,
15}
16
17impl ModelPricing {
18    #[must_use]
19    pub const fn default_sonnet_tier() -> Self {
20        Self {
21            input_cost_per_million: DEFAULT_INPUT_COST_PER_MILLION,
22            output_cost_per_million: DEFAULT_OUTPUT_COST_PER_MILLION,
23            cache_creation_cost_per_million: DEFAULT_CACHE_CREATION_COST_PER_MILLION,
24            cache_read_cost_per_million: DEFAULT_CACHE_READ_COST_PER_MILLION,
25        }
26    }
27
28    #[must_use]
29    pub const fn no_cache(input_cost_per_million: f64, output_cost_per_million: f64) -> Self {
30        Self {
31            input_cost_per_million,
32            output_cost_per_million,
33            cache_creation_cost_per_million: 0.0,
34            cache_read_cost_per_million: 0.0,
35        }
36    }
37}
38
39#[derive(Debug, Clone, Copy, Serialize, Deserialize, Default, PartialEq, Eq)]
40pub struct TokenUsage {
41    pub input_tokens: u32,
42    pub output_tokens: u32,
43    pub cache_creation_input_tokens: u32,
44    pub cache_read_input_tokens: u32,
45}
46
47#[derive(Debug, Clone, Copy, PartialEq)]
48pub struct UsageCostEstimate {
49    pub input_cost_usd: f64,
50    pub output_cost_usd: f64,
51    pub cache_creation_cost_usd: f64,
52    pub cache_read_cost_usd: f64,
53}
54
55impl UsageCostEstimate {
56    #[must_use]
57    pub fn total_cost_usd(self) -> f64 {
58        self.input_cost_usd
59            + self.output_cost_usd
60            + self.cache_creation_cost_usd
61            + self.cache_read_cost_usd
62    }
63}
64
65/// Approximate pricing as of 2025-Q2. Update periodically from provider pricing pages.
66#[must_use]
67pub fn pricing_for_model(model: &str) -> Option<ModelPricing> {
68    let normalized = model.to_ascii_lowercase();
69
70    if normalized.contains("haiku") {
71        return Some(ModelPricing {
72            input_cost_per_million: 1.0,
73            output_cost_per_million: 5.0,
74            cache_creation_cost_per_million: 1.25,
75            cache_read_cost_per_million: 0.1,
76        });
77    }
78    if normalized.contains("opus") {
79        return Some(ModelPricing {
80            input_cost_per_million: 15.0,
81            output_cost_per_million: 75.0,
82            cache_creation_cost_per_million: 18.75,
83            cache_read_cost_per_million: 1.5,
84        });
85    }
86    if normalized.contains("sonnet") {
87        return Some(ModelPricing::default_sonnet_tier());
88    }
89
90    if normalized.starts_with("gpt-4o-mini") {
91        return Some(ModelPricing::no_cache(0.15, 0.60));
92    }
93    if normalized.starts_with("gpt-4o") || normalized.starts_with("chatgpt-4o") {
94        return Some(ModelPricing::no_cache(2.5, 10.0));
95    }
96    if normalized.starts_with("gpt-4.1") {
97        return Some(ModelPricing::no_cache(2.0, 8.0));
98    }
99    if normalized.starts_with("o3-mini") || normalized.starts_with("o4-mini") {
100        return Some(ModelPricing::no_cache(1.1, 4.4));
101    }
102    if normalized.starts_with("o3") {
103        return Some(ModelPricing::no_cache(10.0, 40.0));
104    }
105
106    if normalized.starts_with("grok-3-mini") {
107        return Some(ModelPricing::no_cache(0.30, 0.50));
108    }
109    if normalized.starts_with("grok-3") {
110        return Some(ModelPricing::no_cache(3.0, 15.0));
111    }
112
113    None
114}
115
116impl TokenUsage {
117    #[must_use]
118    pub fn total_tokens(self) -> u32 {
119        self.input_tokens
120            + self.output_tokens
121            + self.cache_creation_input_tokens
122            + self.cache_read_input_tokens
123    }
124
125    #[must_use]
126    pub fn estimate_cost_usd(self) -> UsageCostEstimate {
127        self.estimate_cost_usd_with_pricing(ModelPricing::default_sonnet_tier())
128    }
129
130    #[must_use]
131    pub fn estimate_cost_usd_with_pricing(self, pricing: ModelPricing) -> UsageCostEstimate {
132        UsageCostEstimate {
133            input_cost_usd: cost_for_tokens(self.input_tokens, pricing.input_cost_per_million),
134            output_cost_usd: cost_for_tokens(self.output_tokens, pricing.output_cost_per_million),
135            cache_creation_cost_usd: cost_for_tokens(
136                self.cache_creation_input_tokens,
137                pricing.cache_creation_cost_per_million,
138            ),
139            cache_read_cost_usd: cost_for_tokens(
140                self.cache_read_input_tokens,
141                pricing.cache_read_cost_per_million,
142            ),
143        }
144    }
145
146    #[must_use]
147    pub fn summary_lines(self, label: &str) -> Vec<String> {
148        self.summary_lines_for_model(label, None)
149    }
150
151    #[must_use]
152    pub fn summary_lines_for_model(self, label: &str, model: Option<&str>) -> Vec<String> {
153        let pricing = model.and_then(pricing_for_model);
154        let cost = pricing.map_or_else(
155            || self.estimate_cost_usd(),
156            |pricing| self.estimate_cost_usd_with_pricing(pricing),
157        );
158        let model_suffix =
159            model.map_or_else(String::new, |model_name| format!(" model={model_name}"));
160        let pricing_suffix = if pricing.is_some() {
161            ""
162        } else if model.is_some() {
163            " pricing=estimated-default"
164        } else {
165            ""
166        };
167        vec![
168            format!(
169                "{label}: total_tokens={} input={} output={} cache_write={} cache_read={} estimated_cost={}{}{}",
170                self.total_tokens(),
171                self.input_tokens,
172                self.output_tokens,
173                self.cache_creation_input_tokens,
174                self.cache_read_input_tokens,
175                format_usd(cost.total_cost_usd()),
176                model_suffix,
177                pricing_suffix,
178            ),
179            format!(
180                "  cost breakdown: input={} output={} cache_write={} cache_read={}",
181                format_usd(cost.input_cost_usd),
182                format_usd(cost.output_cost_usd),
183                format_usd(cost.cache_creation_cost_usd),
184                format_usd(cost.cache_read_cost_usd),
185            ),
186        ]
187    }
188}
189
190fn cost_for_tokens(tokens: u32, usd_per_million_tokens: f64) -> f64 {
191    f64::from(tokens) / 1_000_000.0 * usd_per_million_tokens
192}
193
194#[must_use]
195pub fn format_usd(amount: f64) -> String {
196    format!("${amount:.4}")
197}
198
199#[derive(Debug, Clone, Default, PartialEq, Eq)]
200pub struct UsageTracker {
201    latest_turn: TokenUsage,
202    cumulative: TokenUsage,
203    turns: u32,
204}
205
206impl UsageTracker {
207    #[must_use]
208    pub fn new() -> Self {
209        Self::default()
210    }
211
212    #[must_use]
213    pub fn from_session(session: &Session) -> Self {
214        let mut tracker = Self::new();
215        for message in &session.messages {
216            if let Some(usage) = message.usage {
217                tracker.record(usage);
218            }
219        }
220        tracker
221    }
222
223    pub fn record(&mut self, usage: TokenUsage) {
224        self.latest_turn = usage;
225        self.cumulative.input_tokens += usage.input_tokens;
226        self.cumulative.output_tokens += usage.output_tokens;
227        self.cumulative.cache_creation_input_tokens += usage.cache_creation_input_tokens;
228        self.cumulative.cache_read_input_tokens += usage.cache_read_input_tokens;
229        self.turns += 1;
230    }
231
232    #[must_use]
233    pub fn current_turn_usage(&self) -> TokenUsage {
234        self.latest_turn
235    }
236
237    #[must_use]
238    pub fn cumulative_usage(&self) -> TokenUsage {
239        self.cumulative
240    }
241
242    #[must_use]
243    pub fn turns(&self) -> u32 {
244        self.turns
245    }
246}
247
248#[cfg(test)]
249mod tests {
250    use super::{format_usd, pricing_for_model, TokenUsage, UsageTracker};
251    use crate::session::{ContentBlock, ConversationMessage, MessageRole, Session};
252
253    #[test]
254    fn tracks_true_cumulative_usage() {
255        let mut tracker = UsageTracker::new();
256        tracker.record(TokenUsage {
257            input_tokens: 10,
258            output_tokens: 4,
259            cache_creation_input_tokens: 2,
260            cache_read_input_tokens: 1,
261        });
262        tracker.record(TokenUsage {
263            input_tokens: 20,
264            output_tokens: 6,
265            cache_creation_input_tokens: 3,
266            cache_read_input_tokens: 2,
267        });
268
269        assert_eq!(tracker.turns(), 2);
270        assert_eq!(tracker.current_turn_usage().input_tokens, 20);
271        assert_eq!(tracker.current_turn_usage().output_tokens, 6);
272        assert_eq!(tracker.cumulative_usage().output_tokens, 10);
273        assert_eq!(tracker.cumulative_usage().input_tokens, 30);
274        assert_eq!(tracker.cumulative_usage().total_tokens(), 48);
275    }
276
277    #[test]
278    fn computes_cost_summary_lines() {
279        let usage = TokenUsage {
280            input_tokens: 1_000_000,
281            output_tokens: 500_000,
282            cache_creation_input_tokens: 100_000,
283            cache_read_input_tokens: 200_000,
284        };
285
286        let cost = usage.estimate_cost_usd();
287        assert_eq!(format_usd(cost.input_cost_usd), "$15.0000");
288        assert_eq!(format_usd(cost.output_cost_usd), "$37.5000");
289        let lines = usage.summary_lines_for_model("usage", Some("claude-sonnet-4-6"));
290        assert!(lines[0].contains("estimated_cost=$54.6750"));
291        assert!(lines[0].contains("model=claude-sonnet-4-6"));
292        assert!(lines[1].contains("cache_read=$0.3000"));
293    }
294
295    #[test]
296    fn supports_model_specific_pricing() {
297        let usage = TokenUsage {
298            input_tokens: 1_000_000,
299            output_tokens: 500_000,
300            cache_creation_input_tokens: 0,
301            cache_read_input_tokens: 0,
302        };
303
304        let haiku = pricing_for_model("claude-haiku-4-5-20251213").expect("haiku pricing");
305        let opus = pricing_for_model("claude-opus-4-6").expect("opus pricing");
306        let haiku_cost = usage.estimate_cost_usd_with_pricing(haiku);
307        let opus_cost = usage.estimate_cost_usd_with_pricing(opus);
308        assert_eq!(format_usd(haiku_cost.total_cost_usd()), "$3.5000");
309        assert_eq!(format_usd(opus_cost.total_cost_usd()), "$52.5000");
310    }
311
312    #[test]
313    fn supports_openai_and_xai_pricing() {
314        assert!(pricing_for_model("gpt-4o").is_some());
315        assert!(pricing_for_model("gpt-4o-mini").is_some());
316        assert!(pricing_for_model("gpt-4.1-nano").is_some());
317        assert!(pricing_for_model("o3-mini").is_some());
318        assert!(pricing_for_model("o3").is_some());
319        assert!(pricing_for_model("grok-3").is_some());
320        assert!(pricing_for_model("grok-3-mini-fast").is_some());
321
322        let gpt4o = pricing_for_model("gpt-4o").unwrap();
323        assert!((gpt4o.input_cost_per_million - 2.5).abs() < f64::EPSILON);
324        let grok3 = pricing_for_model("grok-3").unwrap();
325        assert!((grok3.input_cost_per_million - 3.0).abs() < f64::EPSILON);
326    }
327
328    #[test]
329    fn marks_unknown_model_pricing_as_fallback() {
330        let usage = TokenUsage {
331            input_tokens: 100,
332            output_tokens: 100,
333            cache_creation_input_tokens: 0,
334            cache_read_input_tokens: 0,
335        };
336        let lines = usage.summary_lines_for_model("usage", Some("custom-model"));
337        assert!(lines[0].contains("pricing=estimated-default"));
338    }
339
340    #[test]
341    fn reconstructs_usage_from_session_messages() {
342        let session = Session {
343            version: 1,
344            messages: vec![ConversationMessage {
345                role: MessageRole::Assistant,
346                blocks: vec![ContentBlock::Text {
347                    text: "done".to_string(),
348                }],
349                usage: Some(TokenUsage {
350                    input_tokens: 5,
351                    output_tokens: 2,
352                    cache_creation_input_tokens: 1,
353                    cache_read_input_tokens: 0,
354                }),
355            }],
356        };
357
358        let tracker = UsageTracker::from_session(&session);
359        assert_eq!(tracker.turns(), 1);
360        assert_eq!(tracker.cumulative_usage().total_tokens(), 8);
361    }
362}