Skip to main content

wraith_runtime/
usage.rs

1use crate::session::Session;
2use serde::{Deserialize, Serialize};
3
4const DEFAULT_INPUT_COST_PER_MILLION: f64 = 15.0;
5const DEFAULT_OUTPUT_COST_PER_MILLION: f64 = 75.0;
6const DEFAULT_CACHE_CREATION_COST_PER_MILLION: f64 = 18.75;
7const DEFAULT_CACHE_READ_COST_PER_MILLION: f64 = 1.5;
8
9#[derive(Debug, Clone, Copy, PartialEq)]
10pub struct ModelPricing {
11    pub input_cost_per_million: f64,
12    pub output_cost_per_million: f64,
13    pub cache_creation_cost_per_million: f64,
14    pub cache_read_cost_per_million: f64,
15}
16
17impl ModelPricing {
18    #[must_use]
19    pub const fn default_sonnet_tier() -> Self {
20        Self {
21            input_cost_per_million: DEFAULT_INPUT_COST_PER_MILLION,
22            output_cost_per_million: DEFAULT_OUTPUT_COST_PER_MILLION,
23            cache_creation_cost_per_million: DEFAULT_CACHE_CREATION_COST_PER_MILLION,
24            cache_read_cost_per_million: DEFAULT_CACHE_READ_COST_PER_MILLION,
25        }
26    }
27}
28
29#[derive(Debug, Clone, Copy, Serialize, Deserialize, Default, PartialEq, Eq)]
30pub struct TokenUsage {
31    pub input_tokens: u32,
32    pub output_tokens: u32,
33    pub cache_creation_input_tokens: u32,
34    pub cache_read_input_tokens: u32,
35}
36
37#[derive(Debug, Clone, Copy, PartialEq)]
38pub struct UsageCostEstimate {
39    pub input_cost_usd: f64,
40    pub output_cost_usd: f64,
41    pub cache_creation_cost_usd: f64,
42    pub cache_read_cost_usd: f64,
43}
44
45impl UsageCostEstimate {
46    #[must_use]
47    pub fn total_cost_usd(self) -> f64 {
48        self.input_cost_usd
49            + self.output_cost_usd
50            + self.cache_creation_cost_usd
51            + self.cache_read_cost_usd
52    }
53}
54
55#[must_use]
56pub fn pricing_for_model(model: &str) -> Option<ModelPricing> {
57    let normalized = model.to_ascii_lowercase();
58    if normalized.contains("haiku") {
59        return Some(ModelPricing {
60            input_cost_per_million: 1.0,
61            output_cost_per_million: 5.0,
62            cache_creation_cost_per_million: 1.25,
63            cache_read_cost_per_million: 0.1,
64        });
65    }
66    if normalized.contains("opus") {
67        return Some(ModelPricing {
68            input_cost_per_million: 15.0,
69            output_cost_per_million: 75.0,
70            cache_creation_cost_per_million: 18.75,
71            cache_read_cost_per_million: 1.5,
72        });
73    }
74    if normalized.contains("sonnet") {
75        return Some(ModelPricing::default_sonnet_tier());
76    }
77    None
78}
79
80impl TokenUsage {
81    #[must_use]
82    pub fn total_tokens(self) -> u32 {
83        self.input_tokens
84            + self.output_tokens
85            + self.cache_creation_input_tokens
86            + self.cache_read_input_tokens
87    }
88
89    #[must_use]
90    pub fn estimate_cost_usd(self) -> UsageCostEstimate {
91        self.estimate_cost_usd_with_pricing(ModelPricing::default_sonnet_tier())
92    }
93
94    #[must_use]
95    pub fn estimate_cost_usd_with_pricing(self, pricing: ModelPricing) -> UsageCostEstimate {
96        UsageCostEstimate {
97            input_cost_usd: cost_for_tokens(self.input_tokens, pricing.input_cost_per_million),
98            output_cost_usd: cost_for_tokens(self.output_tokens, pricing.output_cost_per_million),
99            cache_creation_cost_usd: cost_for_tokens(
100                self.cache_creation_input_tokens,
101                pricing.cache_creation_cost_per_million,
102            ),
103            cache_read_cost_usd: cost_for_tokens(
104                self.cache_read_input_tokens,
105                pricing.cache_read_cost_per_million,
106            ),
107        }
108    }
109
110    #[must_use]
111    pub fn summary_lines(self, label: &str) -> Vec<String> {
112        self.summary_lines_for_model(label, None)
113    }
114
115    #[must_use]
116    pub fn summary_lines_for_model(self, label: &str, model: Option<&str>) -> Vec<String> {
117        let pricing = model.and_then(pricing_for_model);
118        let cost = pricing.map_or_else(
119            || self.estimate_cost_usd(),
120            |pricing| self.estimate_cost_usd_with_pricing(pricing),
121        );
122        let model_suffix =
123            model.map_or_else(String::new, |model_name| format!(" model={model_name}"));
124        let pricing_suffix = if pricing.is_some() {
125            ""
126        } else if model.is_some() {
127            " pricing=estimated-default"
128        } else {
129            ""
130        };
131        vec![
132            format!(
133                "{label}: total_tokens={} input={} output={} cache_write={} cache_read={} estimated_cost={}{}{}",
134                self.total_tokens(),
135                self.input_tokens,
136                self.output_tokens,
137                self.cache_creation_input_tokens,
138                self.cache_read_input_tokens,
139                format_usd(cost.total_cost_usd()),
140                model_suffix,
141                pricing_suffix,
142            ),
143            format!(
144                "  cost breakdown: input={} output={} cache_write={} cache_read={}",
145                format_usd(cost.input_cost_usd),
146                format_usd(cost.output_cost_usd),
147                format_usd(cost.cache_creation_cost_usd),
148                format_usd(cost.cache_read_cost_usd),
149            ),
150        ]
151    }
152}
153
154fn cost_for_tokens(tokens: u32, usd_per_million_tokens: f64) -> f64 {
155    f64::from(tokens) / 1_000_000.0 * usd_per_million_tokens
156}
157
158#[must_use]
159pub fn format_usd(amount: f64) -> String {
160    format!("${amount:.4}")
161}
162
163#[derive(Debug, Clone, Default, PartialEq, Eq)]
164pub struct UsageTracker {
165    latest_turn: TokenUsage,
166    cumulative: TokenUsage,
167    turns: u32,
168}
169
170impl UsageTracker {
171    #[must_use]
172    pub fn new() -> Self {
173        Self::default()
174    }
175
176    #[must_use]
177    pub fn from_session(session: &Session) -> Self {
178        let mut tracker = Self::new();
179        for message in &session.messages {
180            if let Some(usage) = message.usage {
181                tracker.record(usage);
182            }
183        }
184        tracker
185    }
186
187    pub fn record(&mut self, usage: TokenUsage) {
188        self.latest_turn = usage;
189        self.cumulative.input_tokens += usage.input_tokens;
190        self.cumulative.output_tokens += usage.output_tokens;
191        self.cumulative.cache_creation_input_tokens += usage.cache_creation_input_tokens;
192        self.cumulative.cache_read_input_tokens += usage.cache_read_input_tokens;
193        self.turns += 1;
194    }
195
196    #[must_use]
197    pub fn current_turn_usage(&self) -> TokenUsage {
198        self.latest_turn
199    }
200
201    #[must_use]
202    pub fn cumulative_usage(&self) -> TokenUsage {
203        self.cumulative
204    }
205
206    #[must_use]
207    pub fn turns(&self) -> u32 {
208        self.turns
209    }
210}
211
212#[cfg(test)]
213mod tests {
214    use super::{format_usd, pricing_for_model, TokenUsage, UsageTracker};
215    use crate::session::{ContentBlock, ConversationMessage, MessageRole, Session};
216
217    #[test]
218    fn tracks_true_cumulative_usage() {
219        let mut tracker = UsageTracker::new();
220        tracker.record(TokenUsage {
221            input_tokens: 10,
222            output_tokens: 4,
223            cache_creation_input_tokens: 2,
224            cache_read_input_tokens: 1,
225        });
226        tracker.record(TokenUsage {
227            input_tokens: 20,
228            output_tokens: 6,
229            cache_creation_input_tokens: 3,
230            cache_read_input_tokens: 2,
231        });
232
233        assert_eq!(tracker.turns(), 2);
234        assert_eq!(tracker.current_turn_usage().input_tokens, 20);
235        assert_eq!(tracker.current_turn_usage().output_tokens, 6);
236        assert_eq!(tracker.cumulative_usage().output_tokens, 10);
237        assert_eq!(tracker.cumulative_usage().input_tokens, 30);
238        assert_eq!(tracker.cumulative_usage().total_tokens(), 48);
239    }
240
241    #[test]
242    fn computes_cost_summary_lines() {
243        let usage = TokenUsage {
244            input_tokens: 1_000_000,
245            output_tokens: 500_000,
246            cache_creation_input_tokens: 100_000,
247            cache_read_input_tokens: 200_000,
248        };
249
250        let cost = usage.estimate_cost_usd();
251        assert_eq!(format_usd(cost.input_cost_usd), "$15.0000");
252        assert_eq!(format_usd(cost.output_cost_usd), "$37.5000");
253        let lines = usage.summary_lines_for_model("usage", Some("claude-sonnet-4-6"));
254        assert!(lines[0].contains("estimated_cost=$54.6750"));
255        assert!(lines[0].contains("model=claude-sonnet-4-6"));
256        assert!(lines[1].contains("cache_read=$0.3000"));
257    }
258
259    #[test]
260    fn supports_model_specific_pricing() {
261        let usage = TokenUsage {
262            input_tokens: 1_000_000,
263            output_tokens: 500_000,
264            cache_creation_input_tokens: 0,
265            cache_read_input_tokens: 0,
266        };
267
268        let haiku = pricing_for_model("claude-haiku-4-5-20251213").expect("haiku pricing");
269        let opus = pricing_for_model("claude-opus-4-6").expect("opus pricing");
270        let haiku_cost = usage.estimate_cost_usd_with_pricing(haiku);
271        let opus_cost = usage.estimate_cost_usd_with_pricing(opus);
272        assert_eq!(format_usd(haiku_cost.total_cost_usd()), "$3.5000");
273        assert_eq!(format_usd(opus_cost.total_cost_usd()), "$52.5000");
274    }
275
276    #[test]
277    fn marks_unknown_model_pricing_as_fallback() {
278        let usage = TokenUsage {
279            input_tokens: 100,
280            output_tokens: 100,
281            cache_creation_input_tokens: 0,
282            cache_read_input_tokens: 0,
283        };
284        let lines = usage.summary_lines_for_model("usage", Some("custom-model"));
285        assert!(lines[0].contains("pricing=estimated-default"));
286    }
287
288    #[test]
289    fn reconstructs_usage_from_session_messages() {
290        let session = Session {
291            version: 1,
292            messages: vec![ConversationMessage {
293                role: MessageRole::Assistant,
294                blocks: vec![ContentBlock::Text {
295                    text: "done".to_string(),
296                }],
297                usage: Some(TokenUsage {
298                    input_tokens: 5,
299                    output_tokens: 2,
300                    cache_creation_input_tokens: 1,
301                    cache_read_input_tokens: 0,
302                }),
303            }],
304        };
305
306        let tracker = UsageTracker::from_session(&session);
307        assert_eq!(tracker.turns(), 1);
308        assert_eq!(tracker.cumulative_usage().total_tokens(), 8);
309    }
310}