sentinel_proxy/inference/
budget.rs1use dashmap::DashMap;
7use std::sync::atomic::{AtomicU64, AtomicU8, Ordering};
8use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
9use tracing::{debug, info, trace, warn};
10
11use sentinel_common::budget::{
12 BudgetAlert, BudgetCheckResult, BudgetPeriod, TenantBudgetStatus, TokenBudgetConfig,
13};
14
15struct TenantBudgetState {
17 period_start: Instant,
19 period_start_unix: u64,
21 tokens_used: AtomicU64,
23 alerts_fired: AtomicU8,
26}
27
28impl TenantBudgetState {
29 fn new() -> Self {
30 let now_unix = SystemTime::now()
31 .duration_since(UNIX_EPOCH)
32 .unwrap_or_default()
33 .as_secs();
34
35 Self {
36 period_start: Instant::now(),
37 period_start_unix: now_unix,
38 tokens_used: AtomicU64::new(0),
39 alerts_fired: AtomicU8::new(0),
40 }
41 }
42
43 fn tokens_used(&self) -> u64 {
44 self.tokens_used.load(Ordering::Acquire)
45 }
46
47 fn add_tokens(&self, tokens: u64) {
48 self.tokens_used.fetch_add(tokens, Ordering::AcqRel);
49 }
50
51 fn elapsed(&self) -> Duration {
52 self.period_start.elapsed()
53 }
54
55 fn reset(&mut self) {
56 let now_unix = SystemTime::now()
57 .duration_since(UNIX_EPOCH)
58 .unwrap_or_default()
59 .as_secs();
60
61 self.period_start = Instant::now();
62 self.period_start_unix = now_unix;
63 self.tokens_used.store(0, Ordering::Release);
64 self.alerts_fired.store(0, Ordering::Release);
65 }
66
67 fn has_fired_alert(&self, threshold_index: u8) -> bool {
68 let mask = 1u8 << threshold_index;
69 (self.alerts_fired.load(Ordering::Acquire) & mask) != 0
70 }
71
72 fn mark_alert_fired(&self, threshold_index: u8) {
73 let mask = 1u8 << threshold_index;
74 self.alerts_fired.fetch_or(mask, Ordering::AcqRel);
75 }
76}
77
78pub struct TokenBudgetTracker {
87 config: TokenBudgetConfig,
89 tenants: DashMap<String, TenantBudgetState>,
91 route_id: String,
93}
94
95impl TokenBudgetTracker {
96 pub fn new(config: TokenBudgetConfig, route_id: impl Into<String>) -> Self {
98 let route_id = route_id.into();
99
100 info!(
101 route_id = %route_id,
102 period = ?config.period,
103 limit = config.limit,
104 enforce = config.enforce,
105 rollover = config.rollover,
106 "Created token budget tracker"
107 );
108
109 Self {
110 config,
111 tenants: DashMap::new(),
112 route_id,
113 }
114 }
115
116 pub fn check(&self, tenant: &str, estimated_tokens: u64) -> BudgetCheckResult {
120 let state = self.get_or_create_tenant(tenant);
121 let period_secs = self.config.period.as_secs();
122
123 let elapsed = state.elapsed();
125 if elapsed.as_secs() >= period_secs {
126 drop(state);
127 self.reset_period(tenant);
128 return self.check(tenant, estimated_tokens);
129 }
130
131 let current_used = state.tokens_used();
132 let would_use = current_used + estimated_tokens;
133
134 if would_use <= self.config.limit {
136 let remaining = self.config.limit.saturating_sub(would_use);
137 trace!(
138 route_id = %self.route_id,
139 tenant = tenant,
140 current_used = current_used,
141 estimated_tokens = estimated_tokens,
142 remaining = remaining,
143 "Budget check: allowed"
144 );
145 return BudgetCheckResult::Allowed { remaining };
146 }
147
148 if let Some(burst) = self.config.burst_allowance {
150 let burst_limit = self.config.limit + (self.config.limit as f64 * burst) as u64;
151 if would_use <= burst_limit {
152 let over_by = would_use - self.config.limit;
153 let remaining = (self.config.limit as i64) - (would_use as i64);
154 trace!(
155 route_id = %self.route_id,
156 tenant = tenant,
157 over_by = over_by,
158 "Budget check: soft limit (burst)"
159 );
160 return BudgetCheckResult::Soft { remaining, over_by };
161 }
162 }
163
164 if self.config.enforce {
166 let retry_after = period_secs.saturating_sub(elapsed.as_secs());
167 debug!(
168 route_id = %self.route_id,
169 tenant = tenant,
170 current_used = current_used,
171 limit = self.config.limit,
172 retry_after_secs = retry_after,
173 "Budget exhausted"
174 );
175 BudgetCheckResult::Exhausted {
176 retry_after_secs: retry_after,
177 }
178 } else {
179 let over_by = would_use - self.config.limit;
181 let remaining = (self.config.limit as i64) - (would_use as i64);
182 debug!(
183 route_id = %self.route_id,
184 tenant = tenant,
185 over_by = over_by,
186 "Budget exceeded (not enforced)"
187 );
188 BudgetCheckResult::Soft { remaining, over_by }
189 }
190 }
191
192 pub fn record(&self, tenant: &str, actual_tokens: u64) -> Vec<BudgetAlert> {
196 let state = self.get_or_create_tenant(tenant);
197 let period_secs = self.config.period.as_secs();
198
199 let elapsed = state.elapsed();
201 if elapsed.as_secs() >= period_secs {
202 drop(state);
203 self.reset_period(tenant);
204 return self.record(tenant, actual_tokens);
205 }
206
207 state.add_tokens(actual_tokens);
209 let new_total = state.tokens_used();
210
211 trace!(
212 route_id = %self.route_id,
213 tenant = tenant,
214 tokens = actual_tokens,
215 total = new_total,
216 limit = self.config.limit,
217 "Recorded token usage"
218 );
219
220 let mut alerts = Vec::new();
222 let usage_pct = new_total as f64 / self.config.limit as f64;
223
224 for (idx, &threshold) in self.config.alert_thresholds.iter().enumerate() {
225 if usage_pct >= threshold && !state.has_fired_alert(idx as u8) {
226 state.mark_alert_fired(idx as u8);
227
228 let alert = BudgetAlert {
229 tenant: tenant.to_string(),
230 threshold,
231 tokens_used: new_total,
232 tokens_limit: self.config.limit,
233 period_start: state.period_start_unix,
234 };
235
236 info!(
237 route_id = %self.route_id,
238 tenant = tenant,
239 threshold_pct = threshold * 100.0,
240 tokens_used = new_total,
241 tokens_limit = self.config.limit,
242 "Budget alert threshold crossed"
243 );
244
245 alerts.push(alert);
246 }
247 }
248
249 alerts
250 }
251
252 pub fn status(&self, tenant: &str) -> TenantBudgetStatus {
254 let state = self.get_or_create_tenant(tenant);
255 let period_secs = self.config.period.as_secs();
256 let elapsed = state.elapsed();
257
258 let tokens_used = state.tokens_used();
259 let tokens_remaining = self.config.limit.saturating_sub(tokens_used);
260 let usage_percent = (tokens_used as f64 / self.config.limit as f64) * 100.0;
261 let period_end = state.period_start_unix + period_secs;
262
263 TenantBudgetStatus {
264 tokens_used,
265 tokens_limit: self.config.limit,
266 tokens_remaining,
267 usage_percent,
268 period_start: state.period_start_unix,
269 period_end,
270 exhausted: tokens_used >= self.config.limit && self.config.enforce,
271 }
272 }
273
274 pub fn reset_period(&self, tenant: &str) {
276 if let Some(mut state) = self.tenants.get_mut(tenant) {
277 let old_tokens = state.tokens_used();
278
279 if self.config.rollover && old_tokens < self.config.limit {
281 let unused = self.config.limit - old_tokens;
282 state.reset();
283 let rollover = unused.min(self.config.limit);
285 state.add_tokens(rollover);
286 info!(
287 route_id = %self.route_id,
288 tenant = tenant,
289 rollover_tokens = rollover,
290 "Period reset with rollover"
291 );
292 } else {
293 state.reset();
294 debug!(
295 route_id = %self.route_id,
296 tenant = tenant,
297 previous_tokens = old_tokens,
298 "Period reset"
299 );
300 }
301 }
302 }
303
304 pub fn tenant_count(&self) -> usize {
306 self.tenants.len()
307 }
308
309 pub fn period_secs(&self) -> u64 {
311 self.config.period.as_secs()
312 }
313
314 pub fn limit(&self) -> u64 {
316 self.config.limit
317 }
318
319 pub fn is_enforced(&self) -> bool {
321 self.config.enforce
322 }
323
324 fn get_or_create_tenant(&self, tenant: &str) -> dashmap::mapref::one::Ref<'_, String, TenantBudgetState> {
325 self.tenants
326 .entry(tenant.to_string())
327 .or_insert_with(TenantBudgetState::new);
328 self.tenants.get(tenant).expect("Just inserted")
329 }
330}
331
332#[cfg(test)]
337mod tests {
338 use super::*;
339
340 fn test_config() -> TokenBudgetConfig {
341 TokenBudgetConfig {
342 period: BudgetPeriod::Custom { seconds: 60 },
343 limit: 1000,
344 alert_thresholds: vec![0.50, 0.80, 0.95],
345 enforce: true,
346 rollover: false,
347 burst_allowance: None,
348 }
349 }
350
351 #[test]
352 fn test_check_allowed() {
353 let tracker = TokenBudgetTracker::new(test_config(), "test-route");
354
355 let result = tracker.check("tenant-1", 100);
356 assert!(result.is_allowed());
357
358 if let BudgetCheckResult::Allowed { remaining } = result {
359 assert_eq!(remaining, 900);
360 } else {
361 panic!("Expected Allowed result");
362 }
363 }
364
365 #[test]
366 fn test_check_exhausted() {
367 let tracker = TokenBudgetTracker::new(test_config(), "test-route");
368
369 tracker.record("tenant-1", 1000);
371
372 let result = tracker.check("tenant-1", 100);
374 assert!(!result.is_allowed());
375
376 if let BudgetCheckResult::Exhausted { retry_after_secs } = result {
377 assert!(retry_after_secs > 0);
378 } else {
379 panic!("Expected Exhausted result");
380 }
381 }
382
383 #[test]
384 fn test_record_alerts() {
385 let tracker = TokenBudgetTracker::new(test_config(), "test-route");
386
387 let alerts = tracker.record("tenant-1", 500);
389 assert_eq!(alerts.len(), 1);
390 assert!((alerts[0].threshold - 0.50).abs() < 0.001);
391
392 let alerts = tracker.record("tenant-1", 300);
394 assert_eq!(alerts.len(), 1);
395 assert!((alerts[0].threshold - 0.80).abs() < 0.001);
396
397 let alerts = tracker.record("tenant-1", 200);
399 assert_eq!(alerts.len(), 1);
400 assert!((alerts[0].threshold - 0.95).abs() < 0.001);
401
402 let alerts = tracker.record("tenant-1", 100);
404 assert!(alerts.is_empty());
405 }
406
407 #[test]
408 fn test_status() {
409 let tracker = TokenBudgetTracker::new(test_config(), "test-route");
410
411 tracker.record("tenant-1", 400);
412
413 let status = tracker.status("tenant-1");
414 assert_eq!(status.tokens_used, 400);
415 assert_eq!(status.tokens_limit, 1000);
416 assert_eq!(status.tokens_remaining, 600);
417 assert!((status.usage_percent - 40.0).abs() < 0.001);
418 assert!(!status.exhausted);
419 }
420
421 #[test]
422 fn test_burst_allowance() {
423 let mut config = test_config();
424 config.burst_allowance = Some(0.10); let tracker = TokenBudgetTracker::new(config, "test-route");
427
428 tracker.record("tenant-1", 950);
430
431 let result = tracker.check("tenant-1", 100);
432 assert!(result.is_allowed());
433
434 if let BudgetCheckResult::Soft { remaining, over_by } = result {
435 assert_eq!(over_by, 50);
436 assert_eq!(remaining, -50);
437 } else {
438 panic!("Expected Soft result");
439 }
440 }
441
442 #[test]
443 fn test_no_enforcement() {
444 let mut config = test_config();
445 config.enforce = false;
446
447 let tracker = TokenBudgetTracker::new(config, "test-route");
448
449 tracker.record("tenant-1", 1000);
451
452 let result = tracker.check("tenant-1", 100);
454 assert!(result.is_allowed());
455 }
456
457 #[test]
458 fn test_period_reset() {
459 let tracker = TokenBudgetTracker::new(test_config(), "test-route");
460
461 tracker.record("tenant-1", 500);
462 assert_eq!(tracker.status("tenant-1").tokens_used, 500);
463
464 tracker.reset_period("tenant-1");
465 assert_eq!(tracker.status("tenant-1").tokens_used, 0);
466 }
467
468 #[test]
469 fn test_rollover() {
470 let mut config = test_config();
471 config.rollover = true;
472
473 let tracker = TokenBudgetTracker::new(config, "test-route");
474
475 tracker.record("tenant-1", 300);
477
478 tracker.reset_period("tenant-1");
480
481 let status = tracker.status("tenant-1");
483 assert_eq!(status.tokens_used, 700);
484 }
485
486 #[test]
487 fn test_multiple_tenants() {
488 let tracker = TokenBudgetTracker::new(test_config(), "test-route");
489
490 tracker.record("tenant-1", 500);
491 tracker.record("tenant-2", 200);
492
493 assert_eq!(tracker.status("tenant-1").tokens_used, 500);
494 assert_eq!(tracker.status("tenant-2").tokens_used, 200);
495 assert_eq!(tracker.tenant_count(), 2);
496 }
497}