use dashmap::DashMap;
use std::sync::atomic::{AtomicU64, AtomicU8, Ordering};
use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
use tracing::{debug, info, trace, warn};
use zentinel_common::budget::{
BudgetAlert, BudgetCheckResult, BudgetPeriod, TenantBudgetStatus, TokenBudgetConfig,
};
struct TenantBudgetState {
period_start: Instant,
period_start_unix: u64,
tokens_used: AtomicU64,
alerts_fired: AtomicU8,
}
impl TenantBudgetState {
fn new() -> Self {
let now_unix = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
Self {
period_start: Instant::now(),
period_start_unix: now_unix,
tokens_used: AtomicU64::new(0),
alerts_fired: AtomicU8::new(0),
}
}
fn tokens_used(&self) -> u64 {
self.tokens_used.load(Ordering::Acquire)
}
fn add_tokens(&self, tokens: u64) {
self.tokens_used.fetch_add(tokens, Ordering::AcqRel);
}
fn elapsed(&self) -> Duration {
self.period_start.elapsed()
}
fn reset(&mut self) {
let now_unix = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
self.period_start = Instant::now();
self.period_start_unix = now_unix;
self.tokens_used.store(0, Ordering::Release);
self.alerts_fired.store(0, Ordering::Release);
}
fn has_fired_alert(&self, threshold_index: u8) -> bool {
let mask = 1u8 << threshold_index;
(self.alerts_fired.load(Ordering::Acquire) & mask) != 0
}
fn mark_alert_fired(&self, threshold_index: u8) {
let mask = 1u8 << threshold_index;
self.alerts_fired.fetch_or(mask, Ordering::AcqRel);
}
}
pub struct TokenBudgetTracker {
config: TokenBudgetConfig,
tenants: DashMap<String, TenantBudgetState>,
route_id: String,
}
impl TokenBudgetTracker {
pub fn new(config: TokenBudgetConfig, route_id: impl Into<String>) -> Self {
let route_id = route_id.into();
info!(
route_id = %route_id,
period = ?config.period,
limit = config.limit,
enforce = config.enforce,
rollover = config.rollover,
"Created token budget tracker"
);
Self {
config,
tenants: DashMap::new(),
route_id,
}
}
pub fn check(&self, tenant: &str, estimated_tokens: u64) -> BudgetCheckResult {
let state = self.get_or_create_tenant(tenant);
let period_secs = self.config.period.as_secs();
let elapsed = state.elapsed();
if elapsed.as_secs() >= period_secs {
drop(state);
self.reset_period(tenant);
return self.check(tenant, estimated_tokens);
}
let current_used = state.tokens_used();
let would_use = current_used + estimated_tokens;
if would_use <= self.config.limit {
let remaining = self.config.limit.saturating_sub(would_use);
trace!(
route_id = %self.route_id,
tenant = tenant,
current_used = current_used,
estimated_tokens = estimated_tokens,
remaining = remaining,
"Budget check: allowed"
);
return BudgetCheckResult::Allowed { remaining };
}
if let Some(burst) = self.config.burst_allowance {
let burst_limit = self.config.limit + (self.config.limit as f64 * burst) as u64;
if would_use <= burst_limit {
let over_by = would_use - self.config.limit;
let remaining = (self.config.limit as i64) - (would_use as i64);
trace!(
route_id = %self.route_id,
tenant = tenant,
over_by = over_by,
"Budget check: soft limit (burst)"
);
return BudgetCheckResult::Soft { remaining, over_by };
}
}
if self.config.enforce {
let retry_after = period_secs.saturating_sub(elapsed.as_secs());
debug!(
route_id = %self.route_id,
tenant = tenant,
current_used = current_used,
limit = self.config.limit,
retry_after_secs = retry_after,
"Budget exhausted"
);
BudgetCheckResult::Exhausted {
retry_after_secs: retry_after,
}
} else {
let over_by = would_use - self.config.limit;
let remaining = (self.config.limit as i64) - (would_use as i64);
debug!(
route_id = %self.route_id,
tenant = tenant,
over_by = over_by,
"Budget exceeded (not enforced)"
);
BudgetCheckResult::Soft { remaining, over_by }
}
}
pub fn record(&self, tenant: &str, actual_tokens: u64) -> Vec<BudgetAlert> {
let state = self.get_or_create_tenant(tenant);
let period_secs = self.config.period.as_secs();
let elapsed = state.elapsed();
if elapsed.as_secs() >= period_secs {
drop(state);
self.reset_period(tenant);
return self.record(tenant, actual_tokens);
}
state.add_tokens(actual_tokens);
let new_total = state.tokens_used();
trace!(
route_id = %self.route_id,
tenant = tenant,
tokens = actual_tokens,
total = new_total,
limit = self.config.limit,
"Recorded token usage"
);
let mut alerts = Vec::new();
let usage_pct = new_total as f64 / self.config.limit as f64;
for (idx, &threshold) in self.config.alert_thresholds.iter().enumerate() {
if usage_pct >= threshold && !state.has_fired_alert(idx as u8) {
state.mark_alert_fired(idx as u8);
let alert = BudgetAlert {
tenant: tenant.to_string(),
threshold,
tokens_used: new_total,
tokens_limit: self.config.limit,
period_start: state.period_start_unix,
};
info!(
route_id = %self.route_id,
tenant = tenant,
threshold_pct = threshold * 100.0,
tokens_used = new_total,
tokens_limit = self.config.limit,
"Budget alert threshold crossed"
);
alerts.push(alert);
}
}
alerts
}
pub fn status(&self, tenant: &str) -> TenantBudgetStatus {
let state = self.get_or_create_tenant(tenant);
let period_secs = self.config.period.as_secs();
let elapsed = state.elapsed();
let tokens_used = state.tokens_used();
let tokens_remaining = self.config.limit.saturating_sub(tokens_used);
let usage_percent = (tokens_used as f64 / self.config.limit as f64) * 100.0;
let period_end = state.period_start_unix + period_secs;
TenantBudgetStatus {
tokens_used,
tokens_limit: self.config.limit,
tokens_remaining,
usage_percent,
period_start: state.period_start_unix,
period_end,
exhausted: tokens_used >= self.config.limit && self.config.enforce,
}
}
pub fn reset_period(&self, tenant: &str) {
if let Some(mut state) = self.tenants.get_mut(tenant) {
let old_tokens = state.tokens_used();
if self.config.rollover && old_tokens < self.config.limit {
let unused = self.config.limit - old_tokens;
state.reset();
let rollover = unused.min(self.config.limit);
state.add_tokens(rollover);
info!(
route_id = %self.route_id,
tenant = tenant,
rollover_tokens = rollover,
"Period reset with rollover"
);
} else {
state.reset();
debug!(
route_id = %self.route_id,
tenant = tenant,
previous_tokens = old_tokens,
"Period reset"
);
}
}
}
pub fn tenant_count(&self) -> usize {
self.tenants.len()
}
pub fn period_secs(&self) -> u64 {
self.config.period.as_secs()
}
pub fn limit(&self) -> u64 {
self.config.limit
}
pub fn is_enforced(&self) -> bool {
self.config.enforce
}
fn get_or_create_tenant(
&self,
tenant: &str,
) -> dashmap::mapref::one::Ref<'_, String, TenantBudgetState> {
self.tenants
.entry(tenant.to_string())
.or_insert_with(TenantBudgetState::new);
self.tenants.get(tenant).expect("Just inserted")
}
}
#[cfg(test)]
mod tests {
use super::*;
fn test_config() -> TokenBudgetConfig {
TokenBudgetConfig {
period: BudgetPeriod::Custom { seconds: 60 },
limit: 1000,
alert_thresholds: vec![0.50, 0.80, 0.95],
enforce: true,
rollover: false,
burst_allowance: None,
}
}
#[test]
fn test_check_allowed() {
let tracker = TokenBudgetTracker::new(test_config(), "test-route");
let result = tracker.check("tenant-1", 100);
assert!(result.is_allowed());
if let BudgetCheckResult::Allowed { remaining } = result {
assert_eq!(remaining, 900);
} else {
panic!("Expected Allowed result");
}
}
#[test]
fn test_check_exhausted() {
let tracker = TokenBudgetTracker::new(test_config(), "test-route");
tracker.record("tenant-1", 1000);
let result = tracker.check("tenant-1", 100);
assert!(!result.is_allowed());
if let BudgetCheckResult::Exhausted { retry_after_secs } = result {
assert!(retry_after_secs > 0);
} else {
panic!("Expected Exhausted result");
}
}
#[test]
fn test_record_alerts() {
let tracker = TokenBudgetTracker::new(test_config(), "test-route");
let alerts = tracker.record("tenant-1", 500);
assert_eq!(alerts.len(), 1);
assert!((alerts[0].threshold - 0.50).abs() < 0.001);
let alerts = tracker.record("tenant-1", 300);
assert_eq!(alerts.len(), 1);
assert!((alerts[0].threshold - 0.80).abs() < 0.001);
let alerts = tracker.record("tenant-1", 200);
assert_eq!(alerts.len(), 1);
assert!((alerts[0].threshold - 0.95).abs() < 0.001);
let alerts = tracker.record("tenant-1", 100);
assert!(alerts.is_empty());
}
#[test]
fn test_status() {
let tracker = TokenBudgetTracker::new(test_config(), "test-route");
tracker.record("tenant-1", 400);
let status = tracker.status("tenant-1");
assert_eq!(status.tokens_used, 400);
assert_eq!(status.tokens_limit, 1000);
assert_eq!(status.tokens_remaining, 600);
assert!((status.usage_percent - 40.0).abs() < 0.001);
assert!(!status.exhausted);
}
#[test]
fn test_burst_allowance() {
let mut config = test_config();
config.burst_allowance = Some(0.10);
let tracker = TokenBudgetTracker::new(config, "test-route");
tracker.record("tenant-1", 950);
let result = tracker.check("tenant-1", 100);
assert!(result.is_allowed());
if let BudgetCheckResult::Soft { remaining, over_by } = result {
assert_eq!(over_by, 50);
assert_eq!(remaining, -50);
} else {
panic!("Expected Soft result");
}
}
#[test]
fn test_no_enforcement() {
let mut config = test_config();
config.enforce = false;
let tracker = TokenBudgetTracker::new(config, "test-route");
tracker.record("tenant-1", 1000);
let result = tracker.check("tenant-1", 100);
assert!(result.is_allowed());
}
#[test]
fn test_period_reset() {
let tracker = TokenBudgetTracker::new(test_config(), "test-route");
tracker.record("tenant-1", 500);
assert_eq!(tracker.status("tenant-1").tokens_used, 500);
tracker.reset_period("tenant-1");
assert_eq!(tracker.status("tenant-1").tokens_used, 0);
}
#[test]
fn test_rollover() {
let mut config = test_config();
config.rollover = true;
let tracker = TokenBudgetTracker::new(config, "test-route");
tracker.record("tenant-1", 300);
tracker.reset_period("tenant-1");
let status = tracker.status("tenant-1");
assert_eq!(status.tokens_used, 700);
}
#[test]
fn test_multiple_tenants() {
let tracker = TokenBudgetTracker::new(test_config(), "test-route");
tracker.record("tenant-1", 500);
tracker.record("tenant-2", 200);
assert_eq!(tracker.status("tenant-1").tokens_used, 500);
assert_eq!(tracker.status("tenant-2").tokens_used, 200);
assert_eq!(tracker.tenant_count(), 2);
}
}