Skip to main content

enact_core/kernel/
cost.rs

1//! Token Usage and Cost Calculation
2//!
3//! This module provides structures for tracking token usage and calculating
4//! costs for LLM API calls. It supports both per-call and cumulative tracking.
5//!
6//! ## Features
7//!
8//! - `TokenUsage`: Per-call token counts (prompt, completion, total)
9//! - `CostCalculator`: Calculates costs based on token usage and pricing
10//! - `UsageAccumulator`: Tracks cumulative usage across multiple calls
11
12use serde::{Deserialize, Serialize};
13
14// =============================================================================
15// Token Usage
16// =============================================================================
17
18/// Token usage for a single LLM call
19#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)]
20pub struct TokenUsage {
21    /// Number of tokens in the prompt/input
22    pub prompt_tokens: u32,
23    /// Number of tokens in the completion/output
24    pub completion_tokens: u32,
25    /// Total tokens (prompt + completion)
26    pub total_tokens: u32,
27}
28
29impl TokenUsage {
30    /// Create a new TokenUsage instance
31    pub fn new(prompt_tokens: u32, completion_tokens: u32) -> Self {
32        Self {
33            prompt_tokens,
34            completion_tokens,
35            total_tokens: prompt_tokens + completion_tokens,
36        }
37    }
38
39    /// Create from total tokens only (when breakdown not available)
40    pub fn from_total(total_tokens: u32) -> Self {
41        Self {
42            prompt_tokens: 0,
43            completion_tokens: 0,
44            total_tokens,
45        }
46    }
47
48    /// Check if usage data is available
49    pub fn is_empty(&self) -> bool {
50        self.total_tokens == 0
51    }
52}
53
54// =============================================================================
55// Cost Calculator
56// =============================================================================
57
58/// Model pricing configuration (per 1M tokens)
59#[derive(Debug, Clone, Serialize, Deserialize)]
60pub struct ModelPricing {
61    /// Cost per 1 million input tokens (USD)
62    pub input_cost_per_1m: f64,
63    /// Cost per 1 million output tokens (USD)
64    pub output_cost_per_1m: f64,
65}
66
67impl ModelPricing {
68    /// Create pricing for a model
69    pub fn new(input_cost_per_1m: f64, output_cost_per_1m: f64) -> Self {
70        Self {
71            input_cost_per_1m,
72            output_cost_per_1m,
73        }
74    }
75
76    /// Default pricing (conservative estimate)
77    pub fn default_pricing() -> Self {
78        Self {
79            input_cost_per_1m: 3.0,   // $3 per 1M input tokens
80            output_cost_per_1m: 15.0, // $15 per 1M output tokens
81        }
82    }
83
84    /// GPT-4 pricing
85    pub fn gpt4() -> Self {
86        Self {
87            input_cost_per_1m: 30.0,
88            output_cost_per_1m: 60.0,
89        }
90    }
91
92    /// GPT-4 Turbo pricing
93    pub fn gpt4_turbo() -> Self {
94        Self {
95            input_cost_per_1m: 10.0,
96            output_cost_per_1m: 30.0,
97        }
98    }
99
100    /// GPT-4o pricing
101    pub fn gpt4o() -> Self {
102        Self {
103            input_cost_per_1m: 2.5,
104            output_cost_per_1m: 10.0,
105        }
106    }
107
108    /// GPT-4o mini pricing
109    pub fn gpt4o_mini() -> Self {
110        Self {
111            input_cost_per_1m: 0.15,
112            output_cost_per_1m: 0.60,
113        }
114    }
115
116    /// Claude 3 Opus pricing
117    pub fn claude3_opus() -> Self {
118        Self {
119            input_cost_per_1m: 15.0,
120            output_cost_per_1m: 75.0,
121        }
122    }
123
124    /// Claude 3.5 Sonnet pricing
125    pub fn claude35_sonnet() -> Self {
126        Self {
127            input_cost_per_1m: 3.0,
128            output_cost_per_1m: 15.0,
129        }
130    }
131
132    /// Claude 3 Haiku pricing
133    pub fn claude3_haiku() -> Self {
134        Self {
135            input_cost_per_1m: 0.25,
136            output_cost_per_1m: 1.25,
137        }
138    }
139}
140
141/// Cost calculator for token usage
142pub struct CostCalculator;
143
144impl CostCalculator {
145    /// Calculate cost for a given token usage and pricing
146    pub fn calculate_cost(usage: &TokenUsage, pricing: &ModelPricing) -> f64 {
147        let input_cost = (usage.prompt_tokens as f64 / 1_000_000.0) * pricing.input_cost_per_1m;
148        let output_cost =
149            (usage.completion_tokens as f64 / 1_000_000.0) * pricing.output_cost_per_1m;
150        input_cost + output_cost
151    }
152
153    /// Calculate cost using default pricing
154    pub fn calculate_cost_default(usage: &TokenUsage) -> f64 {
155        Self::calculate_cost(usage, &ModelPricing::default_pricing())
156    }
157
158    /// Get pricing for a model by name
159    pub fn pricing_for_model(model: &str) -> ModelPricing {
160        let model_lower = model.to_lowercase();
161
162        if model_lower.contains("gpt-4o-mini") || model_lower.contains("gpt-4-1-mini") {
163            ModelPricing::gpt4o_mini()
164        } else if model_lower.contains("gpt-4o") {
165            ModelPricing::gpt4o()
166        } else if model_lower.contains("gpt-4-turbo") {
167            ModelPricing::gpt4_turbo()
168        } else if model_lower.contains("gpt-4") {
169            ModelPricing::gpt4()
170        } else if model_lower.contains("claude-3-opus") || model_lower.contains("opus") {
171            ModelPricing::claude3_opus()
172        } else if model_lower.contains("claude-3.5-sonnet")
173            || model_lower.contains("claude-3-5-sonnet")
174            || model_lower.contains("sonnet")
175        {
176            ModelPricing::claude35_sonnet()
177        } else if model_lower.contains("claude-3-haiku") || model_lower.contains("haiku") {
178            ModelPricing::claude3_haiku()
179        } else {
180            // Default to a conservative estimate
181            ModelPricing::default_pricing()
182        }
183    }
184}
185
186// =============================================================================
187// Usage Accumulator
188// =============================================================================
189
190/// Accumulates token usage and cost across multiple LLM calls
191#[derive(Debug, Clone, Default)]
192pub struct UsageAccumulator {
193    /// Total prompt tokens across all calls
194    pub total_prompt_tokens: u64,
195    /// Total completion tokens across all calls
196    pub total_completion_tokens: u64,
197    /// Total tokens across all calls
198    pub total_tokens: u64,
199    /// Total cost across all calls (USD)
200    pub total_cost_usd: f64,
201    /// Number of LLM calls made
202    pub call_count: u64,
203}
204
205impl UsageAccumulator {
206    /// Create a new accumulator
207    pub fn new() -> Self {
208        Self::default()
209    }
210
211    /// Add usage from a single call
212    pub fn add(&mut self, usage: &TokenUsage, cost: f64) {
213        self.total_prompt_tokens += usage.prompt_tokens as u64;
214        self.total_completion_tokens += usage.completion_tokens as u64;
215        self.total_tokens += usage.total_tokens as u64;
216        self.total_cost_usd += cost;
217        self.call_count += 1;
218    }
219
220    /// Add usage with automatic cost calculation
221    pub fn add_with_pricing(&mut self, usage: &TokenUsage, pricing: &ModelPricing) {
222        let cost = CostCalculator::calculate_cost(usage, pricing);
223        self.add(usage, cost);
224    }
225
226    /// Add usage with model-based pricing
227    pub fn add_for_model(&mut self, usage: &TokenUsage, model: &str) {
228        let pricing = CostCalculator::pricing_for_model(model);
229        self.add_with_pricing(usage, &pricing);
230    }
231
232    /// Get average tokens per call
233    pub fn avg_tokens_per_call(&self) -> f64 {
234        if self.call_count == 0 {
235            0.0
236        } else {
237            self.total_tokens as f64 / self.call_count as f64
238        }
239    }
240
241    /// Get average cost per call
242    pub fn avg_cost_per_call(&self) -> f64 {
243        if self.call_count == 0 {
244            0.0
245        } else {
246            self.total_cost_usd / self.call_count as f64
247        }
248    }
249}
250
251// =============================================================================
252// Tests
253// =============================================================================
254
255#[cfg(test)]
256mod tests {
257    use super::*;
258
259    #[test]
260    fn test_token_usage_new() {
261        let usage = TokenUsage::new(100, 50);
262        assert_eq!(usage.prompt_tokens, 100);
263        assert_eq!(usage.completion_tokens, 50);
264        assert_eq!(usage.total_tokens, 150);
265    }
266
267    #[test]
268    fn test_token_usage_from_total() {
269        let usage = TokenUsage::from_total(200);
270        assert_eq!(usage.prompt_tokens, 0);
271        assert_eq!(usage.completion_tokens, 0);
272        assert_eq!(usage.total_tokens, 200);
273    }
274
275    #[test]
276    fn test_token_usage_is_empty() {
277        let empty = TokenUsage::default();
278        let non_empty = TokenUsage::new(10, 5);
279
280        assert!(empty.is_empty());
281        assert!(!non_empty.is_empty());
282    }
283
284    #[test]
285    fn test_cost_calculator() {
286        let usage = TokenUsage::new(1_000_000, 500_000);
287        let pricing = ModelPricing::new(3.0, 15.0);
288
289        let cost = CostCalculator::calculate_cost(&usage, &pricing);
290        // 1M input tokens at $3/1M = $3
291        // 500K output tokens at $15/1M = $7.50
292        // Total = $10.50
293        assert!((cost - 10.50).abs() < 0.001);
294    }
295
296    #[test]
297    fn test_cost_calculator_small_usage() {
298        let usage = TokenUsage::new(1000, 500);
299        let pricing = ModelPricing::gpt4o_mini();
300
301        let cost = CostCalculator::calculate_cost(&usage, &pricing);
302        // 1K input tokens at $0.15/1M = $0.00015
303        // 500 output tokens at $0.60/1M = $0.0003
304        // Total = $0.00045
305        assert!(cost > 0.0 && cost < 0.001);
306    }
307
308    #[test]
309    fn test_pricing_for_model() {
310        let gpt4 = CostCalculator::pricing_for_model("gpt-4");
311        assert_eq!(gpt4.input_cost_per_1m, 30.0);
312
313        let gpt4o = CostCalculator::pricing_for_model("gpt-4o");
314        assert_eq!(gpt4o.input_cost_per_1m, 2.5);
315
316        let claude = CostCalculator::pricing_for_model("claude-3.5-sonnet");
317        assert_eq!(claude.input_cost_per_1m, 3.0);
318    }
319
320    #[test]
321    fn test_usage_accumulator() {
322        let mut acc = UsageAccumulator::new();
323
324        acc.add(&TokenUsage::new(100, 50), 0.01);
325        acc.add(&TokenUsage::new(200, 100), 0.02);
326
327        assert_eq!(acc.total_prompt_tokens, 300);
328        assert_eq!(acc.total_completion_tokens, 150);
329        assert_eq!(acc.total_tokens, 450);
330        assert!((acc.total_cost_usd - 0.03).abs() < 0.0001);
331        assert_eq!(acc.call_count, 2);
332    }
333
334    #[test]
335    fn test_usage_accumulator_averages() {
336        let mut acc = UsageAccumulator::new();
337
338        acc.add(&TokenUsage::new(100, 100), 0.02);
339        acc.add(&TokenUsage::new(200, 200), 0.04);
340
341        assert_eq!(acc.avg_tokens_per_call(), 300.0);
342        assert!((acc.avg_cost_per_call() - 0.03).abs() < 0.0001);
343    }
344
345    #[test]
346    fn test_usage_accumulator_empty() {
347        let acc = UsageAccumulator::new();
348        assert_eq!(acc.avg_tokens_per_call(), 0.0);
349        assert_eq!(acc.avg_cost_per_call(), 0.0);
350    }
351}