1use serde::{Deserialize, Serialize};
13
14#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)]
20pub struct TokenUsage {
21 pub prompt_tokens: u32,
23 pub completion_tokens: u32,
25 pub total_tokens: u32,
27}
28
29impl TokenUsage {
30 pub fn new(prompt_tokens: u32, completion_tokens: u32) -> Self {
32 Self {
33 prompt_tokens,
34 completion_tokens,
35 total_tokens: prompt_tokens + completion_tokens,
36 }
37 }
38
39 pub fn from_total(total_tokens: u32) -> Self {
41 Self {
42 prompt_tokens: 0,
43 completion_tokens: 0,
44 total_tokens,
45 }
46 }
47
48 pub fn is_empty(&self) -> bool {
50 self.total_tokens == 0
51 }
52}
53
54#[derive(Debug, Clone, Serialize, Deserialize)]
60pub struct ModelPricing {
61 pub input_cost_per_1m: f64,
63 pub output_cost_per_1m: f64,
65}
66
67impl ModelPricing {
68 pub fn new(input_cost_per_1m: f64, output_cost_per_1m: f64) -> Self {
70 Self {
71 input_cost_per_1m,
72 output_cost_per_1m,
73 }
74 }
75
76 pub fn default_pricing() -> Self {
78 Self {
79 input_cost_per_1m: 3.0, output_cost_per_1m: 15.0, }
82 }
83
84 pub fn gpt4() -> Self {
86 Self {
87 input_cost_per_1m: 30.0,
88 output_cost_per_1m: 60.0,
89 }
90 }
91
92 pub fn gpt4_turbo() -> Self {
94 Self {
95 input_cost_per_1m: 10.0,
96 output_cost_per_1m: 30.0,
97 }
98 }
99
100 pub fn gpt4o() -> Self {
102 Self {
103 input_cost_per_1m: 2.5,
104 output_cost_per_1m: 10.0,
105 }
106 }
107
108 pub fn gpt4o_mini() -> Self {
110 Self {
111 input_cost_per_1m: 0.15,
112 output_cost_per_1m: 0.60,
113 }
114 }
115
116 pub fn claude3_opus() -> Self {
118 Self {
119 input_cost_per_1m: 15.0,
120 output_cost_per_1m: 75.0,
121 }
122 }
123
124 pub fn claude35_sonnet() -> Self {
126 Self {
127 input_cost_per_1m: 3.0,
128 output_cost_per_1m: 15.0,
129 }
130 }
131
132 pub fn claude3_haiku() -> Self {
134 Self {
135 input_cost_per_1m: 0.25,
136 output_cost_per_1m: 1.25,
137 }
138 }
139}
140
141pub struct CostCalculator;
143
144impl CostCalculator {
145 pub fn calculate_cost(usage: &TokenUsage, pricing: &ModelPricing) -> f64 {
147 let input_cost = (usage.prompt_tokens as f64 / 1_000_000.0) * pricing.input_cost_per_1m;
148 let output_cost =
149 (usage.completion_tokens as f64 / 1_000_000.0) * pricing.output_cost_per_1m;
150 input_cost + output_cost
151 }
152
153 pub fn calculate_cost_default(usage: &TokenUsage) -> f64 {
155 Self::calculate_cost(usage, &ModelPricing::default_pricing())
156 }
157
158 pub fn pricing_for_model(model: &str) -> ModelPricing {
160 let model_lower = model.to_lowercase();
161
162 if model_lower.contains("gpt-4o-mini") || model_lower.contains("gpt-4-1-mini") {
163 ModelPricing::gpt4o_mini()
164 } else if model_lower.contains("gpt-4o") {
165 ModelPricing::gpt4o()
166 } else if model_lower.contains("gpt-4-turbo") {
167 ModelPricing::gpt4_turbo()
168 } else if model_lower.contains("gpt-4") {
169 ModelPricing::gpt4()
170 } else if model_lower.contains("claude-3-opus") || model_lower.contains("opus") {
171 ModelPricing::claude3_opus()
172 } else if model_lower.contains("claude-3.5-sonnet")
173 || model_lower.contains("claude-3-5-sonnet")
174 || model_lower.contains("sonnet")
175 {
176 ModelPricing::claude35_sonnet()
177 } else if model_lower.contains("claude-3-haiku") || model_lower.contains("haiku") {
178 ModelPricing::claude3_haiku()
179 } else {
180 ModelPricing::default_pricing()
182 }
183 }
184}
185
186#[derive(Debug, Clone, Default)]
192pub struct UsageAccumulator {
193 pub total_prompt_tokens: u64,
195 pub total_completion_tokens: u64,
197 pub total_tokens: u64,
199 pub total_cost_usd: f64,
201 pub call_count: u64,
203}
204
205impl UsageAccumulator {
206 pub fn new() -> Self {
208 Self::default()
209 }
210
211 pub fn add(&mut self, usage: &TokenUsage, cost: f64) {
213 self.total_prompt_tokens += usage.prompt_tokens as u64;
214 self.total_completion_tokens += usage.completion_tokens as u64;
215 self.total_tokens += usage.total_tokens as u64;
216 self.total_cost_usd += cost;
217 self.call_count += 1;
218 }
219
220 pub fn add_with_pricing(&mut self, usage: &TokenUsage, pricing: &ModelPricing) {
222 let cost = CostCalculator::calculate_cost(usage, pricing);
223 self.add(usage, cost);
224 }
225
226 pub fn add_for_model(&mut self, usage: &TokenUsage, model: &str) {
228 let pricing = CostCalculator::pricing_for_model(model);
229 self.add_with_pricing(usage, &pricing);
230 }
231
232 pub fn avg_tokens_per_call(&self) -> f64 {
234 if self.call_count == 0 {
235 0.0
236 } else {
237 self.total_tokens as f64 / self.call_count as f64
238 }
239 }
240
241 pub fn avg_cost_per_call(&self) -> f64 {
243 if self.call_count == 0 {
244 0.0
245 } else {
246 self.total_cost_usd / self.call_count as f64
247 }
248 }
249}
250
251#[cfg(test)]
256mod tests {
257 use super::*;
258
259 #[test]
260 fn test_token_usage_new() {
261 let usage = TokenUsage::new(100, 50);
262 assert_eq!(usage.prompt_tokens, 100);
263 assert_eq!(usage.completion_tokens, 50);
264 assert_eq!(usage.total_tokens, 150);
265 }
266
267 #[test]
268 fn test_token_usage_from_total() {
269 let usage = TokenUsage::from_total(200);
270 assert_eq!(usage.prompt_tokens, 0);
271 assert_eq!(usage.completion_tokens, 0);
272 assert_eq!(usage.total_tokens, 200);
273 }
274
275 #[test]
276 fn test_token_usage_is_empty() {
277 let empty = TokenUsage::default();
278 let non_empty = TokenUsage::new(10, 5);
279
280 assert!(empty.is_empty());
281 assert!(!non_empty.is_empty());
282 }
283
284 #[test]
285 fn test_cost_calculator() {
286 let usage = TokenUsage::new(1_000_000, 500_000);
287 let pricing = ModelPricing::new(3.0, 15.0);
288
289 let cost = CostCalculator::calculate_cost(&usage, &pricing);
290 assert!((cost - 10.50).abs() < 0.001);
294 }
295
296 #[test]
297 fn test_cost_calculator_small_usage() {
298 let usage = TokenUsage::new(1000, 500);
299 let pricing = ModelPricing::gpt4o_mini();
300
301 let cost = CostCalculator::calculate_cost(&usage, &pricing);
302 assert!(cost > 0.0 && cost < 0.001);
306 }
307
308 #[test]
309 fn test_pricing_for_model() {
310 let gpt4 = CostCalculator::pricing_for_model("gpt-4");
311 assert_eq!(gpt4.input_cost_per_1m, 30.0);
312
313 let gpt4o = CostCalculator::pricing_for_model("gpt-4o");
314 assert_eq!(gpt4o.input_cost_per_1m, 2.5);
315
316 let claude = CostCalculator::pricing_for_model("claude-3.5-sonnet");
317 assert_eq!(claude.input_cost_per_1m, 3.0);
318 }
319
320 #[test]
321 fn test_usage_accumulator() {
322 let mut acc = UsageAccumulator::new();
323
324 acc.add(&TokenUsage::new(100, 50), 0.01);
325 acc.add(&TokenUsage::new(200, 100), 0.02);
326
327 assert_eq!(acc.total_prompt_tokens, 300);
328 assert_eq!(acc.total_completion_tokens, 150);
329 assert_eq!(acc.total_tokens, 450);
330 assert!((acc.total_cost_usd - 0.03).abs() < 0.0001);
331 assert_eq!(acc.call_count, 2);
332 }
333
334 #[test]
335 fn test_usage_accumulator_averages() {
336 let mut acc = UsageAccumulator::new();
337
338 acc.add(&TokenUsage::new(100, 100), 0.02);
339 acc.add(&TokenUsage::new(200, 200), 0.04);
340
341 assert_eq!(acc.avg_tokens_per_call(), 300.0);
342 assert!((acc.avg_cost_per_call() - 0.03).abs() < 0.0001);
343 }
344
345 #[test]
346 fn test_usage_accumulator_empty() {
347 let acc = UsageAccumulator::new();
348 assert_eq!(acc.avg_tokens_per_call(), 0.0);
349 assert_eq!(acc.avg_cost_per_call(), 0.0);
350 }
351}