use std::collections::HashMap;
#[derive(Debug, Clone)]
pub struct TokenBucket {
available: f64,
capacity: f64,
refill_rate: f64,
last_refill: i64,
total_consumed: u64,
total_rejected: u64,
}
impl TokenBucket {
#[must_use]
pub fn new(capacity: f64, refill_rate: f64, now: i64) -> Self {
Self {
available: capacity,
capacity,
refill_rate,
last_refill: now,
total_consumed: 0,
total_rejected: 0,
}
}
fn refill(&mut self, now: i64) {
let elapsed = (now - self.last_refill).max(0) as f64;
self.available = (self.available + elapsed * self.refill_rate).min(self.capacity);
self.last_refill = now;
}
pub fn try_consume(&mut self, tokens: f64, now: i64) -> bool {
self.refill(now);
if self.available >= tokens {
self.available -= tokens;
self.total_consumed += 1;
true
} else {
self.total_rejected += 1;
false
}
}
#[must_use]
pub fn available(&self) -> f64 {
self.available
}
#[must_use]
pub fn total_consumed(&self) -> u64 {
self.total_consumed
}
#[must_use]
pub fn total_rejected(&self) -> u64 {
self.total_rejected
}
}
#[derive(Debug, Clone)]
pub struct RateLimitConfig {
pub per_user_capacity: f64,
pub per_user_refill_rate: f64,
pub global_capacity: f64,
pub global_refill_rate: f64,
pub tokens_per_request: f64,
}
impl Default for RateLimitConfig {
fn default() -> Self {
Self {
per_user_capacity: 20.0,
per_user_refill_rate: 2.0, global_capacity: 500.0,
global_refill_rate: 100.0,
tokens_per_request: 1.0,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum RateLimitDecision {
Allowed,
UserLimitExceeded,
GlobalLimitExceeded,
}
impl RateLimitDecision {
#[must_use]
pub fn is_allowed(self) -> bool {
matches!(self, Self::Allowed)
}
}
#[derive(Debug)]
pub struct RecommendationRateLimiter {
user_buckets: HashMap<String, TokenBucket>,
global_bucket: TokenBucket,
config: RateLimitConfig,
}
impl RecommendationRateLimiter {
#[must_use]
pub fn new(config: RateLimitConfig, now: i64) -> Self {
let global_bucket =
TokenBucket::new(config.global_capacity, config.global_refill_rate, now);
Self {
user_buckets: HashMap::new(),
global_bucket,
config,
}
}
pub fn check_and_consume(&mut self, user_id: &str, now: i64) -> RateLimitDecision {
let tokens = self.config.tokens_per_request;
self.global_bucket.refill(now);
if self.global_bucket.available < tokens {
self.global_bucket.total_rejected += 1;
return RateLimitDecision::GlobalLimitExceeded;
}
let user_bucket = self
.user_buckets
.entry(user_id.to_string())
.or_insert_with(|| {
TokenBucket::new(
self.config.per_user_capacity,
self.config.per_user_refill_rate,
now,
)
});
if !user_bucket.try_consume(tokens, now) {
return RateLimitDecision::UserLimitExceeded;
}
self.global_bucket.available -= tokens;
self.global_bucket.total_consumed += 1;
RateLimitDecision::Allowed
}
#[must_use]
pub fn user_available_tokens(&self, user_id: &str) -> Option<f64> {
self.user_buckets.get(user_id).map(|b| b.available)
}
#[must_use]
pub fn global_available_tokens(&self) -> f64 {
self.global_bucket.available
}
#[must_use]
pub fn total_user_rejections(&self) -> u64 {
self.user_buckets.values().map(|b| b.total_rejected).sum()
}
#[must_use]
pub fn total_global_rejections(&self) -> u64 {
self.global_bucket.total_rejected
}
#[must_use]
pub fn tracked_users(&self) -> usize {
self.user_buckets.len()
}
pub fn reset_user(&mut self, user_id: &str, now: i64) {
self.user_buckets.insert(
user_id.to_string(),
TokenBucket::new(
self.config.per_user_capacity,
self.config.per_user_refill_rate,
now,
),
);
}
}
impl Default for RecommendationRateLimiter {
fn default() -> Self {
Self::new(RateLimitConfig::default(), 0)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_limiter(capacity: f64, rate: f64) -> RecommendationRateLimiter {
let config = RateLimitConfig {
per_user_capacity: capacity,
per_user_refill_rate: rate,
global_capacity: 1000.0,
global_refill_rate: 500.0,
tokens_per_request: 1.0,
};
RecommendationRateLimiter::new(config, 0)
}
#[test]
fn test_new_limiter_allows_first_request() {
let mut rl = make_limiter(5.0, 1.0);
assert_eq!(rl.check_and_consume("alice", 0), RateLimitDecision::Allowed);
}
#[test]
fn test_burst_then_reject() {
let mut rl = make_limiter(3.0, 0.0); assert_eq!(rl.check_and_consume("alice", 0), RateLimitDecision::Allowed);
assert_eq!(rl.check_and_consume("alice", 0), RateLimitDecision::Allowed);
assert_eq!(rl.check_and_consume("alice", 0), RateLimitDecision::Allowed);
assert_eq!(
rl.check_and_consume("alice", 0),
RateLimitDecision::UserLimitExceeded
);
}
#[test]
fn test_refill_after_time() {
let mut rl = make_limiter(2.0, 2.0); rl.check_and_consume("alice", 0);
rl.check_and_consume("alice", 0);
assert_eq!(
rl.check_and_consume("alice", 0),
RateLimitDecision::UserLimitExceeded
);
assert_eq!(rl.check_and_consume("alice", 1), RateLimitDecision::Allowed);
}
#[test]
fn test_different_users_independent() {
let mut rl = make_limiter(1.0, 0.0);
assert_eq!(rl.check_and_consume("alice", 0), RateLimitDecision::Allowed);
assert_eq!(rl.check_and_consume("bob", 0), RateLimitDecision::Allowed);
assert_eq!(
rl.check_and_consume("alice", 0),
RateLimitDecision::UserLimitExceeded
);
assert_eq!(
rl.check_and_consume("bob", 0),
RateLimitDecision::UserLimitExceeded
);
}
#[test]
fn test_global_limit_blocks_all_users() {
let config = RateLimitConfig {
per_user_capacity: 1000.0,
per_user_refill_rate: 0.0,
global_capacity: 2.0,
global_refill_rate: 0.0,
tokens_per_request: 1.0,
};
let mut rl = RecommendationRateLimiter::new(config, 0);
rl.check_and_consume("alice", 0);
rl.check_and_consume("bob", 0);
let decision = rl.check_and_consume("charlie", 0);
assert_eq!(decision, RateLimitDecision::GlobalLimitExceeded);
}
#[test]
fn test_user_available_tokens_before_first_request() {
let rl = make_limiter(5.0, 1.0);
assert!(rl.user_available_tokens("unknown").is_none());
}
#[test]
fn test_user_available_tokens_after_request() {
let mut rl = make_limiter(5.0, 1.0);
rl.check_and_consume("alice", 0);
let avail = rl
.user_available_tokens("alice")
.expect("bucket should exist");
assert!((avail - 4.0).abs() < f64::EPSILON);
}
#[test]
fn test_tracked_users_increases() {
let mut rl = make_limiter(5.0, 1.0);
assert_eq!(rl.tracked_users(), 0);
rl.check_and_consume("alice", 0);
assert_eq!(rl.tracked_users(), 1);
rl.check_and_consume("bob", 0);
assert_eq!(rl.tracked_users(), 2);
}
#[test]
fn test_total_user_rejections() {
let mut rl = make_limiter(1.0, 0.0);
rl.check_and_consume("alice", 0);
rl.check_and_consume("alice", 0); rl.check_and_consume("alice", 0); assert_eq!(rl.total_user_rejections(), 2);
}
#[test]
fn test_reset_user_refills_bucket() {
let mut rl = make_limiter(1.0, 0.0);
rl.check_and_consume("alice", 0);
assert_eq!(
rl.check_and_consume("alice", 0),
RateLimitDecision::UserLimitExceeded
);
rl.reset_user("alice", 0);
assert_eq!(rl.check_and_consume("alice", 0), RateLimitDecision::Allowed);
}
#[test]
fn test_token_bucket_available_increases_after_refill() {
let mut bucket = TokenBucket::new(10.0, 5.0, 0);
bucket.try_consume(5.0, 0); assert!((bucket.available() - 5.0).abs() < f64::EPSILON);
bucket.refill(2); assert!((bucket.available() - 10.0).abs() < f64::EPSILON);
}
#[test]
fn test_rate_limit_decision_is_allowed() {
assert!(RateLimitDecision::Allowed.is_allowed());
assert!(!RateLimitDecision::UserLimitExceeded.is_allowed());
assert!(!RateLimitDecision::GlobalLimitExceeded.is_allowed());
}
#[test]
fn test_token_bucket_stats() {
let mut bucket = TokenBucket::new(3.0, 0.0, 0);
bucket.try_consume(1.0, 0);
bucket.try_consume(1.0, 0);
bucket.try_consume(1.0, 0);
bucket.try_consume(1.0, 0); assert_eq!(bucket.total_consumed(), 3);
assert_eq!(bucket.total_rejected(), 1);
}
#[test]
fn test_global_available_tokens_decreases() {
let mut rl = make_limiter(100.0, 10.0);
let before = rl.global_available_tokens();
rl.check_and_consume("alice", 0);
let after = rl.global_available_tokens();
assert!((before - after - 1.0).abs() < f64::EPSILON);
}
#[test]
fn test_rate_limit_config_default() {
let config = RateLimitConfig::default();
assert!(config.per_user_capacity > 0.0);
assert!(config.per_user_refill_rate > 0.0);
assert!(config.global_capacity > 0.0);
assert!((config.tokens_per_request - 1.0).abs() < f64::EPSILON);
}
}