Skip to main content

grapsus_proxy/inference/
budget.rs

1//! Token budget tracker for per-tenant cumulative usage tracking.
2//!
3//! Unlike rate limiting (tokens per minute), budgets track cumulative usage
4//! over longer periods (hourly, daily, monthly) with optional enforcement.
5
6use dashmap::DashMap;
7use std::sync::atomic::{AtomicU64, AtomicU8, Ordering};
8use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
9use tracing::{debug, info, trace, warn};
10
11use grapsus_common::budget::{
12    BudgetAlert, BudgetCheckResult, BudgetPeriod, TenantBudgetStatus, TokenBudgetConfig,
13};
14
15/// Per-tenant budget state tracking
16struct TenantBudgetState {
17    /// Period start time
18    period_start: Instant,
19    /// Period start time as Unix timestamp (for reporting)
20    period_start_unix: u64,
21    /// Tokens used in current period
22    tokens_used: AtomicU64,
23    /// Bitmask of alert thresholds that have been triggered
24    /// Bit 0 = first threshold, Bit 1 = second, etc.
25    alerts_fired: AtomicU8,
26}
27
28impl TenantBudgetState {
29    fn new() -> Self {
30        let now_unix = SystemTime::now()
31            .duration_since(UNIX_EPOCH)
32            .unwrap_or_default()
33            .as_secs();
34
35        Self {
36            period_start: Instant::now(),
37            period_start_unix: now_unix,
38            tokens_used: AtomicU64::new(0),
39            alerts_fired: AtomicU8::new(0),
40        }
41    }
42
43    fn tokens_used(&self) -> u64 {
44        self.tokens_used.load(Ordering::Acquire)
45    }
46
47    fn add_tokens(&self, tokens: u64) {
48        self.tokens_used.fetch_add(tokens, Ordering::AcqRel);
49    }
50
51    fn elapsed(&self) -> Duration {
52        self.period_start.elapsed()
53    }
54
55    fn reset(&mut self) {
56        let now_unix = SystemTime::now()
57            .duration_since(UNIX_EPOCH)
58            .unwrap_or_default()
59            .as_secs();
60
61        self.period_start = Instant::now();
62        self.period_start_unix = now_unix;
63        self.tokens_used.store(0, Ordering::Release);
64        self.alerts_fired.store(0, Ordering::Release);
65    }
66
67    fn has_fired_alert(&self, threshold_index: u8) -> bool {
68        let mask = 1u8 << threshold_index;
69        (self.alerts_fired.load(Ordering::Acquire) & mask) != 0
70    }
71
72    fn mark_alert_fired(&self, threshold_index: u8) {
73        let mask = 1u8 << threshold_index;
74        self.alerts_fired.fetch_or(mask, Ordering::AcqRel);
75    }
76}
77
78/// Token budget tracker for per-tenant usage tracking.
79///
80/// Tracks cumulative token usage over configurable periods (hourly, daily, monthly)
81/// with support for:
82/// - Configurable alert thresholds
83/// - Hard or soft enforcement
84/// - Optional burst allowance
85/// - Period rollover
86pub struct TokenBudgetTracker {
87    /// Budget configuration
88    config: TokenBudgetConfig,
89    /// Per-tenant budget state
90    tenants: DashMap<String, TenantBudgetState>,
91    /// Route ID for logging
92    route_id: String,
93}
94
95impl TokenBudgetTracker {
96    /// Create a new token budget tracker with the given configuration.
97    pub fn new(config: TokenBudgetConfig, route_id: impl Into<String>) -> Self {
98        let route_id = route_id.into();
99
100        info!(
101            route_id = %route_id,
102            period = ?config.period,
103            limit = config.limit,
104            enforce = config.enforce,
105            rollover = config.rollover,
106            "Created token budget tracker"
107        );
108
109        Self {
110            config,
111            tenants: DashMap::new(),
112            route_id,
113        }
114    }
115
116    /// Check if a request with the given token count is allowed.
117    ///
118    /// This does NOT consume tokens - call `record()` after the request completes.
119    pub fn check(&self, tenant: &str, estimated_tokens: u64) -> BudgetCheckResult {
120        let state = self.get_or_create_tenant(tenant);
121        let period_secs = self.config.period.as_secs();
122
123        // Check if period has expired
124        let elapsed = state.elapsed();
125        if elapsed.as_secs() >= period_secs {
126            drop(state);
127            self.reset_period(tenant);
128            return self.check(tenant, estimated_tokens);
129        }
130
131        let current_used = state.tokens_used();
132        let would_use = current_used + estimated_tokens;
133
134        // Check against limit
135        if would_use <= self.config.limit {
136            let remaining = self.config.limit.saturating_sub(would_use);
137            trace!(
138                route_id = %self.route_id,
139                tenant = tenant,
140                current_used = current_used,
141                estimated_tokens = estimated_tokens,
142                remaining = remaining,
143                "Budget check: allowed"
144            );
145            return BudgetCheckResult::Allowed { remaining };
146        }
147
148        // Check burst allowance
149        if let Some(burst) = self.config.burst_allowance {
150            let burst_limit = self.config.limit + (self.config.limit as f64 * burst) as u64;
151            if would_use <= burst_limit {
152                let over_by = would_use - self.config.limit;
153                let remaining = (self.config.limit as i64) - (would_use as i64);
154                trace!(
155                    route_id = %self.route_id,
156                    tenant = tenant,
157                    over_by = over_by,
158                    "Budget check: soft limit (burst)"
159                );
160                return BudgetCheckResult::Soft { remaining, over_by };
161            }
162        }
163
164        // Budget exhausted
165        if self.config.enforce {
166            let retry_after = period_secs.saturating_sub(elapsed.as_secs());
167            debug!(
168                route_id = %self.route_id,
169                tenant = tenant,
170                current_used = current_used,
171                limit = self.config.limit,
172                retry_after_secs = retry_after,
173                "Budget exhausted"
174            );
175            BudgetCheckResult::Exhausted {
176                retry_after_secs: retry_after,
177            }
178        } else {
179            // Not enforcing, just log and allow
180            let over_by = would_use - self.config.limit;
181            let remaining = (self.config.limit as i64) - (would_use as i64);
182            debug!(
183                route_id = %self.route_id,
184                tenant = tenant,
185                over_by = over_by,
186                "Budget exceeded (not enforced)"
187            );
188            BudgetCheckResult::Soft { remaining, over_by }
189        }
190    }
191
192    /// Record actual token usage after a request completes.
193    ///
194    /// Returns any budget alerts that should be fired.
195    pub fn record(&self, tenant: &str, actual_tokens: u64) -> Vec<BudgetAlert> {
196        let state = self.get_or_create_tenant(tenant);
197        let period_secs = self.config.period.as_secs();
198
199        // Check if period has expired
200        let elapsed = state.elapsed();
201        if elapsed.as_secs() >= period_secs {
202            drop(state);
203            self.reset_period(tenant);
204            return self.record(tenant, actual_tokens);
205        }
206
207        // Add tokens
208        state.add_tokens(actual_tokens);
209        let new_total = state.tokens_used();
210
211        trace!(
212            route_id = %self.route_id,
213            tenant = tenant,
214            tokens = actual_tokens,
215            total = new_total,
216            limit = self.config.limit,
217            "Recorded token usage"
218        );
219
220        // Check for alert thresholds
221        let mut alerts = Vec::new();
222        let usage_pct = new_total as f64 / self.config.limit as f64;
223
224        for (idx, &threshold) in self.config.alert_thresholds.iter().enumerate() {
225            if usage_pct >= threshold && !state.has_fired_alert(idx as u8) {
226                state.mark_alert_fired(idx as u8);
227
228                let alert = BudgetAlert {
229                    tenant: tenant.to_string(),
230                    threshold,
231                    tokens_used: new_total,
232                    tokens_limit: self.config.limit,
233                    period_start: state.period_start_unix,
234                };
235
236                info!(
237                    route_id = %self.route_id,
238                    tenant = tenant,
239                    threshold_pct = threshold * 100.0,
240                    tokens_used = new_total,
241                    tokens_limit = self.config.limit,
242                    "Budget alert threshold crossed"
243                );
244
245                alerts.push(alert);
246            }
247        }
248
249        alerts
250    }
251
252    /// Get the current budget status for a tenant.
253    pub fn status(&self, tenant: &str) -> TenantBudgetStatus {
254        let state = self.get_or_create_tenant(tenant);
255        let period_secs = self.config.period.as_secs();
256        let elapsed = state.elapsed();
257
258        let tokens_used = state.tokens_used();
259        let tokens_remaining = self.config.limit.saturating_sub(tokens_used);
260        let usage_percent = (tokens_used as f64 / self.config.limit as f64) * 100.0;
261        let period_end = state.period_start_unix + period_secs;
262
263        TenantBudgetStatus {
264            tokens_used,
265            tokens_limit: self.config.limit,
266            tokens_remaining,
267            usage_percent,
268            period_start: state.period_start_unix,
269            period_end,
270            exhausted: tokens_used >= self.config.limit && self.config.enforce,
271        }
272    }
273
274    /// Reset the budget period for a tenant.
275    pub fn reset_period(&self, tenant: &str) {
276        if let Some(mut state) = self.tenants.get_mut(tenant) {
277            let old_tokens = state.tokens_used();
278
279            // Handle rollover
280            if self.config.rollover && old_tokens < self.config.limit {
281                let unused = self.config.limit - old_tokens;
282                state.reset();
283                // Add back unused tokens (capped at limit)
284                let rollover = unused.min(self.config.limit);
285                state.add_tokens(rollover);
286                info!(
287                    route_id = %self.route_id,
288                    tenant = tenant,
289                    rollover_tokens = rollover,
290                    "Period reset with rollover"
291                );
292            } else {
293                state.reset();
294                debug!(
295                    route_id = %self.route_id,
296                    tenant = tenant,
297                    previous_tokens = old_tokens,
298                    "Period reset"
299                );
300            }
301        }
302    }
303
304    /// Get the number of tracked tenants.
305    pub fn tenant_count(&self) -> usize {
306        self.tenants.len()
307    }
308
309    /// Get the period duration in seconds.
310    pub fn period_secs(&self) -> u64 {
311        self.config.period.as_secs()
312    }
313
314    /// Get the budget limit.
315    pub fn limit(&self) -> u64 {
316        self.config.limit
317    }
318
319    /// Check if enforcement is enabled.
320    pub fn is_enforced(&self) -> bool {
321        self.config.enforce
322    }
323
324    fn get_or_create_tenant(
325        &self,
326        tenant: &str,
327    ) -> dashmap::mapref::one::Ref<'_, String, TenantBudgetState> {
328        self.tenants
329            .entry(tenant.to_string())
330            .or_insert_with(TenantBudgetState::new);
331        self.tenants.get(tenant).expect("Just inserted")
332    }
333}
334
335// ============================================================================
336// Tests
337// ============================================================================
338
339#[cfg(test)]
340mod tests {
341    use super::*;
342
343    fn test_config() -> TokenBudgetConfig {
344        TokenBudgetConfig {
345            period: BudgetPeriod::Custom { seconds: 60 },
346            limit: 1000,
347            alert_thresholds: vec![0.50, 0.80, 0.95],
348            enforce: true,
349            rollover: false,
350            burst_allowance: None,
351        }
352    }
353
354    #[test]
355    fn test_check_allowed() {
356        let tracker = TokenBudgetTracker::new(test_config(), "test-route");
357
358        let result = tracker.check("tenant-1", 100);
359        assert!(result.is_allowed());
360
361        if let BudgetCheckResult::Allowed { remaining } = result {
362            assert_eq!(remaining, 900);
363        } else {
364            panic!("Expected Allowed result");
365        }
366    }
367
368    #[test]
369    fn test_check_exhausted() {
370        let tracker = TokenBudgetTracker::new(test_config(), "test-route");
371
372        // Use up the budget
373        tracker.record("tenant-1", 1000);
374
375        // Next check should be exhausted
376        let result = tracker.check("tenant-1", 100);
377        assert!(!result.is_allowed());
378
379        if let BudgetCheckResult::Exhausted { retry_after_secs } = result {
380            assert!(retry_after_secs > 0);
381        } else {
382            panic!("Expected Exhausted result");
383        }
384    }
385
386    #[test]
387    fn test_record_alerts() {
388        let tracker = TokenBudgetTracker::new(test_config(), "test-route");
389
390        // Record 500 tokens (50% threshold)
391        let alerts = tracker.record("tenant-1", 500);
392        assert_eq!(alerts.len(), 1);
393        assert!((alerts[0].threshold - 0.50).abs() < 0.001);
394
395        // Record 300 more tokens (80% threshold)
396        let alerts = tracker.record("tenant-1", 300);
397        assert_eq!(alerts.len(), 1);
398        assert!((alerts[0].threshold - 0.80).abs() < 0.001);
399
400        // Record 200 more tokens (95% + 100% threshold, but 100% not in thresholds)
401        let alerts = tracker.record("tenant-1", 200);
402        assert_eq!(alerts.len(), 1);
403        assert!((alerts[0].threshold - 0.95).abs() < 0.001);
404
405        // No more alerts
406        let alerts = tracker.record("tenant-1", 100);
407        assert!(alerts.is_empty());
408    }
409
410    #[test]
411    fn test_status() {
412        let tracker = TokenBudgetTracker::new(test_config(), "test-route");
413
414        tracker.record("tenant-1", 400);
415
416        let status = tracker.status("tenant-1");
417        assert_eq!(status.tokens_used, 400);
418        assert_eq!(status.tokens_limit, 1000);
419        assert_eq!(status.tokens_remaining, 600);
420        assert!((status.usage_percent - 40.0).abs() < 0.001);
421        assert!(!status.exhausted);
422    }
423
424    #[test]
425    fn test_burst_allowance() {
426        let mut config = test_config();
427        config.burst_allowance = Some(0.10); // 10% burst
428
429        let tracker = TokenBudgetTracker::new(config, "test-route");
430
431        // Use 1050 tokens (5% over limit, within burst)
432        tracker.record("tenant-1", 950);
433
434        let result = tracker.check("tenant-1", 100);
435        assert!(result.is_allowed());
436
437        if let BudgetCheckResult::Soft { remaining, over_by } = result {
438            assert_eq!(over_by, 50);
439            assert_eq!(remaining, -50);
440        } else {
441            panic!("Expected Soft result");
442        }
443    }
444
445    #[test]
446    fn test_no_enforcement() {
447        let mut config = test_config();
448        config.enforce = false;
449
450        let tracker = TokenBudgetTracker::new(config, "test-route");
451
452        // Use up budget
453        tracker.record("tenant-1", 1000);
454
455        // Should still be allowed (soft)
456        let result = tracker.check("tenant-1", 100);
457        assert!(result.is_allowed());
458    }
459
460    #[test]
461    fn test_period_reset() {
462        let tracker = TokenBudgetTracker::new(test_config(), "test-route");
463
464        tracker.record("tenant-1", 500);
465        assert_eq!(tracker.status("tenant-1").tokens_used, 500);
466
467        tracker.reset_period("tenant-1");
468        assert_eq!(tracker.status("tenant-1").tokens_used, 0);
469    }
470
471    #[test]
472    fn test_rollover() {
473        let mut config = test_config();
474        config.rollover = true;
475
476        let tracker = TokenBudgetTracker::new(config, "test-route");
477
478        // Use 300 tokens (700 unused)
479        tracker.record("tenant-1", 300);
480
481        // Reset with rollover
482        tracker.reset_period("tenant-1");
483
484        // Should have 700 tokens carried over
485        let status = tracker.status("tenant-1");
486        assert_eq!(status.tokens_used, 700);
487    }
488
489    #[test]
490    fn test_multiple_tenants() {
491        let tracker = TokenBudgetTracker::new(test_config(), "test-route");
492
493        tracker.record("tenant-1", 500);
494        tracker.record("tenant-2", 200);
495
496        assert_eq!(tracker.status("tenant-1").tokens_used, 500);
497        assert_eq!(tracker.status("tenant-2").tokens_used, 200);
498        assert_eq!(tracker.tenant_count(), 2);
499    }
500}