1use serde::{Deserialize, Serialize};
2use std::collections::HashMap;
3use thiserror::Error;
4
5#[derive(Debug, Clone, Serialize, Deserialize)]
6pub struct ModelPricing {
7 pub model_name: String,
8 pub provider: String,
9 pub input_cost_per_1k_tokens: f64, pub output_cost_per_1k_tokens: f64, pub context_window: usize,
12 pub max_output_tokens: Option<usize>,
13}
14
15#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct CostEstimate {
17 pub model_name: String,
18 pub input_tokens: usize,
19 pub output_tokens: usize,
20 pub input_cost: f64,
21 pub output_cost: f64,
22 pub total_cost: f64,
23 pub currency: String,
24}
25
26#[derive(Debug, Clone, Serialize, Deserialize)]
27pub struct BudgetStatus {
28 pub budget_usd: f64,
29 pub spent_usd: f64,
30 pub remaining_usd: f64,
31 pub percent_used: f64,
32 pub status: BudgetAlert,
33}
34
35#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
36pub enum BudgetAlert {
37 Ok, Warning, Critical, Exceeded, }
42
43pub struct CostCalculator {
44 pricing_table: HashMap<String, ModelPricing>,
45}
46
47impl CostCalculator {
48 pub fn new() -> Self {
50 let mut pricing_table = HashMap::new();
51
52 pricing_table.insert(
54 "gpt-4".to_string(),
55 ModelPricing {
56 model_name: "gpt-4".to_string(),
57 provider: "openai".to_string(),
58 input_cost_per_1k_tokens: 0.03,
59 output_cost_per_1k_tokens: 0.06,
60 context_window: 8192,
61 max_output_tokens: Some(4096),
62 },
63 );
64
65 pricing_table.insert(
66 "gpt-4-turbo".to_string(),
67 ModelPricing {
68 model_name: "gpt-4-turbo".to_string(),
69 provider: "openai".to_string(),
70 input_cost_per_1k_tokens: 0.01,
71 output_cost_per_1k_tokens: 0.03,
72 context_window: 128000,
73 max_output_tokens: Some(4096),
74 },
75 );
76
77 pricing_table.insert(
78 "gpt-3.5-turbo".to_string(),
79 ModelPricing {
80 model_name: "gpt-3.5-turbo".to_string(),
81 provider: "openai".to_string(),
82 input_cost_per_1k_tokens: 0.0005,
83 output_cost_per_1k_tokens: 0.0015,
84 context_window: 16385,
85 max_output_tokens: Some(4096),
86 },
87 );
88
89 pricing_table.insert(
90 "gpt-4o".to_string(),
91 ModelPricing {
92 model_name: "gpt-4o".to_string(),
93 provider: "openai".to_string(),
94 input_cost_per_1k_tokens: 0.005,
95 output_cost_per_1k_tokens: 0.015,
96 context_window: 128000,
97 max_output_tokens: Some(4096),
98 },
99 );
100
101 pricing_table.insert(
102 "gpt-4o-mini".to_string(),
103 ModelPricing {
104 model_name: "gpt-4o-mini".to_string(),
105 provider: "openai".to_string(),
106 input_cost_per_1k_tokens: 0.00015,
107 output_cost_per_1k_tokens: 0.0006,
108 context_window: 128000,
109 max_output_tokens: Some(16384),
110 },
111 );
112
113 pricing_table.insert(
115 "claude-3-opus".to_string(),
116 ModelPricing {
117 model_name: "claude-3-opus".to_string(),
118 provider: "anthropic".to_string(),
119 input_cost_per_1k_tokens: 0.015,
120 output_cost_per_1k_tokens: 0.075,
121 context_window: 200000,
122 max_output_tokens: Some(4096),
123 },
124 );
125
126 pricing_table.insert(
127 "claude-3-sonnet".to_string(),
128 ModelPricing {
129 model_name: "claude-3-sonnet".to_string(),
130 provider: "anthropic".to_string(),
131 input_cost_per_1k_tokens: 0.003,
132 output_cost_per_1k_tokens: 0.015,
133 context_window: 200000,
134 max_output_tokens: Some(4096),
135 },
136 );
137
138 pricing_table.insert(
139 "claude-3-haiku".to_string(),
140 ModelPricing {
141 model_name: "claude-3-haiku".to_string(),
142 provider: "anthropic".to_string(),
143 input_cost_per_1k_tokens: 0.00025,
144 output_cost_per_1k_tokens: 0.00125,
145 context_window: 200000,
146 max_output_tokens: Some(4096),
147 },
148 );
149
150 pricing_table.insert(
151 "claude-3-5-sonnet".to_string(),
152 ModelPricing {
153 model_name: "claude-3-5-sonnet".to_string(),
154 provider: "anthropic".to_string(),
155 input_cost_per_1k_tokens: 0.003,
156 output_cost_per_1k_tokens: 0.015,
157 context_window: 200000,
158 max_output_tokens: Some(8192),
159 },
160 );
161
162 pricing_table.insert(
164 "gemini-pro".to_string(),
165 ModelPricing {
166 model_name: "gemini-pro".to_string(),
167 provider: "google".to_string(),
168 input_cost_per_1k_tokens: 0.0005,
169 output_cost_per_1k_tokens: 0.0015,
170 context_window: 30720,
171 max_output_tokens: Some(2048),
172 },
173 );
174
175 pricing_table.insert(
176 "gemini-ultra".to_string(),
177 ModelPricing {
178 model_name: "gemini-ultra".to_string(),
179 provider: "google".to_string(),
180 input_cost_per_1k_tokens: 0.0125,
181 output_cost_per_1k_tokens: 0.0375,
182 context_window: 30720,
183 max_output_tokens: Some(2048),
184 },
185 );
186
187 Self { pricing_table }
188 }
189
190 pub fn estimate_cost(
192 &self,
193 model_name: &str,
194 input_tokens: usize,
195 output_tokens: usize,
196 ) -> Result<CostEstimate, CostError> {
197 let pricing = self
198 .pricing_table
199 .get(model_name)
200 .ok_or_else(|| CostError::UnknownModel(model_name.to_string()))?;
201
202 if input_tokens == 0 && output_tokens == 0 {
203 return Err(CostError::InvalidTokenCount);
204 }
205
206 if input_tokens + output_tokens > pricing.context_window {
208 return Err(CostError::InvalidTokenCount);
209 }
210
211 if let Some(max_output) = pricing.max_output_tokens {
213 if output_tokens > max_output {
214 return Err(CostError::InvalidTokenCount);
215 }
216 }
217
218 let input_cost = (input_tokens as f64 / 1000.0) * pricing.input_cost_per_1k_tokens;
219 let output_cost = (output_tokens as f64 / 1000.0) * pricing.output_cost_per_1k_tokens;
220 let total_cost = input_cost + output_cost;
221
222 Ok(CostEstimate {
223 model_name: model_name.to_string(),
224 input_tokens,
225 output_tokens,
226 input_cost,
227 output_cost,
228 total_cost,
229 currency: "USD".to_string(),
230 })
231 }
232
233 pub fn estimate_cost_from_text(
235 &self,
236 model_name: &str,
237 input_text: &str,
238 estimated_output_tokens: usize,
239 ) -> Result<CostEstimate, CostError> {
240 let input_tokens = self.estimate_tokens(input_text);
241 self.estimate_cost(model_name, input_tokens, estimated_output_tokens)
242 }
243
244 pub fn check_budget(&self, spent: f64, budget: f64) -> BudgetStatus {
246 if budget <= 0.0 {
247 return BudgetStatus {
248 budget_usd: budget,
249 spent_usd: spent,
250 remaining_usd: budget - spent,
251 percent_used: 100.0,
252 status: BudgetAlert::Exceeded,
253 };
254 }
255
256 let percent_used = (spent / budget) * 100.0;
257 let remaining = budget - spent;
258
259 let status = match percent_used {
260 p if p >= 100.0 => BudgetAlert::Exceeded,
261 p if p >= 95.0 => BudgetAlert::Critical,
262 p if p >= 80.0 => BudgetAlert::Warning,
263 _ => BudgetAlert::Ok,
264 };
265
266 BudgetStatus {
267 budget_usd: budget,
268 spent_usd: spent,
269 remaining_usd: remaining,
270 percent_used: percent_used.min(100.0),
271 status,
272 }
273 }
274
275 pub fn get_cheapest_model(&self, min_context_window: usize) -> Option<&ModelPricing> {
277 self.pricing_table
278 .values()
279 .filter(|pricing| pricing.context_window >= min_context_window)
280 .min_by(|a, b| {
281 let avg_cost_a = (a.input_cost_per_1k_tokens + a.output_cost_per_1k_tokens) / 2.0;
282 let avg_cost_b = (b.input_cost_per_1k_tokens + b.output_cost_per_1k_tokens) / 2.0;
283 avg_cost_a
284 .partial_cmp(&avg_cost_b)
285 .unwrap_or(std::cmp::Ordering::Equal)
286 })
287 }
288
289 pub fn get_models_under_cost(&self, max_cost_per_1k: f64) -> Vec<&ModelPricing> {
291 self.pricing_table
292 .values()
293 .filter(|pricing| {
294 let avg_cost =
295 (pricing.input_cost_per_1k_tokens + pricing.output_cost_per_1k_tokens) / 2.0;
296 avg_cost <= max_cost_per_1k
297 })
298 .collect()
299 }
300
301 pub fn get_models_by_provider(&self, provider: &str) -> Vec<&ModelPricing> {
303 self.pricing_table
304 .values()
305 .filter(|pricing| pricing.provider.eq_ignore_ascii_case(provider))
306 .collect()
307 }
308
309 pub fn compare_models(
311 &self,
312 model_a: &str,
313 model_b: &str,
314 input_tokens: usize,
315 output_tokens: usize,
316 ) -> Result<ModelComparison, CostError> {
317 let cost_a = self.estimate_cost(model_a, input_tokens, output_tokens)?;
318 let cost_b = self.estimate_cost(model_b, input_tokens, output_tokens)?;
319
320 let savings = cost_a.total_cost - cost_b.total_cost;
321 let percent_difference = if cost_a.total_cost > 0.0 {
322 (savings / cost_a.total_cost) * 100.0
323 } else {
324 0.0
325 };
326
327 Ok(ModelComparison {
328 model_a: cost_a,
329 model_b: cost_b,
330 cheaper_model: if savings > 0.0 { model_b } else { model_a }.to_string(),
331 savings: savings.abs(),
332 percent_difference: percent_difference.abs(),
333 })
334 }
335
336 pub fn add_model(&mut self, pricing: ModelPricing) {
338 self.pricing_table
339 .insert(pricing.model_name.clone(), pricing);
340 }
341
342 pub fn remove_model(&mut self, model_name: &str) -> Option<ModelPricing> {
344 self.pricing_table.remove(model_name)
345 }
346
347 pub fn get_all_models(&self) -> Vec<&ModelPricing> {
349 self.pricing_table.values().collect()
350 }
351
352 fn estimate_tokens(&self, text: &str) -> usize {
354 let char_count = text.len();
357
358 let token_estimate = if text.is_ascii() {
360 (char_count as f64 / 4.0).ceil() as usize
362 } else {
363 (char_count as f64 / 3.0).ceil() as usize
365 };
366
367 token_estimate + (token_estimate / 20) }
370
371 pub fn project_monthly_cost(
373 &self,
374 model_name: &str,
375 daily_input_tokens: usize,
376 daily_output_tokens: usize,
377 days_per_month: f64,
378 ) -> Result<CostProjection, CostError> {
379 let daily_cost = self.estimate_cost(model_name, daily_input_tokens, daily_output_tokens)?;
380 let monthly_cost = daily_cost.total_cost * days_per_month;
381
382 Ok(CostProjection {
383 model_name: model_name.to_string(),
384 daily_cost: daily_cost.total_cost,
385 monthly_cost,
386 annual_cost: monthly_cost * 12.0,
387 currency: "USD".to_string(),
388 })
389 }
390}
391
392impl Default for CostCalculator {
393 fn default() -> Self {
394 Self::new()
395 }
396}
397
398#[derive(Debug, Clone)]
399pub struct ModelComparison {
400 pub model_a: CostEstimate,
401 pub model_b: CostEstimate,
402 pub cheaper_model: String,
403 pub savings: f64,
404 pub percent_difference: f64,
405}
406
407#[derive(Debug, Clone, Serialize, Deserialize)]
408pub struct CostProjection {
409 pub model_name: String,
410 pub daily_cost: f64,
411 pub monthly_cost: f64,
412 pub annual_cost: f64,
413 pub currency: String,
414}
415
416#[derive(Error, Debug, Clone, PartialEq)]
417pub enum CostError {
418 #[error("Unknown model: {0}")]
419 UnknownModel(String),
420 #[error("Invalid token count")]
421 InvalidTokenCount,
422}
423
424#[cfg(test)]
425mod tests {
426 use super::*;
427
428 #[test]
429 fn test_cost_estimation() {
430 let calculator = CostCalculator::new();
431
432 let estimate = calculator.estimate_cost("gpt-4", 1000, 500).unwrap();
433
434 assert_eq!(estimate.model_name, "gpt-4");
435 assert_eq!(estimate.input_tokens, 1000);
436 assert_eq!(estimate.output_tokens, 500);
437 assert_eq!(estimate.input_cost, 0.03); assert_eq!(estimate.output_cost, 0.03); assert_eq!(estimate.total_cost, 0.06);
440 assert_eq!(estimate.currency, "USD");
441 }
442
443 #[test]
444 fn test_unknown_model() {
445 let calculator = CostCalculator::new();
446 let result = calculator.estimate_cost("unknown-model", 1000, 500);
447
448 assert!(matches!(result, Err(CostError::UnknownModel(_))));
449 }
450
451 #[test]
452 fn test_invalid_token_count() {
453 let calculator = CostCalculator::new();
454
455 let result = calculator.estimate_cost("gpt-4", 0, 0);
457 assert!(matches!(result, Err(CostError::InvalidTokenCount)));
458
459 let result = calculator.estimate_cost("gpt-4", 10000, 0);
461 assert!(matches!(result, Err(CostError::InvalidTokenCount)));
462
463 let result = calculator.estimate_cost("gpt-4", 1000, 5000);
465 assert!(matches!(result, Err(CostError::InvalidTokenCount)));
466 }
467
468 #[test]
469 fn test_budget_status() {
470 let calculator = CostCalculator::new();
471
472 let status = calculator.check_budget(50.0, 100.0);
474 assert_eq!(status.status, BudgetAlert::Ok);
475 assert_eq!(status.percent_used, 50.0);
476 assert_eq!(status.remaining_usd, 50.0);
477
478 let status = calculator.check_budget(85.0, 100.0);
480 assert_eq!(status.status, BudgetAlert::Warning);
481
482 let status = calculator.check_budget(96.0, 100.0);
484 assert_eq!(status.status, BudgetAlert::Critical);
485
486 let status = calculator.check_budget(110.0, 100.0);
488 assert_eq!(status.status, BudgetAlert::Exceeded);
489 assert_eq!(status.remaining_usd, -10.0);
490 }
491
492 #[test]
493 fn test_cheapest_model() {
494 let calculator = CostCalculator::new();
495
496 let cheapest = calculator.get_cheapest_model(8000);
497 assert!(cheapest.is_some());
498 let model = cheapest.unwrap();
499
500 assert!(model.context_window >= 8000);
502 }
503
504 #[test]
505 fn test_models_under_cost() {
506 let calculator = CostCalculator::new();
507
508 let cheap_models = calculator.get_models_under_cost(0.01);
509 assert!(!cheap_models.is_empty());
510
511 for model in &cheap_models {
513 let avg_cost = (model.input_cost_per_1k_tokens + model.output_cost_per_1k_tokens) / 2.0;
514 assert!(avg_cost <= 0.01);
515 }
516 }
517
518 #[test]
519 fn test_models_by_provider() {
520 let calculator = CostCalculator::new();
521
522 let openai_models = calculator.get_models_by_provider("openai");
523 assert!(!openai_models.is_empty());
524 for model in &openai_models {
525 assert_eq!(model.provider, "openai");
526 }
527
528 let anthropic_models = calculator.get_models_by_provider("anthropic");
529 assert!(!anthropic_models.is_empty());
530 for model in &anthropic_models {
531 assert_eq!(model.provider, "anthropic");
532 }
533 }
534
535 #[test]
536 fn test_model_comparison() {
537 let calculator = CostCalculator::new();
538
539 let comparison = calculator
540 .compare_models("gpt-4", "gpt-3.5-turbo", 1000, 500)
541 .unwrap();
542
543 assert_eq!(comparison.cheaper_model, "gpt-3.5-turbo");
545 assert!(comparison.savings > 0.0);
546 assert!(comparison.percent_difference > 0.0);
547 }
548
549 #[test]
550 fn test_cost_from_text() {
551 let calculator = CostCalculator::new();
552
553 let text = "Hello, world!";
554 let estimate = calculator
555 .estimate_cost_from_text("gpt-3.5-turbo", text, 100)
556 .unwrap();
557
558 assert!(estimate.input_tokens > 0);
559 assert_eq!(estimate.output_tokens, 100);
560 assert!(estimate.total_cost > 0.0);
561 }
562
563 #[test]
564 fn test_token_estimation() {
565 let calculator = CostCalculator::new();
566
567 let english_text = "Hello, world! This is a test.";
569 let tokens = calculator.estimate_tokens(english_text);
570
571 let expected = ((english_text.len() as f64 / 4.0).ceil() as usize * 105) / 100; assert!(tokens >= expected - 2 && tokens <= expected + 2);
574
575 assert_eq!(calculator.estimate_tokens(""), 0);
577 }
578
579 #[test]
580 fn test_custom_model() {
581 let mut calculator = CostCalculator::new();
582
583 let custom_model = ModelPricing {
584 model_name: "custom-model".to_string(),
585 provider: "custom".to_string(),
586 input_cost_per_1k_tokens: 0.001,
587 output_cost_per_1k_tokens: 0.002,
588 context_window: 4096,
589 max_output_tokens: Some(2048),
590 };
591
592 calculator.add_model(custom_model.clone());
593
594 let estimate = calculator.estimate_cost("custom-model", 1000, 500).unwrap();
595 assert_eq!(estimate.input_cost, 0.001);
596 assert_eq!(estimate.output_cost, 0.001);
597 assert_eq!(estimate.total_cost, 0.002);
598
599 let removed = calculator.remove_model("custom-model");
601 assert!(removed.is_some());
602 assert_eq!(removed.unwrap().model_name, "custom-model");
603
604 let result = calculator.estimate_cost("custom-model", 1000, 500);
606 assert!(matches!(result, Err(CostError::UnknownModel(_))));
607 }
608
609 #[test]
610 fn test_cost_projection() {
611 let calculator = CostCalculator::new();
612
613 let projection = calculator
614 .project_monthly_cost("gpt-4", 4000, 2000, 30.0)
615 .unwrap();
616
617 assert_eq!(projection.model_name, "gpt-4");
618 assert!(projection.daily_cost > 0.0);
619 assert_eq!(projection.monthly_cost, projection.daily_cost * 30.0);
620 assert_eq!(projection.annual_cost, projection.monthly_cost * 12.0);
621 }
622
623 #[test]
624 fn test_all_default_models_available() {
625 let calculator = CostCalculator::new();
626
627 let test_models = [
629 "gpt-4",
630 "gpt-4-turbo",
631 "gpt-3.5-turbo",
632 "gpt-4o",
633 "gpt-4o-mini",
634 "claude-3-opus",
635 "claude-3-sonnet",
636 "claude-3-haiku",
637 "claude-3-5-sonnet",
638 "gemini-pro",
639 "gemini-ultra",
640 ];
641
642 for model in &test_models {
643 let result = calculator.estimate_cost(model, 1000, 500);
644 assert!(result.is_ok(), "Model {} should be available", model);
645 }
646 }
647}