1use tracing::{debug, trace};
6
7use grapsus_common::budget::{CostAttributionConfig, CostResult, ModelPricing};
8
9pub struct CostCalculator {
14 config: CostAttributionConfig,
16 route_id: String,
18}
19
20impl CostCalculator {
21 pub fn new(config: CostAttributionConfig, route_id: impl Into<String>) -> Self {
23 let route_id = route_id.into();
24
25 debug!(
26 route_id = %route_id,
27 enabled = config.enabled,
28 pricing_rules = config.pricing.len(),
29 default_input = config.default_input_cost,
30 default_output = config.default_output_cost,
31 currency = %config.currency,
32 "Created cost calculator"
33 );
34
35 Self { config, route_id }
36 }
37
38 pub fn is_enabled(&self) -> bool {
40 self.config.enabled
41 }
42
43 pub fn calculate(&self, model: &str, input_tokens: u64, output_tokens: u64) -> CostResult {
47 if !self.config.enabled {
48 return CostResult::new(model, input_tokens, output_tokens, 0.0, 0.0, "USD");
49 }
50
51 let (input_cost_per_million, output_cost_per_million, currency) =
53 if let Some(pricing) = self.find_pricing(model) {
54 let currency = pricing
55 .currency
56 .as_ref()
57 .unwrap_or(&self.config.currency)
58 .clone();
59 (
60 pricing.input_cost_per_million,
61 pricing.output_cost_per_million,
62 currency,
63 )
64 } else {
65 (
66 self.config.default_input_cost,
67 self.config.default_output_cost,
68 self.config.currency.clone(),
69 )
70 };
71
72 let input_cost = (input_tokens as f64 / 1_000_000.0) * input_cost_per_million;
74 let output_cost = (output_tokens as f64 / 1_000_000.0) * output_cost_per_million;
75 let total_cost = input_cost + output_cost;
76
77 trace!(
78 route_id = %self.route_id,
79 model = model,
80 input_tokens = input_tokens,
81 output_tokens = output_tokens,
82 input_cost = input_cost,
83 output_cost = output_cost,
84 total_cost = total_cost,
85 currency = %currency,
86 "Calculated cost"
87 );
88
89 CostResult::new(
90 model,
91 input_tokens,
92 output_tokens,
93 input_cost,
94 output_cost,
95 currency,
96 )
97 }
98
99 pub fn find_pricing(&self, model: &str) -> Option<&ModelPricing> {
103 self.config.pricing.iter().find(|p| p.matches(model))
104 }
105
106 pub fn currency(&self) -> &str {
108 &self.config.currency
109 }
110
111 pub fn pricing_rule_count(&self) -> usize {
113 self.config.pricing.len()
114 }
115}
116
117#[cfg(test)]
122mod tests {
123 use super::*;
124
125 fn test_config() -> CostAttributionConfig {
126 CostAttributionConfig {
127 enabled: true,
128 pricing: vec![
129 ModelPricing {
130 model_pattern: "gpt-4*".to_string(),
131 input_cost_per_million: 30.0,
132 output_cost_per_million: 60.0,
133 currency: None,
134 },
135 ModelPricing {
136 model_pattern: "gpt-3.5*".to_string(),
137 input_cost_per_million: 0.5,
138 output_cost_per_million: 1.5,
139 currency: None,
140 },
141 ModelPricing {
142 model_pattern: "claude-*".to_string(),
143 input_cost_per_million: 15.0,
144 output_cost_per_million: 75.0,
145 currency: Some("EUR".to_string()),
146 },
147 ],
148 default_input_cost: 1.0,
149 default_output_cost: 2.0,
150 currency: "USD".to_string(),
151 }
152 }
153
154 #[test]
155 fn test_calculate_gpt4() {
156 let calc = CostCalculator::new(test_config(), "test-route");
157
158 let result = calc.calculate("gpt-4-turbo", 1000, 500);
160
161 assert_eq!(result.model, "gpt-4-turbo");
162 assert_eq!(result.input_tokens, 1000);
163 assert_eq!(result.output_tokens, 500);
164 assert_eq!(result.currency, "USD");
165
166 assert!((result.input_cost - 0.03).abs() < 0.001);
168
169 assert!((result.output_cost - 0.03).abs() < 0.001);
171 }
172
173 #[test]
174 fn test_calculate_gpt35() {
175 let calc = CostCalculator::new(test_config(), "test-route");
176
177 let result = calc.calculate("gpt-3.5-turbo", 1_000_000, 1_000_000);
178
179 assert!((result.input_cost - 0.5).abs() < 0.001);
181
182 assert!((result.output_cost - 1.5).abs() < 0.001);
184
185 assert!((result.total_cost - 2.0).abs() < 0.001);
186 }
187
188 #[test]
189 fn test_calculate_claude_with_currency_override() {
190 let calc = CostCalculator::new(test_config(), "test-route");
191
192 let result = calc.calculate("claude-3-opus", 1000, 1000);
193
194 assert_eq!(result.currency, "EUR");
196 }
197
198 #[test]
199 fn test_calculate_unknown_model_uses_default() {
200 let calc = CostCalculator::new(test_config(), "test-route");
201
202 let result = calc.calculate("llama-3", 1_000_000, 1_000_000);
203
204 assert!((result.input_cost - 1.0).abs() < 0.001);
206 assert!((result.output_cost - 2.0).abs() < 0.001);
207 assert_eq!(result.currency, "USD");
208 }
209
210 #[test]
211 fn test_disabled_returns_zero() {
212 let mut config = test_config();
213 config.enabled = false;
214
215 let calc = CostCalculator::new(config, "test-route");
216
217 let result = calc.calculate("gpt-4", 1000, 500);
218
219 assert!((result.input_cost).abs() < 0.00001);
220 assert!((result.output_cost).abs() < 0.00001);
221 assert!((result.total_cost).abs() < 0.00001);
222 }
223
224 #[test]
225 fn test_find_pricing() {
226 let calc = CostCalculator::new(test_config(), "test-route");
227
228 assert!(calc.find_pricing("gpt-4").is_some());
229 assert!(calc.find_pricing("gpt-4-turbo").is_some());
230 assert!(calc.find_pricing("gpt-3.5-turbo").is_some());
231 assert!(calc.find_pricing("claude-3-sonnet").is_some());
232 assert!(calc.find_pricing("llama-3").is_none());
233 }
234}