1use serde::{Deserialize, Serialize};
21
22#[derive(Debug, Clone, Serialize, Deserialize)]
31pub struct TokenBudgetConfig {
32 #[serde(default)]
34 pub period: BudgetPeriod,
35
36 pub limit: u64,
38
39 #[serde(default = "default_alert_thresholds")]
42 pub alert_thresholds: Vec<f64>,
43
44 #[serde(default = "default_true")]
46 pub enforce: bool,
47
48 #[serde(default)]
50 pub rollover: bool,
51
52 #[serde(default)]
55 pub burst_allowance: Option<f64>,
56}
57
58fn default_alert_thresholds() -> Vec<f64> {
59 vec![0.80, 0.90, 0.95]
60}
61
62fn default_true() -> bool {
63 true
64}
65
66impl Default for TokenBudgetConfig {
67 fn default() -> Self {
68 Self {
69 period: BudgetPeriod::Daily,
70 limit: 1_000_000, alert_thresholds: default_alert_thresholds(),
72 enforce: true,
73 rollover: false,
74 burst_allowance: None,
75 }
76 }
77}
78
79#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
81#[serde(rename_all = "snake_case")]
82pub enum BudgetPeriod {
83 Hourly,
85 #[default]
87 Daily,
88 Monthly,
90 Custom {
92 seconds: u64,
94 },
95}
96
97impl BudgetPeriod {
98 pub fn as_secs(&self) -> u64 {
100 match self {
101 BudgetPeriod::Hourly => 3600,
102 BudgetPeriod::Daily => 86400,
103 BudgetPeriod::Monthly => 2_592_000, BudgetPeriod::Custom { seconds } => *seconds,
105 }
106 }
107}
108
109#[derive(Debug, Clone, Serialize, Deserialize)]
117pub struct CostAttributionConfig {
118 #[serde(default)]
120 pub enabled: bool,
121
122 #[serde(default)]
124 pub pricing: Vec<ModelPricing>,
125
126 #[serde(default = "default_input_cost")]
128 pub default_input_cost: f64,
129
130 #[serde(default = "default_output_cost")]
132 pub default_output_cost: f64,
133
134 #[serde(default = "default_currency")]
136 pub currency: String,
137}
138
139fn default_input_cost() -> f64 {
140 1.0
141}
142
143fn default_output_cost() -> f64 {
144 2.0
145}
146
147fn default_currency() -> String {
148 "USD".to_string()
149}
150
151impl Default for CostAttributionConfig {
152 fn default() -> Self {
153 Self {
154 enabled: false,
155 pricing: Vec::new(),
156 default_input_cost: default_input_cost(),
157 default_output_cost: default_output_cost(),
158 currency: default_currency(),
159 }
160 }
161}
162
163#[derive(Debug, Clone, Serialize, Deserialize)]
169pub struct ModelPricing {
170 pub model_pattern: String,
172
173 pub input_cost_per_million: f64,
175
176 pub output_cost_per_million: f64,
178
179 #[serde(default)]
181 pub currency: Option<String>,
182}
183
184impl ModelPricing {
185 pub fn new(pattern: impl Into<String>, input_cost: f64, output_cost: f64) -> Self {
187 Self {
188 model_pattern: pattern.into(),
189 input_cost_per_million: input_cost,
190 output_cost_per_million: output_cost,
191 currency: None,
192 }
193 }
194
195 pub fn matches(&self, model: &str) -> bool {
197 if self.model_pattern.contains('*') {
198 let pattern = &self.model_pattern;
200 if pattern.starts_with('*') && pattern.ends_with('*') {
201 let inner = &pattern[1..pattern.len() - 1];
203 model.contains(inner)
204 } else if pattern.starts_with('*') {
205 model.ends_with(&pattern[1..])
207 } else if pattern.ends_with('*') {
208 model.starts_with(&pattern[..pattern.len() - 1])
210 } else {
211 let parts: Vec<&str> = pattern.split('*').collect();
213 if parts.is_empty() {
214 return true;
215 }
216
217 let mut remaining = model;
218 for (i, part) in parts.iter().enumerate() {
219 if part.is_empty() {
220 continue;
221 }
222 if i == 0 {
223 if !remaining.starts_with(part) {
225 return false;
226 }
227 remaining = &remaining[part.len()..];
228 } else if i == parts.len() - 1 {
229 if !remaining.ends_with(part) {
231 return false;
232 }
233 } else {
234 if let Some(idx) = remaining.find(part) {
236 remaining = &remaining[idx + part.len()..];
237 } else {
238 return false;
239 }
240 }
241 }
242 true
243 }
244 } else {
245 self.model_pattern == model
247 }
248 }
249
250 pub fn calculate_cost(&self, input_tokens: u64, output_tokens: u64) -> f64 {
252 let input_cost = (input_tokens as f64 / 1_000_000.0) * self.input_cost_per_million;
253 let output_cost = (output_tokens as f64 / 1_000_000.0) * self.output_cost_per_million;
254 input_cost + output_cost
255 }
256}
257
258#[derive(Debug, Clone, PartialEq)]
264pub enum BudgetCheckResult {
265 Allowed {
267 remaining: u64,
269 },
270 Exhausted {
272 retry_after_secs: u64,
274 },
275 Soft {
277 remaining: i64,
279 over_by: u64,
281 },
282}
283
284impl BudgetCheckResult {
285 pub fn is_allowed(&self) -> bool {
287 matches!(self, Self::Allowed { .. } | Self::Soft { .. })
288 }
289
290 pub fn retry_after_secs(&self) -> u64 {
292 match self {
293 Self::Exhausted { retry_after_secs } => *retry_after_secs,
294 _ => 0,
295 }
296 }
297}
298
299#[derive(Debug, Clone)]
301pub struct BudgetAlert {
302 pub tenant: String,
304 pub threshold: f64,
306 pub tokens_used: u64,
308 pub tokens_limit: u64,
310 pub period_start: u64,
312}
313
314impl BudgetAlert {
315 pub fn usage_percent(&self) -> f64 {
317 if self.tokens_limit == 0 {
318 return 0.0;
319 }
320 (self.tokens_used as f64 / self.tokens_limit as f64) * 100.0
321 }
322}
323
324#[derive(Debug, Clone)]
326pub struct TenantBudgetStatus {
327 pub tokens_used: u64,
329 pub tokens_limit: u64,
331 pub tokens_remaining: u64,
333 pub usage_percent: f64,
335 pub period_start: u64,
337 pub period_end: u64,
339 pub exhausted: bool,
341}
342
343#[derive(Debug, Clone)]
345pub struct CostResult {
346 pub input_cost: f64,
348 pub output_cost: f64,
350 pub total_cost: f64,
352 pub currency: String,
354 pub model: String,
356 pub input_tokens: u64,
358 pub output_tokens: u64,
360}
361
362impl CostResult {
363 pub fn new(
365 model: impl Into<String>,
366 input_tokens: u64,
367 output_tokens: u64,
368 input_cost: f64,
369 output_cost: f64,
370 currency: impl Into<String>,
371 ) -> Self {
372 Self {
373 input_cost,
374 output_cost,
375 total_cost: input_cost + output_cost,
376 currency: currency.into(),
377 model: model.into(),
378 input_tokens,
379 output_tokens,
380 }
381 }
382}
383
384#[cfg(test)]
389mod tests {
390 use super::*;
391
392 #[test]
393 fn test_budget_period_as_secs() {
394 assert_eq!(BudgetPeriod::Hourly.as_secs(), 3600);
395 assert_eq!(BudgetPeriod::Daily.as_secs(), 86400);
396 assert_eq!(BudgetPeriod::Monthly.as_secs(), 2_592_000);
397 assert_eq!(BudgetPeriod::Custom { seconds: 7200 }.as_secs(), 7200);
398 }
399
400 #[test]
401 fn test_model_pricing_exact_match() {
402 let pricing = ModelPricing::new("gpt-4", 30.0, 60.0);
403 assert!(pricing.matches("gpt-4"));
404 assert!(!pricing.matches("gpt-4-turbo"));
405 assert!(!pricing.matches("gpt-3.5"));
406 }
407
408 #[test]
409 fn test_model_pricing_prefix_match() {
410 let pricing = ModelPricing::new("gpt-4*", 30.0, 60.0);
411 assert!(pricing.matches("gpt-4"));
412 assert!(pricing.matches("gpt-4-turbo"));
413 assert!(pricing.matches("gpt-4o"));
414 assert!(!pricing.matches("gpt-3.5"));
415 }
416
417 #[test]
418 fn test_model_pricing_suffix_match() {
419 let pricing = ModelPricing::new("*-turbo", 30.0, 60.0);
420 assert!(pricing.matches("gpt-4-turbo"));
421 assert!(pricing.matches("gpt-3.5-turbo"));
422 assert!(!pricing.matches("gpt-4"));
423 }
424
425 #[test]
426 fn test_model_pricing_contains_match() {
427 let pricing = ModelPricing::new("*claude*", 30.0, 60.0);
428 assert!(pricing.matches("claude-3"));
429 assert!(pricing.matches("anthropic-claude-3-opus"));
430 assert!(!pricing.matches("gpt-4"));
431 }
432
433 #[test]
434 fn test_model_pricing_calculate_cost() {
435 let pricing = ModelPricing::new("gpt-4", 30.0, 60.0);
436
437 let cost = pricing.calculate_cost(1_000_000, 1_000_000);
439 assert!((cost - 90.0).abs() < 0.001);
440
441 let cost = pricing.calculate_cost(1000, 500);
443 let expected = (1000.0 / 1_000_000.0) * 30.0 + (500.0 / 1_000_000.0) * 60.0;
444 assert!((cost - expected).abs() < 0.0001);
445 }
446
447 #[test]
448 fn test_budget_check_result_is_allowed() {
449 assert!(BudgetCheckResult::Allowed { remaining: 1000 }.is_allowed());
450 assert!(BudgetCheckResult::Soft { remaining: -100, over_by: 100 }.is_allowed());
451 assert!(!BudgetCheckResult::Exhausted { retry_after_secs: 3600 }.is_allowed());
452 }
453
454 #[test]
455 fn test_budget_alert_usage_percent() {
456 let alert = BudgetAlert {
457 tenant: "test".to_string(),
458 threshold: 0.80,
459 tokens_used: 800_000,
460 tokens_limit: 1_000_000,
461 period_start: 0,
462 };
463 assert!((alert.usage_percent() - 80.0).abs() < 0.001);
464 }
465
466 #[test]
467 fn test_cost_result_new() {
468 let result = CostResult::new("gpt-4", 1000, 500, 0.03, 0.03, "USD");
469 assert_eq!(result.model, "gpt-4");
470 assert_eq!(result.input_tokens, 1000);
471 assert_eq!(result.output_tokens, 500);
472 assert!((result.total_cost - 0.06).abs() < 0.001);
473 }
474
475 #[test]
476 fn test_token_budget_config_default() {
477 let config = TokenBudgetConfig::default();
478 assert_eq!(config.period, BudgetPeriod::Daily);
479 assert_eq!(config.limit, 1_000_000);
480 assert!(config.enforce);
481 assert!(!config.rollover);
482 assert!(config.burst_allowance.is_none());
483 assert_eq!(config.alert_thresholds, vec![0.80, 0.90, 0.95]);
484 }
485
486 #[test]
487 fn test_cost_attribution_config_default() {
488 let config = CostAttributionConfig::default();
489 assert!(!config.enabled);
490 assert!(config.pricing.is_empty());
491 assert!((config.default_input_cost - 1.0).abs() < 0.001);
492 assert!((config.default_output_cost - 2.0).abs() < 0.001);
493 assert_eq!(config.currency, "USD");
494 }
495}