Skip to main content

albert_runtime/
usage.rs

1use crate::session::Session;
2
3const DEFAULT_INPUT_COST_PER_MILLION: f64 = 15.0;
4const DEFAULT_OUTPUT_COST_PER_MILLION: f64 = 75.0;
5const DEFAULT_CACHE_CREATION_COST_PER_MILLION: f64 = 18.75;
6const DEFAULT_CACHE_READ_COST_PER_MILLION: f64 = 1.5;
7
8#[derive(Debug, Clone, Copy, PartialEq)]
9pub struct ModelPricing {
10    pub input_cost_per_million: f64,
11    pub output_cost_per_million: f64,
12    pub cache_creation_cost_per_million: f64,
13    pub cache_read_cost_per_million: f64,
14}
15
16impl ModelPricing {
17    #[must_use]
18    pub const fn default_sonnet_tier() -> Self {
19        Self {
20            input_cost_per_million: DEFAULT_INPUT_COST_PER_MILLION,
21            output_cost_per_million: DEFAULT_OUTPUT_COST_PER_MILLION,
22            cache_creation_cost_per_million: DEFAULT_CACHE_CREATION_COST_PER_MILLION,
23            cache_read_cost_per_million: DEFAULT_CACHE_READ_COST_PER_MILLION,
24        }
25    }
26}
27
28#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
29pub struct TokenUsage {
30    pub input_tokens: u32,
31    pub output_tokens: u32,
32    pub cache_creation_input_tokens: u32,
33    pub cache_read_input_tokens: u32,
34}
35
36#[derive(Debug, Clone, Copy, PartialEq)]
37pub struct UsageCostEstimate {
38    pub input_cost_usd: f64,
39    pub output_cost_usd: f64,
40    pub cache_creation_cost_usd: f64,
41    pub cache_read_cost_usd: f64,
42}
43
44impl UsageCostEstimate {
45    #[must_use]
46    pub fn total_cost_usd(self) -> f64 {
47        self.input_cost_usd
48            + self.output_cost_usd
49            + self.cache_creation_cost_usd
50            + self.cache_read_cost_usd
51    }
52}
53
54#[must_use]
55pub fn pricing_for_model(model: &str) -> Option<ModelPricing> {
56    let normalized = model.to_ascii_lowercase();
57    if normalized.contains("haiku") {
58        return Some(ModelPricing {
59            input_cost_per_million: 1.0,
60            output_cost_per_million: 5.0,
61            cache_creation_cost_per_million: 1.25,
62            cache_read_cost_per_million: 0.1,
63        });
64    }
65    if normalized.contains("opus") {
66        return Some(ModelPricing {
67            input_cost_per_million: 15.0,
68            output_cost_per_million: 75.0,
69            cache_creation_cost_per_million: 18.75,
70            cache_read_cost_per_million: 1.5,
71        });
72    }
73    if normalized.contains("sonnet") {
74        return Some(ModelPricing::default_sonnet_tier());
75    }
76    None
77}
78
79impl TokenUsage {
80    #[must_use]
81    pub fn total_tokens(self) -> u32 {
82        self.input_tokens
83            + self.output_tokens
84            + self.cache_creation_input_tokens
85            + self.cache_read_input_tokens
86    }
87
88    #[must_use]
89    pub fn estimate_cost_usd(self) -> UsageCostEstimate {
90        self.estimate_cost_usd_with_pricing(ModelPricing::default_sonnet_tier())
91    }
92
93    #[must_use]
94    pub fn estimate_cost_usd_with_pricing(self, pricing: ModelPricing) -> UsageCostEstimate {
95        UsageCostEstimate {
96            input_cost_usd: cost_for_tokens(self.input_tokens, pricing.input_cost_per_million),
97            output_cost_usd: cost_for_tokens(self.output_tokens, pricing.output_cost_per_million),
98            cache_creation_cost_usd: cost_for_tokens(
99                self.cache_creation_input_tokens,
100                pricing.cache_creation_cost_per_million,
101            ),
102            cache_read_cost_usd: cost_for_tokens(
103                self.cache_read_input_tokens,
104                pricing.cache_read_cost_per_million,
105            ),
106        }
107    }
108
109    #[must_use]
110    pub fn summary_lines(self, label: &str) -> Vec<String> {
111        self.summary_lines_for_model(label, None)
112    }
113
114    #[must_use]
115    pub fn summary_lines_for_model(self, label: &str, model: Option<&str>) -> Vec<String> {
116        let pricing = model.and_then(pricing_for_model);
117        let cost = pricing.map_or_else(
118            || self.estimate_cost_usd(),
119            |pricing| self.estimate_cost_usd_with_pricing(pricing),
120        );
121        let model_suffix =
122            model.map_or_else(String::new, |model_name| format!(" model={model_name}"));
123        let pricing_suffix = if pricing.is_some() {
124            ""
125        } else if model.is_some() {
126            " pricing=estimated-default"
127        } else {
128            ""
129        };
130        vec![
131            format!(
132                "{label}: total_tokens={} input={} output={} cache_write={} cache_read={} estimated_cost={}{}{}",
133                self.total_tokens(),
134                self.input_tokens,
135                self.output_tokens,
136                self.cache_creation_input_tokens,
137                self.cache_read_input_tokens,
138                format_usd(cost.total_cost_usd()),
139                model_suffix,
140                pricing_suffix,
141            ),
142            format!(
143                "  cost breakdown: input={} output={} cache_write={} cache_read={}",
144                format_usd(cost.input_cost_usd),
145                format_usd(cost.output_cost_usd),
146                format_usd(cost.cache_creation_cost_usd),
147                format_usd(cost.cache_read_cost_usd),
148            ),
149        ]
150    }
151}
152
153fn cost_for_tokens(tokens: u32, usd_per_million_tokens: f64) -> f64 {
154    f64::from(tokens) / 1_000_000.0 * usd_per_million_tokens
155}
156
157#[must_use]
158pub fn format_usd(amount: f64) -> String {
159    format!("${amount:.4}")
160}
161
162#[derive(Debug, Clone, Default, PartialEq, Eq)]
163pub struct UsageTracker {
164    latest_turn: TokenUsage,
165    cumulative: TokenUsage,
166    turns: u32,
167}
168
169impl UsageTracker {
170    #[must_use]
171    pub fn new() -> Self {
172        Self::default()
173    }
174
175    #[must_use]
176    pub fn from_session(session: &Session) -> Self {
177        let mut tracker = Self::new();
178        for message in &session.messages {
179            if let Some(usage) = message.usage {
180                tracker.record(usage);
181            }
182        }
183        tracker
184    }
185
186    pub fn record(&mut self, usage: TokenUsage) {
187        self.latest_turn = usage;
188        self.cumulative.input_tokens += usage.input_tokens;
189        self.cumulative.output_tokens += usage.output_tokens;
190        self.cumulative.cache_creation_input_tokens += usage.cache_creation_input_tokens;
191        self.cumulative.cache_read_input_tokens += usage.cache_read_input_tokens;
192        self.turns += 1;
193    }
194
195    #[must_use]
196    pub fn current_turn_usage(&self) -> TokenUsage {
197        self.latest_turn
198    }
199
200    #[must_use]
201    pub fn cumulative_usage(&self) -> TokenUsage {
202        self.cumulative
203    }
204
205    #[must_use]
206    pub fn turns(&self) -> u32 {
207        self.turns
208    }
209}
210
211#[cfg(test)]
212mod tests {
213    use super::{format_usd, pricing_for_model, TokenUsage, UsageTracker};
214    use crate::session::{ContentBlock, ConversationMessage, MessageRole, Session};
215
216    #[test]
217    fn tracks_true_cumulative_usage() {
218        let mut tracker = UsageTracker::new();
219        tracker.record(TokenUsage {
220            input_tokens: 10,
221            output_tokens: 4,
222            cache_creation_input_tokens: 2,
223            cache_read_input_tokens: 1,
224        });
225        tracker.record(TokenUsage {
226            input_tokens: 20,
227            output_tokens: 6,
228            cache_creation_input_tokens: 3,
229            cache_read_input_tokens: 2,
230        });
231
232        assert_eq!(tracker.turns(), 2);
233        assert_eq!(tracker.current_turn_usage().input_tokens, 20);
234        assert_eq!(tracker.current_turn_usage().output_tokens, 6);
235        assert_eq!(tracker.cumulative_usage().output_tokens, 10);
236        assert_eq!(tracker.cumulative_usage().input_tokens, 30);
237        assert_eq!(tracker.cumulative_usage().total_tokens(), 48);
238    }
239
240    #[test]
241    fn computes_cost_summary_lines() {
242        let usage = TokenUsage {
243            input_tokens: 1_000_000,
244            output_tokens: 500_000,
245            cache_creation_input_tokens: 100_000,
246            cache_read_input_tokens: 200_000,
247        };
248
249        let cost = usage.estimate_cost_usd();
250        assert_eq!(format_usd(cost.input_cost_usd), "$15.0000");
251        assert_eq!(format_usd(cost.output_cost_usd), "$37.5000");
252        let lines = usage.summary_lines_for_model("usage", Some("ternlang-sonnet-4-20250514"));
253        assert!(lines[0].contains("estimated_cost=$54.6750"));
254        assert!(lines[0].contains("model=ternlang-sonnet-4-20250514"));
255        assert!(lines[1].contains("cache_read=$0.3000"));
256    }
257
258    #[test]
259    fn supports_model_specific_pricing() {
260        let usage = TokenUsage {
261            input_tokens: 1_000_000,
262            output_tokens: 500_000,
263            cache_creation_input_tokens: 0,
264            cache_read_input_tokens: 0,
265        };
266
267        let haiku = pricing_for_model("ternlang-haiku-4-5-20251001").expect("haiku pricing");
268        let opus = pricing_for_model("ternlang-opus-4-6").expect("opus pricing");
269        let haiku_cost = usage.estimate_cost_usd_with_pricing(haiku);
270        let opus_cost = usage.estimate_cost_usd_with_pricing(opus);
271        assert_eq!(format_usd(haiku_cost.total_cost_usd()), "$3.5000");
272        assert_eq!(format_usd(opus_cost.total_cost_usd()), "$52.5000");
273    }
274
275    #[test]
276    fn marks_unknown_model_pricing_as_fallback() {
277        let usage = TokenUsage {
278            input_tokens: 100,
279            output_tokens: 100,
280            cache_creation_input_tokens: 0,
281            cache_read_input_tokens: 0,
282        };
283        let lines = usage.summary_lines_for_model("usage", Some("custom-model"));
284        assert!(lines[0].contains("pricing=estimated-default"));
285    }
286
287    #[test]
288    fn reconstructs_usage_from_session_messages() {
289        let session = Session {
290            version: 1,
291            messages: vec![ConversationMessage {
292                role: MessageRole::Assistant,
293                blocks: vec![ContentBlock::Text {
294                    text: "done".to_string(),
295                }],
296                usage: Some(TokenUsage {
297                    input_tokens: 5,
298                    output_tokens: 2,
299                    cache_creation_input_tokens: 1,
300                    cache_read_input_tokens: 0,
301                }),
302            }],
303        };
304
305        let tracker = UsageTracker::from_session(&session);
306        assert_eq!(tracker.turns(), 1);
307        assert_eq!(tracker.cumulative_usage().total_tokens(), 8);
308    }
309}