Skip to main content

sentinel_common/
budget.rs

1//! Token budget management and cost attribution types.
2//!
3//! This module provides configuration types for:
4//! - Per-tenant token budgets with period-based limits
5//! - Cost attribution with per-model pricing
6//!
7//! # Token Budgets
8//!
9//! Token budgets allow tracking cumulative token usage per tenant over
10//! configurable periods (hourly, daily, monthly). This enables:
11//! - Quota enforcement for API consumers
12//! - Usage alerts at configurable thresholds
13//! - Optional rollover of unused tokens
14//!
15//! # Cost Attribution
16//!
17//! Cost attribution tracks the monetary cost of inference requests based
18//! on model-specific pricing for input and output tokens.
19
20use serde::{Deserialize, Serialize};
21
22// ============================================================================
23// Budget Configuration
24// ============================================================================
25
26/// Token budget configuration for per-tenant usage tracking.
27///
28/// Budgets track cumulative token usage over a configurable period,
29/// with optional alerts and enforcement.
30#[derive(Debug, Clone, Serialize, Deserialize)]
31pub struct TokenBudgetConfig {
32    /// Budget period (when the budget resets)
33    #[serde(default)]
34    pub period: BudgetPeriod,
35
36    /// Total tokens allowed in the period
37    pub limit: u64,
38
39    /// Alert thresholds as percentages (e.g., [0.80, 0.90, 0.95])
40    /// Triggers alerts when usage crosses these thresholds
41    #[serde(default = "default_alert_thresholds")]
42    pub alert_thresholds: Vec<f64>,
43
44    /// Whether to enforce the limit (block requests when exhausted)
45    #[serde(default = "default_true")]
46    pub enforce: bool,
47
48    /// Allow unused tokens to roll over to the next period
49    #[serde(default)]
50    pub rollover: bool,
51
52    /// Allow burst usage above limit as a percentage (soft limit)
53    /// E.g., 0.10 allows 10% burst above the limit
54    #[serde(default)]
55    pub burst_allowance: Option<f64>,
56}
57
58fn default_alert_thresholds() -> Vec<f64> {
59    vec![0.80, 0.90, 0.95]
60}
61
62fn default_true() -> bool {
63    true
64}
65
66impl Default for TokenBudgetConfig {
67    fn default() -> Self {
68        Self {
69            period: BudgetPeriod::Daily,
70            limit: 1_000_000, // 1M tokens
71            alert_thresholds: default_alert_thresholds(),
72            enforce: true,
73            rollover: false,
74            burst_allowance: None,
75        }
76    }
77}
78
79/// Budget period defining when the budget resets.
80#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
81#[serde(rename_all = "snake_case")]
82pub enum BudgetPeriod {
83    /// Resets every hour
84    Hourly,
85    /// Resets every day at midnight UTC
86    #[default]
87    Daily,
88    /// Resets on the first of each month at midnight UTC
89    Monthly,
90    /// Custom period in seconds
91    Custom {
92        /// Period duration in seconds
93        seconds: u64,
94    },
95}
96
97impl BudgetPeriod {
98    /// Get the period duration in seconds.
99    pub fn as_secs(&self) -> u64 {
100        match self {
101            BudgetPeriod::Hourly => 3600,
102            BudgetPeriod::Daily => 86400,
103            BudgetPeriod::Monthly => 2_592_000, // 30 days
104            BudgetPeriod::Custom { seconds } => *seconds,
105        }
106    }
107}
108
109// ============================================================================
110// Cost Attribution Configuration
111// ============================================================================
112
113/// Cost attribution configuration for tracking inference costs.
114///
115/// Allows per-model pricing with separate input/output token rates.
116#[derive(Debug, Clone, Serialize, Deserialize)]
117pub struct CostAttributionConfig {
118    /// Whether cost attribution is enabled
119    #[serde(default)]
120    pub enabled: bool,
121
122    /// Per-model pricing rules (evaluated in order, first match wins)
123    #[serde(default)]
124    pub pricing: Vec<ModelPricing>,
125
126    /// Default cost per million input tokens (fallback)
127    #[serde(default = "default_input_cost")]
128    pub default_input_cost: f64,
129
130    /// Default cost per million output tokens (fallback)
131    #[serde(default = "default_output_cost")]
132    pub default_output_cost: f64,
133
134    /// Currency for cost values (default: USD)
135    #[serde(default = "default_currency")]
136    pub currency: String,
137}
138
139fn default_input_cost() -> f64 {
140    1.0
141}
142
143fn default_output_cost() -> f64 {
144    2.0
145}
146
147fn default_currency() -> String {
148    "USD".to_string()
149}
150
151impl Default for CostAttributionConfig {
152    fn default() -> Self {
153        Self {
154            enabled: false,
155            pricing: Vec::new(),
156            default_input_cost: default_input_cost(),
157            default_output_cost: default_output_cost(),
158            currency: default_currency(),
159        }
160    }
161}
162
163/// Per-model pricing configuration.
164///
165/// The `model_pattern` supports glob-style matching:
166/// - `gpt-4*` matches `gpt-4`, `gpt-4-turbo`, `gpt-4o`, etc.
167/// - `claude-3-*` matches all Claude 3 variants
168#[derive(Debug, Clone, Serialize, Deserialize)]
169pub struct ModelPricing {
170    /// Model name or pattern (glob-style matching with `*`)
171    pub model_pattern: String,
172
173    /// Cost per million input tokens
174    pub input_cost_per_million: f64,
175
176    /// Cost per million output tokens
177    pub output_cost_per_million: f64,
178
179    /// Optional currency override (defaults to parent config currency)
180    #[serde(default)]
181    pub currency: Option<String>,
182}
183
184impl ModelPricing {
185    /// Create new model pricing with the given pattern and costs.
186    pub fn new(pattern: impl Into<String>, input_cost: f64, output_cost: f64) -> Self {
187        Self {
188            model_pattern: pattern.into(),
189            input_cost_per_million: input_cost,
190            output_cost_per_million: output_cost,
191            currency: None,
192        }
193    }
194
195    /// Check if this pricing rule matches the given model name.
196    pub fn matches(&self, model: &str) -> bool {
197        if self.model_pattern.contains('*') {
198            // Glob-style matching
199            let pattern = &self.model_pattern;
200            if let Some(inner) = pattern.strip_prefix('*').and_then(|p| p.strip_suffix('*')) {
201                // *pattern* - contains
202                model.contains(inner)
203            } else if let Some(suffix) = pattern.strip_prefix('*') {
204                // *pattern - ends with
205                model.ends_with(suffix)
206            } else if let Some(prefix) = pattern.strip_suffix('*') {
207                // pattern* - starts with
208                model.starts_with(prefix)
209            } else {
210                // Complex pattern - split and match parts
211                let parts: Vec<&str> = pattern.split('*').collect();
212                if parts.is_empty() {
213                    return true;
214                }
215
216                let mut remaining = model;
217                for (i, part) in parts.iter().enumerate() {
218                    if part.is_empty() {
219                        continue;
220                    }
221                    if i == 0 {
222                        // First part must be prefix
223                        if !remaining.starts_with(part) {
224                            return false;
225                        }
226                        remaining = &remaining[part.len()..];
227                    } else if i == parts.len() - 1 {
228                        // Last part must be suffix
229                        if !remaining.ends_with(part) {
230                            return false;
231                        }
232                    } else {
233                        // Middle parts must exist
234                        if let Some(idx) = remaining.find(part) {
235                            remaining = &remaining[idx + part.len()..];
236                        } else {
237                            return false;
238                        }
239                    }
240                }
241                true
242            }
243        } else {
244            // Exact match
245            self.model_pattern == model
246        }
247    }
248
249    /// Calculate cost for the given token counts.
250    pub fn calculate_cost(&self, input_tokens: u64, output_tokens: u64) -> f64 {
251        let input_cost = (input_tokens as f64 / 1_000_000.0) * self.input_cost_per_million;
252        let output_cost = (output_tokens as f64 / 1_000_000.0) * self.output_cost_per_million;
253        input_cost + output_cost
254    }
255}
256
257// ============================================================================
258// Result Types
259// ============================================================================
260
261/// Result of a budget check operation.
262#[derive(Debug, Clone, PartialEq)]
263pub enum BudgetCheckResult {
264    /// Request is allowed within budget
265    Allowed {
266        /// Tokens remaining after this request
267        remaining: u64,
268    },
269    /// Budget is exhausted
270    Exhausted {
271        /// Seconds until the period resets
272        retry_after_secs: u64,
273    },
274    /// Request allowed via burst allowance (soft limit)
275    Soft {
276        /// Tokens remaining (negative means over budget)
277        remaining: i64,
278        /// Amount over the base limit
279        over_by: u64,
280    },
281}
282
283impl BudgetCheckResult {
284    /// Returns true if the request should be allowed.
285    pub fn is_allowed(&self) -> bool {
286        matches!(self, Self::Allowed { .. } | Self::Soft { .. })
287    }
288
289    /// Returns the retry-after value in seconds, or 0 if allowed.
290    pub fn retry_after_secs(&self) -> u64 {
291        match self {
292            Self::Exhausted { retry_after_secs } => *retry_after_secs,
293            _ => 0,
294        }
295    }
296}
297
298/// Alert generated when budget threshold is crossed.
299#[derive(Debug, Clone)]
300pub struct BudgetAlert {
301    /// Tenant/client identifier
302    pub tenant: String,
303    /// Threshold that was crossed (e.g., 0.80 for 80%)
304    pub threshold: f64,
305    /// Current token usage
306    pub tokens_used: u64,
307    /// Budget limit
308    pub tokens_limit: u64,
309    /// Current period start time (Unix timestamp)
310    pub period_start: u64,
311}
312
313impl BudgetAlert {
314    /// Get the usage percentage.
315    pub fn usage_percent(&self) -> f64 {
316        if self.tokens_limit == 0 {
317            return 0.0;
318        }
319        (self.tokens_used as f64 / self.tokens_limit as f64) * 100.0
320    }
321}
322
323/// Current budget status for a tenant.
324#[derive(Debug, Clone)]
325pub struct TenantBudgetStatus {
326    /// Tokens used in current period
327    pub tokens_used: u64,
328    /// Budget limit
329    pub tokens_limit: u64,
330    /// Tokens remaining
331    pub tokens_remaining: u64,
332    /// Usage percentage
333    pub usage_percent: f64,
334    /// Period start time (Unix timestamp)
335    pub period_start: u64,
336    /// Period end time (Unix timestamp)
337    pub period_end: u64,
338    /// Whether budget is exhausted
339    pub exhausted: bool,
340}
341
342/// Result of a cost calculation.
343#[derive(Debug, Clone)]
344pub struct CostResult {
345    /// Cost for input tokens
346    pub input_cost: f64,
347    /// Cost for output tokens
348    pub output_cost: f64,
349    /// Total cost (input + output)
350    pub total_cost: f64,
351    /// Currency
352    pub currency: String,
353    /// Model that was used
354    pub model: String,
355    /// Number of input tokens
356    pub input_tokens: u64,
357    /// Number of output tokens
358    pub output_tokens: u64,
359}
360
361impl CostResult {
362    /// Create a new cost result.
363    pub fn new(
364        model: impl Into<String>,
365        input_tokens: u64,
366        output_tokens: u64,
367        input_cost: f64,
368        output_cost: f64,
369        currency: impl Into<String>,
370    ) -> Self {
371        Self {
372            input_cost,
373            output_cost,
374            total_cost: input_cost + output_cost,
375            currency: currency.into(),
376            model: model.into(),
377            input_tokens,
378            output_tokens,
379        }
380    }
381}
382
383// ============================================================================
384// Tests
385// ============================================================================
386
387#[cfg(test)]
388mod tests {
389    use super::*;
390
391    #[test]
392    fn test_budget_period_as_secs() {
393        assert_eq!(BudgetPeriod::Hourly.as_secs(), 3600);
394        assert_eq!(BudgetPeriod::Daily.as_secs(), 86400);
395        assert_eq!(BudgetPeriod::Monthly.as_secs(), 2_592_000);
396        assert_eq!(BudgetPeriod::Custom { seconds: 7200 }.as_secs(), 7200);
397    }
398
399    #[test]
400    fn test_model_pricing_exact_match() {
401        let pricing = ModelPricing::new("gpt-4", 30.0, 60.0);
402        assert!(pricing.matches("gpt-4"));
403        assert!(!pricing.matches("gpt-4-turbo"));
404        assert!(!pricing.matches("gpt-3.5"));
405    }
406
407    #[test]
408    fn test_model_pricing_prefix_match() {
409        let pricing = ModelPricing::new("gpt-4*", 30.0, 60.0);
410        assert!(pricing.matches("gpt-4"));
411        assert!(pricing.matches("gpt-4-turbo"));
412        assert!(pricing.matches("gpt-4o"));
413        assert!(!pricing.matches("gpt-3.5"));
414    }
415
416    #[test]
417    fn test_model_pricing_suffix_match() {
418        let pricing = ModelPricing::new("*-turbo", 30.0, 60.0);
419        assert!(pricing.matches("gpt-4-turbo"));
420        assert!(pricing.matches("gpt-3.5-turbo"));
421        assert!(!pricing.matches("gpt-4"));
422    }
423
424    #[test]
425    fn test_model_pricing_contains_match() {
426        let pricing = ModelPricing::new("*claude*", 30.0, 60.0);
427        assert!(pricing.matches("claude-3"));
428        assert!(pricing.matches("anthropic-claude-3-opus"));
429        assert!(!pricing.matches("gpt-4"));
430    }
431
432    #[test]
433    fn test_model_pricing_calculate_cost() {
434        let pricing = ModelPricing::new("gpt-4", 30.0, 60.0);
435
436        // 1M input tokens = $30, 1M output tokens = $60
437        let cost = pricing.calculate_cost(1_000_000, 1_000_000);
438        assert!((cost - 90.0).abs() < 0.001);
439
440        // 1000 input tokens, 500 output tokens
441        let cost = pricing.calculate_cost(1000, 500);
442        let expected = (1000.0 / 1_000_000.0) * 30.0 + (500.0 / 1_000_000.0) * 60.0;
443        assert!((cost - expected).abs() < 0.0001);
444    }
445
446    #[test]
447    fn test_budget_check_result_is_allowed() {
448        assert!(BudgetCheckResult::Allowed { remaining: 1000 }.is_allowed());
449        assert!(BudgetCheckResult::Soft {
450            remaining: -100,
451            over_by: 100
452        }
453        .is_allowed());
454        assert!(!BudgetCheckResult::Exhausted {
455            retry_after_secs: 3600
456        }
457        .is_allowed());
458    }
459
460    #[test]
461    fn test_budget_alert_usage_percent() {
462        let alert = BudgetAlert {
463            tenant: "test".to_string(),
464            threshold: 0.80,
465            tokens_used: 800_000,
466            tokens_limit: 1_000_000,
467            period_start: 0,
468        };
469        assert!((alert.usage_percent() - 80.0).abs() < 0.001);
470    }
471
472    #[test]
473    fn test_cost_result_new() {
474        let result = CostResult::new("gpt-4", 1000, 500, 0.03, 0.03, "USD");
475        assert_eq!(result.model, "gpt-4");
476        assert_eq!(result.input_tokens, 1000);
477        assert_eq!(result.output_tokens, 500);
478        assert!((result.total_cost - 0.06).abs() < 0.001);
479    }
480
481    #[test]
482    fn test_token_budget_config_default() {
483        let config = TokenBudgetConfig::default();
484        assert_eq!(config.period, BudgetPeriod::Daily);
485        assert_eq!(config.limit, 1_000_000);
486        assert!(config.enforce);
487        assert!(!config.rollover);
488        assert!(config.burst_allowance.is_none());
489        assert_eq!(config.alert_thresholds, vec![0.80, 0.90, 0.95]);
490    }
491
492    #[test]
493    fn test_cost_attribution_config_default() {
494        let config = CostAttributionConfig::default();
495        assert!(!config.enabled);
496        assert!(config.pricing.is_empty());
497        assert!((config.default_input_cost - 1.0).abs() < 0.001);
498        assert!((config.default_output_cost - 2.0).abs() < 0.001);
499        assert_eq!(config.currency, "USD");
500    }
501}