sentinel_proxy/inference/
budget.rs

1//! Token budget tracker for per-tenant cumulative usage tracking.
2//!
3//! Unlike rate limiting (tokens per minute), budgets track cumulative usage
4//! over longer periods (hourly, daily, monthly) with optional enforcement.
5
6use dashmap::DashMap;
7use std::sync::atomic::{AtomicU64, AtomicU8, Ordering};
8use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
9use tracing::{debug, info, trace, warn};
10
11use sentinel_common::budget::{
12    BudgetAlert, BudgetCheckResult, BudgetPeriod, TenantBudgetStatus, TokenBudgetConfig,
13};
14
15/// Per-tenant budget state tracking
16struct TenantBudgetState {
17    /// Period start time
18    period_start: Instant,
19    /// Period start time as Unix timestamp (for reporting)
20    period_start_unix: u64,
21    /// Tokens used in current period
22    tokens_used: AtomicU64,
23    /// Bitmask of alert thresholds that have been triggered
24    /// Bit 0 = first threshold, Bit 1 = second, etc.
25    alerts_fired: AtomicU8,
26}
27
28impl TenantBudgetState {
29    fn new() -> Self {
30        let now_unix = SystemTime::now()
31            .duration_since(UNIX_EPOCH)
32            .unwrap_or_default()
33            .as_secs();
34
35        Self {
36            period_start: Instant::now(),
37            period_start_unix: now_unix,
38            tokens_used: AtomicU64::new(0),
39            alerts_fired: AtomicU8::new(0),
40        }
41    }
42
43    fn tokens_used(&self) -> u64 {
44        self.tokens_used.load(Ordering::Acquire)
45    }
46
47    fn add_tokens(&self, tokens: u64) {
48        self.tokens_used.fetch_add(tokens, Ordering::AcqRel);
49    }
50
51    fn elapsed(&self) -> Duration {
52        self.period_start.elapsed()
53    }
54
55    fn reset(&mut self) {
56        let now_unix = SystemTime::now()
57            .duration_since(UNIX_EPOCH)
58            .unwrap_or_default()
59            .as_secs();
60
61        self.period_start = Instant::now();
62        self.period_start_unix = now_unix;
63        self.tokens_used.store(0, Ordering::Release);
64        self.alerts_fired.store(0, Ordering::Release);
65    }
66
67    fn has_fired_alert(&self, threshold_index: u8) -> bool {
68        let mask = 1u8 << threshold_index;
69        (self.alerts_fired.load(Ordering::Acquire) & mask) != 0
70    }
71
72    fn mark_alert_fired(&self, threshold_index: u8) {
73        let mask = 1u8 << threshold_index;
74        self.alerts_fired.fetch_or(mask, Ordering::AcqRel);
75    }
76}
77
78/// Token budget tracker for per-tenant usage tracking.
79///
80/// Tracks cumulative token usage over configurable periods (hourly, daily, monthly)
81/// with support for:
82/// - Configurable alert thresholds
83/// - Hard or soft enforcement
84/// - Optional burst allowance
85/// - Period rollover
86pub struct TokenBudgetTracker {
87    /// Budget configuration
88    config: TokenBudgetConfig,
89    /// Per-tenant budget state
90    tenants: DashMap<String, TenantBudgetState>,
91    /// Route ID for logging
92    route_id: String,
93}
94
95impl TokenBudgetTracker {
96    /// Create a new token budget tracker with the given configuration.
97    pub fn new(config: TokenBudgetConfig, route_id: impl Into<String>) -> Self {
98        let route_id = route_id.into();
99
100        info!(
101            route_id = %route_id,
102            period = ?config.period,
103            limit = config.limit,
104            enforce = config.enforce,
105            rollover = config.rollover,
106            "Created token budget tracker"
107        );
108
109        Self {
110            config,
111            tenants: DashMap::new(),
112            route_id,
113        }
114    }
115
116    /// Check if a request with the given token count is allowed.
117    ///
118    /// This does NOT consume tokens - call `record()` after the request completes.
119    pub fn check(&self, tenant: &str, estimated_tokens: u64) -> BudgetCheckResult {
120        let state = self.get_or_create_tenant(tenant);
121        let period_secs = self.config.period.as_secs();
122
123        // Check if period has expired
124        let elapsed = state.elapsed();
125        if elapsed.as_secs() >= period_secs {
126            drop(state);
127            self.reset_period(tenant);
128            return self.check(tenant, estimated_tokens);
129        }
130
131        let current_used = state.tokens_used();
132        let would_use = current_used + estimated_tokens;
133
134        // Check against limit
135        if would_use <= self.config.limit {
136            let remaining = self.config.limit.saturating_sub(would_use);
137            trace!(
138                route_id = %self.route_id,
139                tenant = tenant,
140                current_used = current_used,
141                estimated_tokens = estimated_tokens,
142                remaining = remaining,
143                "Budget check: allowed"
144            );
145            return BudgetCheckResult::Allowed { remaining };
146        }
147
148        // Check burst allowance
149        if let Some(burst) = self.config.burst_allowance {
150            let burst_limit = self.config.limit + (self.config.limit as f64 * burst) as u64;
151            if would_use <= burst_limit {
152                let over_by = would_use - self.config.limit;
153                let remaining = (self.config.limit as i64) - (would_use as i64);
154                trace!(
155                    route_id = %self.route_id,
156                    tenant = tenant,
157                    over_by = over_by,
158                    "Budget check: soft limit (burst)"
159                );
160                return BudgetCheckResult::Soft { remaining, over_by };
161            }
162        }
163
164        // Budget exhausted
165        if self.config.enforce {
166            let retry_after = period_secs.saturating_sub(elapsed.as_secs());
167            debug!(
168                route_id = %self.route_id,
169                tenant = tenant,
170                current_used = current_used,
171                limit = self.config.limit,
172                retry_after_secs = retry_after,
173                "Budget exhausted"
174            );
175            BudgetCheckResult::Exhausted {
176                retry_after_secs: retry_after,
177            }
178        } else {
179            // Not enforcing, just log and allow
180            let over_by = would_use - self.config.limit;
181            let remaining = (self.config.limit as i64) - (would_use as i64);
182            debug!(
183                route_id = %self.route_id,
184                tenant = tenant,
185                over_by = over_by,
186                "Budget exceeded (not enforced)"
187            );
188            BudgetCheckResult::Soft { remaining, over_by }
189        }
190    }
191
192    /// Record actual token usage after a request completes.
193    ///
194    /// Returns any budget alerts that should be fired.
195    pub fn record(&self, tenant: &str, actual_tokens: u64) -> Vec<BudgetAlert> {
196        let state = self.get_or_create_tenant(tenant);
197        let period_secs = self.config.period.as_secs();
198
199        // Check if period has expired
200        let elapsed = state.elapsed();
201        if elapsed.as_secs() >= period_secs {
202            drop(state);
203            self.reset_period(tenant);
204            return self.record(tenant, actual_tokens);
205        }
206
207        // Add tokens
208        state.add_tokens(actual_tokens);
209        let new_total = state.tokens_used();
210
211        trace!(
212            route_id = %self.route_id,
213            tenant = tenant,
214            tokens = actual_tokens,
215            total = new_total,
216            limit = self.config.limit,
217            "Recorded token usage"
218        );
219
220        // Check for alert thresholds
221        let mut alerts = Vec::new();
222        let usage_pct = new_total as f64 / self.config.limit as f64;
223
224        for (idx, &threshold) in self.config.alert_thresholds.iter().enumerate() {
225            if usage_pct >= threshold && !state.has_fired_alert(idx as u8) {
226                state.mark_alert_fired(idx as u8);
227
228                let alert = BudgetAlert {
229                    tenant: tenant.to_string(),
230                    threshold,
231                    tokens_used: new_total,
232                    tokens_limit: self.config.limit,
233                    period_start: state.period_start_unix,
234                };
235
236                info!(
237                    route_id = %self.route_id,
238                    tenant = tenant,
239                    threshold_pct = threshold * 100.0,
240                    tokens_used = new_total,
241                    tokens_limit = self.config.limit,
242                    "Budget alert threshold crossed"
243                );
244
245                alerts.push(alert);
246            }
247        }
248
249        alerts
250    }
251
252    /// Get the current budget status for a tenant.
253    pub fn status(&self, tenant: &str) -> TenantBudgetStatus {
254        let state = self.get_or_create_tenant(tenant);
255        let period_secs = self.config.period.as_secs();
256        let elapsed = state.elapsed();
257
258        let tokens_used = state.tokens_used();
259        let tokens_remaining = self.config.limit.saturating_sub(tokens_used);
260        let usage_percent = (tokens_used as f64 / self.config.limit as f64) * 100.0;
261        let period_end = state.period_start_unix + period_secs;
262
263        TenantBudgetStatus {
264            tokens_used,
265            tokens_limit: self.config.limit,
266            tokens_remaining,
267            usage_percent,
268            period_start: state.period_start_unix,
269            period_end,
270            exhausted: tokens_used >= self.config.limit && self.config.enforce,
271        }
272    }
273
274    /// Reset the budget period for a tenant.
275    pub fn reset_period(&self, tenant: &str) {
276        if let Some(mut state) = self.tenants.get_mut(tenant) {
277            let old_tokens = state.tokens_used();
278
279            // Handle rollover
280            if self.config.rollover && old_tokens < self.config.limit {
281                let unused = self.config.limit - old_tokens;
282                state.reset();
283                // Add back unused tokens (capped at limit)
284                let rollover = unused.min(self.config.limit);
285                state.add_tokens(rollover);
286                info!(
287                    route_id = %self.route_id,
288                    tenant = tenant,
289                    rollover_tokens = rollover,
290                    "Period reset with rollover"
291                );
292            } else {
293                state.reset();
294                debug!(
295                    route_id = %self.route_id,
296                    tenant = tenant,
297                    previous_tokens = old_tokens,
298                    "Period reset"
299                );
300            }
301        }
302    }
303
304    /// Get the number of tracked tenants.
305    pub fn tenant_count(&self) -> usize {
306        self.tenants.len()
307    }
308
309    /// Get the period duration in seconds.
310    pub fn period_secs(&self) -> u64 {
311        self.config.period.as_secs()
312    }
313
314    /// Get the budget limit.
315    pub fn limit(&self) -> u64 {
316        self.config.limit
317    }
318
319    /// Check if enforcement is enabled.
320    pub fn is_enforced(&self) -> bool {
321        self.config.enforce
322    }
323
324    fn get_or_create_tenant(&self, tenant: &str) -> dashmap::mapref::one::Ref<'_, String, TenantBudgetState> {
325        self.tenants
326            .entry(tenant.to_string())
327            .or_insert_with(TenantBudgetState::new);
328        self.tenants.get(tenant).expect("Just inserted")
329    }
330}
331
332// ============================================================================
333// Tests
334// ============================================================================
335
336#[cfg(test)]
337mod tests {
338    use super::*;
339
340    fn test_config() -> TokenBudgetConfig {
341        TokenBudgetConfig {
342            period: BudgetPeriod::Custom { seconds: 60 },
343            limit: 1000,
344            alert_thresholds: vec![0.50, 0.80, 0.95],
345            enforce: true,
346            rollover: false,
347            burst_allowance: None,
348        }
349    }
350
351    #[test]
352    fn test_check_allowed() {
353        let tracker = TokenBudgetTracker::new(test_config(), "test-route");
354
355        let result = tracker.check("tenant-1", 100);
356        assert!(result.is_allowed());
357
358        if let BudgetCheckResult::Allowed { remaining } = result {
359            assert_eq!(remaining, 900);
360        } else {
361            panic!("Expected Allowed result");
362        }
363    }
364
365    #[test]
366    fn test_check_exhausted() {
367        let tracker = TokenBudgetTracker::new(test_config(), "test-route");
368
369        // Use up the budget
370        tracker.record("tenant-1", 1000);
371
372        // Next check should be exhausted
373        let result = tracker.check("tenant-1", 100);
374        assert!(!result.is_allowed());
375
376        if let BudgetCheckResult::Exhausted { retry_after_secs } = result {
377            assert!(retry_after_secs > 0);
378        } else {
379            panic!("Expected Exhausted result");
380        }
381    }
382
383    #[test]
384    fn test_record_alerts() {
385        let tracker = TokenBudgetTracker::new(test_config(), "test-route");
386
387        // Record 500 tokens (50% threshold)
388        let alerts = tracker.record("tenant-1", 500);
389        assert_eq!(alerts.len(), 1);
390        assert!((alerts[0].threshold - 0.50).abs() < 0.001);
391
392        // Record 300 more tokens (80% threshold)
393        let alerts = tracker.record("tenant-1", 300);
394        assert_eq!(alerts.len(), 1);
395        assert!((alerts[0].threshold - 0.80).abs() < 0.001);
396
397        // Record 200 more tokens (95% + 100% threshold, but 100% not in thresholds)
398        let alerts = tracker.record("tenant-1", 200);
399        assert_eq!(alerts.len(), 1);
400        assert!((alerts[0].threshold - 0.95).abs() < 0.001);
401
402        // No more alerts
403        let alerts = tracker.record("tenant-1", 100);
404        assert!(alerts.is_empty());
405    }
406
407    #[test]
408    fn test_status() {
409        let tracker = TokenBudgetTracker::new(test_config(), "test-route");
410
411        tracker.record("tenant-1", 400);
412
413        let status = tracker.status("tenant-1");
414        assert_eq!(status.tokens_used, 400);
415        assert_eq!(status.tokens_limit, 1000);
416        assert_eq!(status.tokens_remaining, 600);
417        assert!((status.usage_percent - 40.0).abs() < 0.001);
418        assert!(!status.exhausted);
419    }
420
421    #[test]
422    fn test_burst_allowance() {
423        let mut config = test_config();
424        config.burst_allowance = Some(0.10); // 10% burst
425
426        let tracker = TokenBudgetTracker::new(config, "test-route");
427
428        // Use 1050 tokens (5% over limit, within burst)
429        tracker.record("tenant-1", 950);
430
431        let result = tracker.check("tenant-1", 100);
432        assert!(result.is_allowed());
433
434        if let BudgetCheckResult::Soft { remaining, over_by } = result {
435            assert_eq!(over_by, 50);
436            assert_eq!(remaining, -50);
437        } else {
438            panic!("Expected Soft result");
439        }
440    }
441
442    #[test]
443    fn test_no_enforcement() {
444        let mut config = test_config();
445        config.enforce = false;
446
447        let tracker = TokenBudgetTracker::new(config, "test-route");
448
449        // Use up budget
450        tracker.record("tenant-1", 1000);
451
452        // Should still be allowed (soft)
453        let result = tracker.check("tenant-1", 100);
454        assert!(result.is_allowed());
455    }
456
457    #[test]
458    fn test_period_reset() {
459        let tracker = TokenBudgetTracker::new(test_config(), "test-route");
460
461        tracker.record("tenant-1", 500);
462        assert_eq!(tracker.status("tenant-1").tokens_used, 500);
463
464        tracker.reset_period("tenant-1");
465        assert_eq!(tracker.status("tenant-1").tokens_used, 0);
466    }
467
468    #[test]
469    fn test_rollover() {
470        let mut config = test_config();
471        config.rollover = true;
472
473        let tracker = TokenBudgetTracker::new(config, "test-route");
474
475        // Use 300 tokens (700 unused)
476        tracker.record("tenant-1", 300);
477
478        // Reset with rollover
479        tracker.reset_period("tenant-1");
480
481        // Should have 700 tokens carried over
482        let status = tracker.status("tenant-1");
483        assert_eq!(status.tokens_used, 700);
484    }
485
486    #[test]
487    fn test_multiple_tenants() {
488        let tracker = TokenBudgetTracker::new(test_config(), "test-route");
489
490        tracker.record("tenant-1", 500);
491        tracker.record("tenant-2", 200);
492
493        assert_eq!(tracker.status("tenant-1").tokens_used, 500);
494        assert_eq!(tracker.status("tenant-2").tokens_used, 200);
495        assert_eq!(tracker.tenant_count(), 2);
496    }
497}