Skip to main content

briefcase_core/
cost.rs

1use serde::{Deserialize, Serialize};
2use std::collections::HashMap;
3use thiserror::Error;
4
5#[derive(Debug, Clone, Serialize, Deserialize)]
6pub struct ModelPricing {
7    pub model_name: String,
8    pub provider: String,
9    pub input_cost_per_1k_tokens: f64,  // USD
10    pub output_cost_per_1k_tokens: f64, // USD
11    pub context_window: usize,
12    pub max_output_tokens: Option<usize>,
13}
14
15#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct CostEstimate {
17    pub model_name: String,
18    pub input_tokens: usize,
19    pub output_tokens: usize,
20    pub input_cost: f64,
21    pub output_cost: f64,
22    pub total_cost: f64,
23    pub currency: String,
24}
25
26#[derive(Debug, Clone, Serialize, Deserialize)]
27pub struct BudgetStatus {
28    pub budget_usd: f64,
29    pub spent_usd: f64,
30    pub remaining_usd: f64,
31    pub percent_used: f64,
32    pub status: BudgetAlert,
33}
34
35#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
36pub enum BudgetAlert {
37    Ok,       // < 80% used
38    Warning,  // 80-95% used
39    Critical, // 95-100% used
40    Exceeded, // > 100% used
41}
42
43pub struct CostCalculator {
44    pricing_table: HashMap<String, ModelPricing>,
45}
46
47impl CostCalculator {
48    /// Create with default pricing table (OpenAI, Anthropic, etc.)
49    pub fn new() -> Self {
50        let mut pricing_table = HashMap::new();
51
52        // OpenAI models
53        pricing_table.insert(
54            "gpt-4".to_string(),
55            ModelPricing {
56                model_name: "gpt-4".to_string(),
57                provider: "openai".to_string(),
58                input_cost_per_1k_tokens: 0.03,
59                output_cost_per_1k_tokens: 0.06,
60                context_window: 8192,
61                max_output_tokens: Some(4096),
62            },
63        );
64
65        pricing_table.insert(
66            "gpt-4-turbo".to_string(),
67            ModelPricing {
68                model_name: "gpt-4-turbo".to_string(),
69                provider: "openai".to_string(),
70                input_cost_per_1k_tokens: 0.01,
71                output_cost_per_1k_tokens: 0.03,
72                context_window: 128000,
73                max_output_tokens: Some(4096),
74            },
75        );
76
77        pricing_table.insert(
78            "gpt-3.5-turbo".to_string(),
79            ModelPricing {
80                model_name: "gpt-3.5-turbo".to_string(),
81                provider: "openai".to_string(),
82                input_cost_per_1k_tokens: 0.0005,
83                output_cost_per_1k_tokens: 0.0015,
84                context_window: 16385,
85                max_output_tokens: Some(4096),
86            },
87        );
88
89        pricing_table.insert(
90            "gpt-4o".to_string(),
91            ModelPricing {
92                model_name: "gpt-4o".to_string(),
93                provider: "openai".to_string(),
94                input_cost_per_1k_tokens: 0.005,
95                output_cost_per_1k_tokens: 0.015,
96                context_window: 128000,
97                max_output_tokens: Some(4096),
98            },
99        );
100
101        pricing_table.insert(
102            "gpt-4o-mini".to_string(),
103            ModelPricing {
104                model_name: "gpt-4o-mini".to_string(),
105                provider: "openai".to_string(),
106                input_cost_per_1k_tokens: 0.00015,
107                output_cost_per_1k_tokens: 0.0006,
108                context_window: 128000,
109                max_output_tokens: Some(16384),
110            },
111        );
112
113        // Anthropic models
114        pricing_table.insert(
115            "claude-3-opus".to_string(),
116            ModelPricing {
117                model_name: "claude-3-opus".to_string(),
118                provider: "anthropic".to_string(),
119                input_cost_per_1k_tokens: 0.015,
120                output_cost_per_1k_tokens: 0.075,
121                context_window: 200000,
122                max_output_tokens: Some(4096),
123            },
124        );
125
126        pricing_table.insert(
127            "claude-3-sonnet".to_string(),
128            ModelPricing {
129                model_name: "claude-3-sonnet".to_string(),
130                provider: "anthropic".to_string(),
131                input_cost_per_1k_tokens: 0.003,
132                output_cost_per_1k_tokens: 0.015,
133                context_window: 200000,
134                max_output_tokens: Some(4096),
135            },
136        );
137
138        pricing_table.insert(
139            "claude-3-haiku".to_string(),
140            ModelPricing {
141                model_name: "claude-3-haiku".to_string(),
142                provider: "anthropic".to_string(),
143                input_cost_per_1k_tokens: 0.00025,
144                output_cost_per_1k_tokens: 0.00125,
145                context_window: 200000,
146                max_output_tokens: Some(4096),
147            },
148        );
149
150        pricing_table.insert(
151            "claude-3-5-sonnet".to_string(),
152            ModelPricing {
153                model_name: "claude-3-5-sonnet".to_string(),
154                provider: "anthropic".to_string(),
155                input_cost_per_1k_tokens: 0.003,
156                output_cost_per_1k_tokens: 0.015,
157                context_window: 200000,
158                max_output_tokens: Some(8192),
159            },
160        );
161
162        // Google models
163        pricing_table.insert(
164            "gemini-pro".to_string(),
165            ModelPricing {
166                model_name: "gemini-pro".to_string(),
167                provider: "google".to_string(),
168                input_cost_per_1k_tokens: 0.0005,
169                output_cost_per_1k_tokens: 0.0015,
170                context_window: 30720,
171                max_output_tokens: Some(2048),
172            },
173        );
174
175        pricing_table.insert(
176            "gemini-ultra".to_string(),
177            ModelPricing {
178                model_name: "gemini-ultra".to_string(),
179                provider: "google".to_string(),
180                input_cost_per_1k_tokens: 0.0125,
181                output_cost_per_1k_tokens: 0.0375,
182                context_window: 30720,
183                max_output_tokens: Some(2048),
184            },
185        );
186
187        Self { pricing_table }
188    }
189
190    /// Estimate cost for a given model and token counts
191    pub fn estimate_cost(
192        &self,
193        model_name: &str,
194        input_tokens: usize,
195        output_tokens: usize,
196    ) -> Result<CostEstimate, CostError> {
197        let pricing = self
198            .pricing_table
199            .get(model_name)
200            .ok_or_else(|| CostError::UnknownModel(model_name.to_string()))?;
201
202        if input_tokens == 0 && output_tokens == 0 {
203            return Err(CostError::InvalidTokenCount);
204        }
205
206        // Check if tokens exceed context window
207        if input_tokens + output_tokens > pricing.context_window {
208            return Err(CostError::InvalidTokenCount);
209        }
210
211        // Check if output tokens exceed max output
212        if let Some(max_output) = pricing.max_output_tokens {
213            if output_tokens > max_output {
214                return Err(CostError::InvalidTokenCount);
215            }
216        }
217
218        let input_cost = (input_tokens as f64 / 1000.0) * pricing.input_cost_per_1k_tokens;
219        let output_cost = (output_tokens as f64 / 1000.0) * pricing.output_cost_per_1k_tokens;
220        let total_cost = input_cost + output_cost;
221
222        Ok(CostEstimate {
223            model_name: model_name.to_string(),
224            input_tokens,
225            output_tokens,
226            input_cost,
227            output_cost,
228            total_cost,
229            currency: "USD".to_string(),
230        })
231    }
232
233    /// Estimate cost from text (estimates tokens)
234    pub fn estimate_cost_from_text(
235        &self,
236        model_name: &str,
237        input_text: &str,
238        estimated_output_tokens: usize,
239    ) -> Result<CostEstimate, CostError> {
240        let input_tokens = self.estimate_tokens(input_text);
241        self.estimate_cost(model_name, input_tokens, estimated_output_tokens)
242    }
243
244    /// Check budget status
245    pub fn check_budget(&self, spent: f64, budget: f64) -> BudgetStatus {
246        if budget <= 0.0 {
247            return BudgetStatus {
248                budget_usd: budget,
249                spent_usd: spent,
250                remaining_usd: budget - spent,
251                percent_used: 100.0,
252                status: BudgetAlert::Exceeded,
253            };
254        }
255
256        let percent_used = (spent / budget) * 100.0;
257        let remaining = budget - spent;
258
259        let status = match percent_used {
260            p if p >= 100.0 => BudgetAlert::Exceeded,
261            p if p >= 95.0 => BudgetAlert::Critical,
262            p if p >= 80.0 => BudgetAlert::Warning,
263            _ => BudgetAlert::Ok,
264        };
265
266        BudgetStatus {
267            budget_usd: budget,
268            spent_usd: spent,
269            remaining_usd: remaining,
270            percent_used: percent_used.min(100.0),
271            status,
272        }
273    }
274
275    /// Get cheapest model for a given context size
276    pub fn get_cheapest_model(&self, min_context_window: usize) -> Option<&ModelPricing> {
277        self.pricing_table
278            .values()
279            .filter(|pricing| pricing.context_window >= min_context_window)
280            .min_by(|a, b| {
281                let avg_cost_a = (a.input_cost_per_1k_tokens + a.output_cost_per_1k_tokens) / 2.0;
282                let avg_cost_b = (b.input_cost_per_1k_tokens + b.output_cost_per_1k_tokens) / 2.0;
283                avg_cost_a
284                    .partial_cmp(&avg_cost_b)
285                    .unwrap_or(std::cmp::Ordering::Equal)
286            })
287    }
288
289    /// Get all models under a cost threshold (per 1k tokens average)
290    pub fn get_models_under_cost(&self, max_cost_per_1k: f64) -> Vec<&ModelPricing> {
291        self.pricing_table
292            .values()
293            .filter(|pricing| {
294                let avg_cost =
295                    (pricing.input_cost_per_1k_tokens + pricing.output_cost_per_1k_tokens) / 2.0;
296                avg_cost <= max_cost_per_1k
297            })
298            .collect()
299    }
300
301    /// Get models by provider
302    pub fn get_models_by_provider(&self, provider: &str) -> Vec<&ModelPricing> {
303        self.pricing_table
304            .values()
305            .filter(|pricing| pricing.provider.eq_ignore_ascii_case(provider))
306            .collect()
307    }
308
309    /// Compare cost between two models for given usage
310    pub fn compare_models(
311        &self,
312        model_a: &str,
313        model_b: &str,
314        input_tokens: usize,
315        output_tokens: usize,
316    ) -> Result<ModelComparison, CostError> {
317        let cost_a = self.estimate_cost(model_a, input_tokens, output_tokens)?;
318        let cost_b = self.estimate_cost(model_b, input_tokens, output_tokens)?;
319
320        let savings = cost_a.total_cost - cost_b.total_cost;
321        let percent_difference = if cost_a.total_cost > 0.0 {
322            (savings / cost_a.total_cost) * 100.0
323        } else {
324            0.0
325        };
326
327        Ok(ModelComparison {
328            model_a: cost_a,
329            model_b: cost_b,
330            cheaper_model: if savings > 0.0 { model_b } else { model_a }.to_string(),
331            savings: savings.abs(),
332            percent_difference: percent_difference.abs(),
333        })
334    }
335
336    /// Add custom model pricing
337    pub fn add_model(&mut self, pricing: ModelPricing) {
338        self.pricing_table
339            .insert(pricing.model_name.clone(), pricing);
340    }
341
342    /// Remove a model from pricing table
343    pub fn remove_model(&mut self, model_name: &str) -> Option<ModelPricing> {
344        self.pricing_table.remove(model_name)
345    }
346
347    /// Get all available models
348    pub fn get_all_models(&self) -> Vec<&ModelPricing> {
349        self.pricing_table.values().collect()
350    }
351
352    /// Estimate tokens from text (rough approximation: chars / 4)
353    fn estimate_tokens(&self, text: &str) -> usize {
354        // Basic tokenization approximation
355        // Real-world implementations should use proper tokenizers like tiktoken
356        let char_count = text.len();
357
358        // Account for different languages and complexity
359        let token_estimate = if text.is_ascii() {
360            // English text: roughly 4 chars per token
361            (char_count as f64 / 4.0).ceil() as usize
362        } else {
363            // Non-ASCII text: typically more tokens
364            (char_count as f64 / 3.0).ceil() as usize
365        };
366
367        // Add some tokens for special tokens, formatting, etc.
368        token_estimate + (token_estimate / 20) // Add 5% overhead
369    }
370
371    /// Calculate monthly cost projection
372    pub fn project_monthly_cost(
373        &self,
374        model_name: &str,
375        daily_input_tokens: usize,
376        daily_output_tokens: usize,
377        days_per_month: f64,
378    ) -> Result<CostProjection, CostError> {
379        let daily_cost = self.estimate_cost(model_name, daily_input_tokens, daily_output_tokens)?;
380        let monthly_cost = daily_cost.total_cost * days_per_month;
381
382        Ok(CostProjection {
383            model_name: model_name.to_string(),
384            daily_cost: daily_cost.total_cost,
385            monthly_cost,
386            annual_cost: monthly_cost * 12.0,
387            currency: "USD".to_string(),
388        })
389    }
390}
391
392impl Default for CostCalculator {
393    fn default() -> Self {
394        Self::new()
395    }
396}
397
398#[derive(Debug, Clone)]
399pub struct ModelComparison {
400    pub model_a: CostEstimate,
401    pub model_b: CostEstimate,
402    pub cheaper_model: String,
403    pub savings: f64,
404    pub percent_difference: f64,
405}
406
407#[derive(Debug, Clone, Serialize, Deserialize)]
408pub struct CostProjection {
409    pub model_name: String,
410    pub daily_cost: f64,
411    pub monthly_cost: f64,
412    pub annual_cost: f64,
413    pub currency: String,
414}
415
416#[derive(Error, Debug, Clone, PartialEq)]
417pub enum CostError {
418    #[error("Unknown model: {0}")]
419    UnknownModel(String),
420    #[error("Invalid token count")]
421    InvalidTokenCount,
422}
423
424#[cfg(test)]
425mod tests {
426    use super::*;
427
428    #[test]
429    fn test_cost_estimation() {
430        let calculator = CostCalculator::new();
431
432        let estimate = calculator.estimate_cost("gpt-4", 1000, 500).unwrap();
433
434        assert_eq!(estimate.model_name, "gpt-4");
435        assert_eq!(estimate.input_tokens, 1000);
436        assert_eq!(estimate.output_tokens, 500);
437        assert_eq!(estimate.input_cost, 0.03); // 1000 tokens = 1k tokens * $0.03
438        assert_eq!(estimate.output_cost, 0.03); // 500 tokens = 0.5k tokens * $0.06
439        assert_eq!(estimate.total_cost, 0.06);
440        assert_eq!(estimate.currency, "USD");
441    }
442
443    #[test]
444    fn test_unknown_model() {
445        let calculator = CostCalculator::new();
446        let result = calculator.estimate_cost("unknown-model", 1000, 500);
447
448        assert!(matches!(result, Err(CostError::UnknownModel(_))));
449    }
450
451    #[test]
452    fn test_invalid_token_count() {
453        let calculator = CostCalculator::new();
454
455        // Zero tokens
456        let result = calculator.estimate_cost("gpt-4", 0, 0);
457        assert!(matches!(result, Err(CostError::InvalidTokenCount)));
458
459        // Exceeding context window (gpt-4 has 8192)
460        let result = calculator.estimate_cost("gpt-4", 10000, 0);
461        assert!(matches!(result, Err(CostError::InvalidTokenCount)));
462
463        // Exceeding max output tokens (gpt-4 has 4096)
464        let result = calculator.estimate_cost("gpt-4", 1000, 5000);
465        assert!(matches!(result, Err(CostError::InvalidTokenCount)));
466    }
467
468    #[test]
469    fn test_budget_status() {
470        let calculator = CostCalculator::new();
471
472        // OK status
473        let status = calculator.check_budget(50.0, 100.0);
474        assert_eq!(status.status, BudgetAlert::Ok);
475        assert_eq!(status.percent_used, 50.0);
476        assert_eq!(status.remaining_usd, 50.0);
477
478        // Warning status
479        let status = calculator.check_budget(85.0, 100.0);
480        assert_eq!(status.status, BudgetAlert::Warning);
481
482        // Critical status
483        let status = calculator.check_budget(96.0, 100.0);
484        assert_eq!(status.status, BudgetAlert::Critical);
485
486        // Exceeded status
487        let status = calculator.check_budget(110.0, 100.0);
488        assert_eq!(status.status, BudgetAlert::Exceeded);
489        assert_eq!(status.remaining_usd, -10.0);
490    }
491
492    #[test]
493    fn test_cheapest_model() {
494        let calculator = CostCalculator::new();
495
496        let cheapest = calculator.get_cheapest_model(8000);
497        assert!(cheapest.is_some());
498        let model = cheapest.unwrap();
499
500        // Should be one of the cheaper models with sufficient context
501        assert!(model.context_window >= 8000);
502    }
503
504    #[test]
505    fn test_models_under_cost() {
506        let calculator = CostCalculator::new();
507
508        let cheap_models = calculator.get_models_under_cost(0.01);
509        assert!(!cheap_models.is_empty());
510
511        // All returned models should have average cost <= 0.01
512        for model in &cheap_models {
513            let avg_cost = (model.input_cost_per_1k_tokens + model.output_cost_per_1k_tokens) / 2.0;
514            assert!(avg_cost <= 0.01);
515        }
516    }
517
518    #[test]
519    fn test_models_by_provider() {
520        let calculator = CostCalculator::new();
521
522        let openai_models = calculator.get_models_by_provider("openai");
523        assert!(!openai_models.is_empty());
524        for model in &openai_models {
525            assert_eq!(model.provider, "openai");
526        }
527
528        let anthropic_models = calculator.get_models_by_provider("anthropic");
529        assert!(!anthropic_models.is_empty());
530        for model in &anthropic_models {
531            assert_eq!(model.provider, "anthropic");
532        }
533    }
534
535    #[test]
536    fn test_model_comparison() {
537        let calculator = CostCalculator::new();
538
539        let comparison = calculator
540            .compare_models("gpt-4", "gpt-3.5-turbo", 1000, 500)
541            .unwrap();
542
543        // GPT-3.5-turbo should be cheaper than GPT-4
544        assert_eq!(comparison.cheaper_model, "gpt-3.5-turbo");
545        assert!(comparison.savings > 0.0);
546        assert!(comparison.percent_difference > 0.0);
547    }
548
549    #[test]
550    fn test_cost_from_text() {
551        let calculator = CostCalculator::new();
552
553        let text = "Hello, world!";
554        let estimate = calculator
555            .estimate_cost_from_text("gpt-3.5-turbo", text, 100)
556            .unwrap();
557
558        assert!(estimate.input_tokens > 0);
559        assert_eq!(estimate.output_tokens, 100);
560        assert!(estimate.total_cost > 0.0);
561    }
562
563    #[test]
564    fn test_token_estimation() {
565        let calculator = CostCalculator::new();
566
567        // English text
568        let english_text = "Hello, world! This is a test.";
569        let tokens = calculator.estimate_tokens(english_text);
570
571        // Should be roughly chars/4 with some overhead
572        let expected = ((english_text.len() as f64 / 4.0).ceil() as usize * 105) / 100; // 5% overhead
573        assert!(tokens >= expected - 2 && tokens <= expected + 2);
574
575        // Empty text
576        assert_eq!(calculator.estimate_tokens(""), 0);
577    }
578
579    #[test]
580    fn test_custom_model() {
581        let mut calculator = CostCalculator::new();
582
583        let custom_model = ModelPricing {
584            model_name: "custom-model".to_string(),
585            provider: "custom".to_string(),
586            input_cost_per_1k_tokens: 0.001,
587            output_cost_per_1k_tokens: 0.002,
588            context_window: 4096,
589            max_output_tokens: Some(2048),
590        };
591
592        calculator.add_model(custom_model.clone());
593
594        let estimate = calculator.estimate_cost("custom-model", 1000, 500).unwrap();
595        assert_eq!(estimate.input_cost, 0.001);
596        assert_eq!(estimate.output_cost, 0.001);
597        assert_eq!(estimate.total_cost, 0.002);
598
599        // Test removal
600        let removed = calculator.remove_model("custom-model");
601        assert!(removed.is_some());
602        assert_eq!(removed.unwrap().model_name, "custom-model");
603
604        // Should no longer be available
605        let result = calculator.estimate_cost("custom-model", 1000, 500);
606        assert!(matches!(result, Err(CostError::UnknownModel(_))));
607    }
608
609    #[test]
610    fn test_cost_projection() {
611        let calculator = CostCalculator::new();
612
613        let projection = calculator
614            .project_monthly_cost("gpt-4", 4000, 2000, 30.0)
615            .unwrap();
616
617        assert_eq!(projection.model_name, "gpt-4");
618        assert!(projection.daily_cost > 0.0);
619        assert_eq!(projection.monthly_cost, projection.daily_cost * 30.0);
620        assert_eq!(projection.annual_cost, projection.monthly_cost * 12.0);
621    }
622
623    #[test]
624    fn test_all_default_models_available() {
625        let calculator = CostCalculator::new();
626
627        // Test all default models can be used for cost estimation
628        let test_models = [
629            "gpt-4",
630            "gpt-4-turbo",
631            "gpt-3.5-turbo",
632            "gpt-4o",
633            "gpt-4o-mini",
634            "claude-3-opus",
635            "claude-3-sonnet",
636            "claude-3-haiku",
637            "claude-3-5-sonnet",
638            "gemini-pro",
639            "gemini-ultra",
640        ];
641
642        for model in &test_models {
643            let result = calculator.estimate_cost(model, 1000, 500);
644            assert!(result.is_ok(), "Model {} should be available", model);
645        }
646    }
647}