1use serde::{Deserialize, Serialize};
21
22#[derive(Debug, Clone, Serialize, Deserialize)]
31pub struct TokenBudgetConfig {
32 #[serde(default)]
34 pub period: BudgetPeriod,
35
36 pub limit: u64,
38
39 #[serde(default = "default_alert_thresholds")]
42 pub alert_thresholds: Vec<f64>,
43
44 #[serde(default = "default_true")]
46 pub enforce: bool,
47
48 #[serde(default)]
50 pub rollover: bool,
51
52 #[serde(default)]
55 pub burst_allowance: Option<f64>,
56}
57
58fn default_alert_thresholds() -> Vec<f64> {
59 vec![0.80, 0.90, 0.95]
60}
61
62fn default_true() -> bool {
63 true
64}
65
66impl Default for TokenBudgetConfig {
67 fn default() -> Self {
68 Self {
69 period: BudgetPeriod::Daily,
70 limit: 1_000_000, alert_thresholds: default_alert_thresholds(),
72 enforce: true,
73 rollover: false,
74 burst_allowance: None,
75 }
76 }
77}
78
79#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
81#[serde(rename_all = "snake_case")]
82pub enum BudgetPeriod {
83 Hourly,
85 #[default]
87 Daily,
88 Monthly,
90 Custom {
92 seconds: u64,
94 },
95}
96
97impl BudgetPeriod {
98 pub fn as_secs(&self) -> u64 {
100 match self {
101 BudgetPeriod::Hourly => 3600,
102 BudgetPeriod::Daily => 86400,
103 BudgetPeriod::Monthly => 2_592_000, BudgetPeriod::Custom { seconds } => *seconds,
105 }
106 }
107}
108
109#[derive(Debug, Clone, Serialize, Deserialize)]
117pub struct CostAttributionConfig {
118 #[serde(default)]
120 pub enabled: bool,
121
122 #[serde(default)]
124 pub pricing: Vec<ModelPricing>,
125
126 #[serde(default = "default_input_cost")]
128 pub default_input_cost: f64,
129
130 #[serde(default = "default_output_cost")]
132 pub default_output_cost: f64,
133
134 #[serde(default = "default_currency")]
136 pub currency: String,
137}
138
139fn default_input_cost() -> f64 {
140 1.0
141}
142
143fn default_output_cost() -> f64 {
144 2.0
145}
146
147fn default_currency() -> String {
148 "USD".to_string()
149}
150
151impl Default for CostAttributionConfig {
152 fn default() -> Self {
153 Self {
154 enabled: false,
155 pricing: Vec::new(),
156 default_input_cost: default_input_cost(),
157 default_output_cost: default_output_cost(),
158 currency: default_currency(),
159 }
160 }
161}
162
163#[derive(Debug, Clone, Serialize, Deserialize)]
169pub struct ModelPricing {
170 pub model_pattern: String,
172
173 pub input_cost_per_million: f64,
175
176 pub output_cost_per_million: f64,
178
179 #[serde(default)]
181 pub currency: Option<String>,
182}
183
184impl ModelPricing {
185 pub fn new(pattern: impl Into<String>, input_cost: f64, output_cost: f64) -> Self {
187 Self {
188 model_pattern: pattern.into(),
189 input_cost_per_million: input_cost,
190 output_cost_per_million: output_cost,
191 currency: None,
192 }
193 }
194
195 pub fn matches(&self, model: &str) -> bool {
197 if self.model_pattern.contains('*') {
198 let pattern = &self.model_pattern;
200 if let Some(inner) = pattern.strip_prefix('*').and_then(|p| p.strip_suffix('*')) {
201 model.contains(inner)
203 } else if let Some(suffix) = pattern.strip_prefix('*') {
204 model.ends_with(suffix)
206 } else if let Some(prefix) = pattern.strip_suffix('*') {
207 model.starts_with(prefix)
209 } else {
210 let parts: Vec<&str> = pattern.split('*').collect();
212 if parts.is_empty() {
213 return true;
214 }
215
216 let mut remaining = model;
217 for (i, part) in parts.iter().enumerate() {
218 if part.is_empty() {
219 continue;
220 }
221 if i == 0 {
222 if !remaining.starts_with(part) {
224 return false;
225 }
226 remaining = &remaining[part.len()..];
227 } else if i == parts.len() - 1 {
228 if !remaining.ends_with(part) {
230 return false;
231 }
232 } else {
233 if let Some(idx) = remaining.find(part) {
235 remaining = &remaining[idx + part.len()..];
236 } else {
237 return false;
238 }
239 }
240 }
241 true
242 }
243 } else {
244 self.model_pattern == model
246 }
247 }
248
249 pub fn calculate_cost(&self, input_tokens: u64, output_tokens: u64) -> f64 {
251 let input_cost = (input_tokens as f64 / 1_000_000.0) * self.input_cost_per_million;
252 let output_cost = (output_tokens as f64 / 1_000_000.0) * self.output_cost_per_million;
253 input_cost + output_cost
254 }
255}
256
257#[derive(Debug, Clone, PartialEq)]
263pub enum BudgetCheckResult {
264 Allowed {
266 remaining: u64,
268 },
269 Exhausted {
271 retry_after_secs: u64,
273 },
274 Soft {
276 remaining: i64,
278 over_by: u64,
280 },
281}
282
283impl BudgetCheckResult {
284 pub fn is_allowed(&self) -> bool {
286 matches!(self, Self::Allowed { .. } | Self::Soft { .. })
287 }
288
289 pub fn retry_after_secs(&self) -> u64 {
291 match self {
292 Self::Exhausted { retry_after_secs } => *retry_after_secs,
293 _ => 0,
294 }
295 }
296}
297
298#[derive(Debug, Clone)]
300pub struct BudgetAlert {
301 pub tenant: String,
303 pub threshold: f64,
305 pub tokens_used: u64,
307 pub tokens_limit: u64,
309 pub period_start: u64,
311}
312
313impl BudgetAlert {
314 pub fn usage_percent(&self) -> f64 {
316 if self.tokens_limit == 0 {
317 return 0.0;
318 }
319 (self.tokens_used as f64 / self.tokens_limit as f64) * 100.0
320 }
321}
322
323#[derive(Debug, Clone)]
325pub struct TenantBudgetStatus {
326 pub tokens_used: u64,
328 pub tokens_limit: u64,
330 pub tokens_remaining: u64,
332 pub usage_percent: f64,
334 pub period_start: u64,
336 pub period_end: u64,
338 pub exhausted: bool,
340}
341
342#[derive(Debug, Clone)]
344pub struct CostResult {
345 pub input_cost: f64,
347 pub output_cost: f64,
349 pub total_cost: f64,
351 pub currency: String,
353 pub model: String,
355 pub input_tokens: u64,
357 pub output_tokens: u64,
359}
360
361impl CostResult {
362 pub fn new(
364 model: impl Into<String>,
365 input_tokens: u64,
366 output_tokens: u64,
367 input_cost: f64,
368 output_cost: f64,
369 currency: impl Into<String>,
370 ) -> Self {
371 Self {
372 input_cost,
373 output_cost,
374 total_cost: input_cost + output_cost,
375 currency: currency.into(),
376 model: model.into(),
377 input_tokens,
378 output_tokens,
379 }
380 }
381}
382
383#[cfg(test)]
388mod tests {
389 use super::*;
390
391 #[test]
392 fn test_budget_period_as_secs() {
393 assert_eq!(BudgetPeriod::Hourly.as_secs(), 3600);
394 assert_eq!(BudgetPeriod::Daily.as_secs(), 86400);
395 assert_eq!(BudgetPeriod::Monthly.as_secs(), 2_592_000);
396 assert_eq!(BudgetPeriod::Custom { seconds: 7200 }.as_secs(), 7200);
397 }
398
399 #[test]
400 fn test_model_pricing_exact_match() {
401 let pricing = ModelPricing::new("gpt-4", 30.0, 60.0);
402 assert!(pricing.matches("gpt-4"));
403 assert!(!pricing.matches("gpt-4-turbo"));
404 assert!(!pricing.matches("gpt-3.5"));
405 }
406
407 #[test]
408 fn test_model_pricing_prefix_match() {
409 let pricing = ModelPricing::new("gpt-4*", 30.0, 60.0);
410 assert!(pricing.matches("gpt-4"));
411 assert!(pricing.matches("gpt-4-turbo"));
412 assert!(pricing.matches("gpt-4o"));
413 assert!(!pricing.matches("gpt-3.5"));
414 }
415
416 #[test]
417 fn test_model_pricing_suffix_match() {
418 let pricing = ModelPricing::new("*-turbo", 30.0, 60.0);
419 assert!(pricing.matches("gpt-4-turbo"));
420 assert!(pricing.matches("gpt-3.5-turbo"));
421 assert!(!pricing.matches("gpt-4"));
422 }
423
424 #[test]
425 fn test_model_pricing_contains_match() {
426 let pricing = ModelPricing::new("*claude*", 30.0, 60.0);
427 assert!(pricing.matches("claude-3"));
428 assert!(pricing.matches("anthropic-claude-3-opus"));
429 assert!(!pricing.matches("gpt-4"));
430 }
431
432 #[test]
433 fn test_model_pricing_calculate_cost() {
434 let pricing = ModelPricing::new("gpt-4", 30.0, 60.0);
435
436 let cost = pricing.calculate_cost(1_000_000, 1_000_000);
438 assert!((cost - 90.0).abs() < 0.001);
439
440 let cost = pricing.calculate_cost(1000, 500);
442 let expected = (1000.0 / 1_000_000.0) * 30.0 + (500.0 / 1_000_000.0) * 60.0;
443 assert!((cost - expected).abs() < 0.0001);
444 }
445
446 #[test]
447 fn test_budget_check_result_is_allowed() {
448 assert!(BudgetCheckResult::Allowed { remaining: 1000 }.is_allowed());
449 assert!(BudgetCheckResult::Soft {
450 remaining: -100,
451 over_by: 100
452 }
453 .is_allowed());
454 assert!(!BudgetCheckResult::Exhausted {
455 retry_after_secs: 3600
456 }
457 .is_allowed());
458 }
459
460 #[test]
461 fn test_budget_alert_usage_percent() {
462 let alert = BudgetAlert {
463 tenant: "test".to_string(),
464 threshold: 0.80,
465 tokens_used: 800_000,
466 tokens_limit: 1_000_000,
467 period_start: 0,
468 };
469 assert!((alert.usage_percent() - 80.0).abs() < 0.001);
470 }
471
472 #[test]
473 fn test_cost_result_new() {
474 let result = CostResult::new("gpt-4", 1000, 500, 0.03, 0.03, "USD");
475 assert_eq!(result.model, "gpt-4");
476 assert_eq!(result.input_tokens, 1000);
477 assert_eq!(result.output_tokens, 500);
478 assert!((result.total_cost - 0.06).abs() < 0.001);
479 }
480
481 #[test]
482 fn test_token_budget_config_default() {
483 let config = TokenBudgetConfig::default();
484 assert_eq!(config.period, BudgetPeriod::Daily);
485 assert_eq!(config.limit, 1_000_000);
486 assert!(config.enforce);
487 assert!(!config.rollover);
488 assert!(config.burst_allowance.is_none());
489 assert_eq!(config.alert_thresholds, vec![0.80, 0.90, 0.95]);
490 }
491
492 #[test]
493 fn test_cost_attribution_config_default() {
494 let config = CostAttributionConfig::default();
495 assert!(!config.enabled);
496 assert!(config.pricing.is_empty());
497 assert!((config.default_input_cost - 1.0).abs() < 0.001);
498 assert!((config.default_output_cost - 2.0).abs() < 0.001);
499 assert_eq!(config.currency, "USD");
500 }
501}