Skip to main content

grapsus_proxy/inference/
cost.rs

1//! Cost calculator for inference request attribution.
2//!
3//! Calculates costs based on per-model pricing for input and output tokens.
4
5use tracing::{debug, trace};
6
7use grapsus_common::budget::{CostAttributionConfig, CostResult, ModelPricing};
8
9/// Cost calculator for inference requests.
10///
11/// Uses per-model pricing rules to calculate costs for inference requests
12/// based on input and output token counts.
13pub struct CostCalculator {
14    /// Configuration
15    config: CostAttributionConfig,
16    /// Route ID for logging
17    route_id: String,
18}
19
20impl CostCalculator {
21    /// Create a new cost calculator with the given configuration.
22    pub fn new(config: CostAttributionConfig, route_id: impl Into<String>) -> Self {
23        let route_id = route_id.into();
24
25        debug!(
26            route_id = %route_id,
27            enabled = config.enabled,
28            pricing_rules = config.pricing.len(),
29            default_input = config.default_input_cost,
30            default_output = config.default_output_cost,
31            currency = %config.currency,
32            "Created cost calculator"
33        );
34
35        Self { config, route_id }
36    }
37
38    /// Check if cost attribution is enabled.
39    pub fn is_enabled(&self) -> bool {
40        self.config.enabled
41    }
42
43    /// Calculate the cost for a request/response.
44    ///
45    /// Uses the first matching pricing rule, or falls back to default pricing.
46    pub fn calculate(&self, model: &str, input_tokens: u64, output_tokens: u64) -> CostResult {
47        if !self.config.enabled {
48            return CostResult::new(model, input_tokens, output_tokens, 0.0, 0.0, "USD");
49        }
50
51        // Find matching pricing rule
52        let (input_cost_per_million, output_cost_per_million, currency) =
53            if let Some(pricing) = self.find_pricing(model) {
54                let currency = pricing
55                    .currency
56                    .as_ref()
57                    .unwrap_or(&self.config.currency)
58                    .clone();
59                (
60                    pricing.input_cost_per_million,
61                    pricing.output_cost_per_million,
62                    currency,
63                )
64            } else {
65                (
66                    self.config.default_input_cost,
67                    self.config.default_output_cost,
68                    self.config.currency.clone(),
69                )
70            };
71
72        // Calculate costs
73        let input_cost = (input_tokens as f64 / 1_000_000.0) * input_cost_per_million;
74        let output_cost = (output_tokens as f64 / 1_000_000.0) * output_cost_per_million;
75        let total_cost = input_cost + output_cost;
76
77        trace!(
78            route_id = %self.route_id,
79            model = model,
80            input_tokens = input_tokens,
81            output_tokens = output_tokens,
82            input_cost = input_cost,
83            output_cost = output_cost,
84            total_cost = total_cost,
85            currency = %currency,
86            "Calculated cost"
87        );
88
89        CostResult::new(
90            model,
91            input_tokens,
92            output_tokens,
93            input_cost,
94            output_cost,
95            currency,
96        )
97    }
98
99    /// Find the pricing rule for a model.
100    ///
101    /// Returns the first matching rule, or None if no rules match.
102    pub fn find_pricing(&self, model: &str) -> Option<&ModelPricing> {
103        self.config.pricing.iter().find(|p| p.matches(model))
104    }
105
106    /// Get the default currency.
107    pub fn currency(&self) -> &str {
108        &self.config.currency
109    }
110
111    /// Get the number of pricing rules.
112    pub fn pricing_rule_count(&self) -> usize {
113        self.config.pricing.len()
114    }
115}
116
117// ============================================================================
118// Tests
119// ============================================================================
120
121#[cfg(test)]
122mod tests {
123    use super::*;
124
125    fn test_config() -> CostAttributionConfig {
126        CostAttributionConfig {
127            enabled: true,
128            pricing: vec![
129                ModelPricing {
130                    model_pattern: "gpt-4*".to_string(),
131                    input_cost_per_million: 30.0,
132                    output_cost_per_million: 60.0,
133                    currency: None,
134                },
135                ModelPricing {
136                    model_pattern: "gpt-3.5*".to_string(),
137                    input_cost_per_million: 0.5,
138                    output_cost_per_million: 1.5,
139                    currency: None,
140                },
141                ModelPricing {
142                    model_pattern: "claude-*".to_string(),
143                    input_cost_per_million: 15.0,
144                    output_cost_per_million: 75.0,
145                    currency: Some("EUR".to_string()),
146                },
147            ],
148            default_input_cost: 1.0,
149            default_output_cost: 2.0,
150            currency: "USD".to_string(),
151        }
152    }
153
154    #[test]
155    fn test_calculate_gpt4() {
156        let calc = CostCalculator::new(test_config(), "test-route");
157
158        // 1000 input tokens, 500 output tokens
159        let result = calc.calculate("gpt-4-turbo", 1000, 500);
160
161        assert_eq!(result.model, "gpt-4-turbo");
162        assert_eq!(result.input_tokens, 1000);
163        assert_eq!(result.output_tokens, 500);
164        assert_eq!(result.currency, "USD");
165
166        // $30/1M input = $0.00003 per token, 1000 tokens = $0.03
167        assert!((result.input_cost - 0.03).abs() < 0.001);
168
169        // $60/1M output = $0.00006 per token, 500 tokens = $0.03
170        assert!((result.output_cost - 0.03).abs() < 0.001);
171    }
172
173    #[test]
174    fn test_calculate_gpt35() {
175        let calc = CostCalculator::new(test_config(), "test-route");
176
177        let result = calc.calculate("gpt-3.5-turbo", 1_000_000, 1_000_000);
178
179        // $0.5/1M input = $0.50 for 1M tokens
180        assert!((result.input_cost - 0.5).abs() < 0.001);
181
182        // $1.5/1M output = $1.50 for 1M tokens
183        assert!((result.output_cost - 1.5).abs() < 0.001);
184
185        assert!((result.total_cost - 2.0).abs() < 0.001);
186    }
187
188    #[test]
189    fn test_calculate_claude_with_currency_override() {
190        let calc = CostCalculator::new(test_config(), "test-route");
191
192        let result = calc.calculate("claude-3-opus", 1000, 1000);
193
194        // Should use EUR from the pricing rule
195        assert_eq!(result.currency, "EUR");
196    }
197
198    #[test]
199    fn test_calculate_unknown_model_uses_default() {
200        let calc = CostCalculator::new(test_config(), "test-route");
201
202        let result = calc.calculate("llama-3", 1_000_000, 1_000_000);
203
204        // Should use default pricing
205        assert!((result.input_cost - 1.0).abs() < 0.001);
206        assert!((result.output_cost - 2.0).abs() < 0.001);
207        assert_eq!(result.currency, "USD");
208    }
209
210    #[test]
211    fn test_disabled_returns_zero() {
212        let mut config = test_config();
213        config.enabled = false;
214
215        let calc = CostCalculator::new(config, "test-route");
216
217        let result = calc.calculate("gpt-4", 1000, 500);
218
219        assert!((result.input_cost).abs() < 0.00001);
220        assert!((result.output_cost).abs() < 0.00001);
221        assert!((result.total_cost).abs() < 0.00001);
222    }
223
224    #[test]
225    fn test_find_pricing() {
226        let calc = CostCalculator::new(test_config(), "test-route");
227
228        assert!(calc.find_pricing("gpt-4").is_some());
229        assert!(calc.find_pricing("gpt-4-turbo").is_some());
230        assert!(calc.find_pricing("gpt-3.5-turbo").is_some());
231        assert!(calc.find_pricing("claude-3-sonnet").is_some());
232        assert!(calc.find_pricing("llama-3").is_none());
233    }
234}