use dashmap::DashMap;
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::{Duration, Instant};
use tracing::{debug, trace};
use zentinel_config::TokenRateLimit;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TokenRateLimitResult {
Allowed,
TokensExceeded {
retry_after_ms: u64,
},
RequestsExceeded {
retry_after_ms: u64,
},
}
impl TokenRateLimitResult {
pub fn is_allowed(&self) -> bool {
matches!(self, Self::Allowed)
}
pub fn retry_after_ms(&self) -> u64 {
match self {
Self::Allowed => 0,
Self::TokensExceeded { retry_after_ms } => *retry_after_ms,
Self::RequestsExceeded { retry_after_ms } => *retry_after_ms,
}
}
}
struct TokenBucket {
tokens: AtomicU64,
max_tokens: u64,
refill_rate: f64,
last_refill: std::sync::Mutex<Instant>,
}
impl TokenBucket {
fn new(tokens_per_minute: u64, burst_tokens: u64) -> Self {
let refill_rate = tokens_per_minute as f64 / 60_000.0;
Self {
tokens: AtomicU64::new(burst_tokens),
max_tokens: burst_tokens,
refill_rate,
last_refill: std::sync::Mutex::new(Instant::now()),
}
}
fn try_consume(&self, amount: u64) -> Result<(), u64> {
self.refill();
loop {
let current = self.tokens.load(Ordering::Acquire);
if current < amount {
let needed = amount - current;
let wait_ms = (needed as f64 / self.refill_rate).ceil() as u64;
return Err(wait_ms);
}
if self
.tokens
.compare_exchange(
current,
current - amount,
Ordering::AcqRel,
Ordering::Relaxed,
)
.is_ok()
{
return Ok(());
}
}
}
fn refill(&self) {
let mut last = self.last_refill.lock().unwrap();
let now = Instant::now();
let elapsed = now.duration_since(*last);
if elapsed.as_millis() > 0 {
let refill_amount = (elapsed.as_millis() as f64 * self.refill_rate) as u64;
if refill_amount > 0 {
let current = self.tokens.load(Ordering::Acquire);
let new_tokens = (current + refill_amount).min(self.max_tokens);
self.tokens.store(new_tokens, Ordering::Release);
*last = now;
}
}
}
fn current_tokens(&self) -> u64 {
self.refill();
self.tokens.load(Ordering::Acquire)
}
}
pub struct TokenRateLimiter {
token_buckets: DashMap<String, TokenBucket>,
request_buckets: Option<DashMap<String, TokenBucket>>,
config: TokenRateLimit,
}
impl TokenRateLimiter {
pub fn new(config: TokenRateLimit) -> Self {
let request_buckets = config.requests_per_minute.map(|rpm| DashMap::new());
Self {
token_buckets: DashMap::new(),
request_buckets,
config,
}
}
pub fn check(&self, key: &str, estimated_tokens: u64) -> TokenRateLimitResult {
let token_bucket = self
.token_buckets
.entry(key.to_string())
.or_insert_with(|| {
TokenBucket::new(self.config.tokens_per_minute, self.config.burst_tokens)
});
if let Err(retry_ms) = token_bucket.try_consume(estimated_tokens) {
trace!(
key = key,
estimated_tokens = estimated_tokens,
retry_after_ms = retry_ms,
"Token rate limit exceeded"
);
return TokenRateLimitResult::TokensExceeded {
retry_after_ms: retry_ms,
};
}
if let (Some(rpm), Some(ref request_buckets)) =
(self.config.requests_per_minute, &self.request_buckets)
{
let request_bucket = request_buckets.entry(key.to_string()).or_insert_with(|| {
let burst = rpm.max(1) / 6;
TokenBucket::new(rpm, burst.max(1))
});
if let Err(retry_ms) = request_bucket.try_consume(1) {
trace!(
key = key,
retry_after_ms = retry_ms,
"Request rate limit exceeded"
);
return TokenRateLimitResult::RequestsExceeded {
retry_after_ms: retry_ms,
};
}
}
trace!(
key = key,
estimated_tokens = estimated_tokens,
"Rate limit check passed"
);
TokenRateLimitResult::Allowed
}
pub fn record_actual(&self, key: &str, actual_tokens: u64, estimated_tokens: u64) {
if let Some(bucket) = self.token_buckets.get(key) {
if actual_tokens < estimated_tokens {
let refund = estimated_tokens - actual_tokens;
let current = bucket.tokens.load(Ordering::Acquire);
let new_tokens = (current + refund).min(bucket.max_tokens);
bucket.tokens.store(new_tokens, Ordering::Release);
debug!(
key = key,
actual = actual_tokens,
estimated = estimated_tokens,
refund = refund,
"Refunded over-estimated tokens"
);
} else if actual_tokens > estimated_tokens {
let extra = actual_tokens - estimated_tokens;
let current = bucket.tokens.load(Ordering::Acquire);
let to_consume = extra.min(current);
if to_consume > 0 {
bucket.tokens.fetch_sub(to_consume, Ordering::AcqRel);
}
debug!(
key = key,
actual = actual_tokens,
estimated = estimated_tokens,
consumed_extra = to_consume,
"Consumed under-estimated tokens"
);
}
}
}
pub fn current_tokens(&self, key: &str) -> Option<u64> {
self.token_buckets.get(key).map(|b| b.current_tokens())
}
pub fn stats(&self) -> TokenRateLimiterStats {
TokenRateLimiterStats {
active_keys: self.token_buckets.len(),
tokens_per_minute: self.config.tokens_per_minute,
requests_per_minute: self.config.requests_per_minute,
}
}
}
#[derive(Debug, Clone)]
pub struct TokenRateLimiterStats {
pub active_keys: usize,
pub tokens_per_minute: u64,
pub requests_per_minute: Option<u64>,
}
#[cfg(test)]
mod tests {
use super::*;
use zentinel_config::TokenEstimation;
fn test_config() -> TokenRateLimit {
TokenRateLimit {
tokens_per_minute: 1000,
requests_per_minute: Some(10),
burst_tokens: 200,
estimation_method: TokenEstimation::Chars,
}
}
#[test]
fn test_basic_rate_limiting() {
let limiter = TokenRateLimiter::new(test_config());
let result = limiter.check("test-key", 50);
assert!(result.is_allowed());
let current = limiter.current_tokens("test-key").unwrap();
assert!(current > 0);
}
#[test]
fn test_token_exhaustion() {
let limiter = TokenRateLimiter::new(test_config());
for _ in 0..4 {
let _ = limiter.check("test-key", 50);
}
let result = limiter.check("test-key", 50);
assert!(!result.is_allowed());
assert!(matches!(
result,
TokenRateLimitResult::TokensExceeded { .. }
));
}
#[test]
fn test_actual_token_refund() {
let limiter = TokenRateLimiter::new(test_config());
let _ = limiter.check("test-key", 100);
let before = limiter.current_tokens("test-key").unwrap();
limiter.record_actual("test-key", 50, 100);
let after = limiter.current_tokens("test-key").unwrap();
assert!(after > before);
}
}