use std::collections::HashMap;
use std::sync::RwLock;
use tracing::debug;
use super::bucket::TokenBucket;
use super::config::RateLimitConfig;
pub struct RateLimitResult {
pub allowed: bool,
pub remaining: u64,
pub limit: u64,
pub retry_after_secs: u64,
}
pub struct RateLimiter {
config: RateLimitConfig,
buckets: RwLock<HashMap<String, TokenBucket>>,
rejections_total: std::sync::atomic::AtomicU64,
}
impl RateLimiter {
pub fn new(config: RateLimitConfig) -> Self {
Self {
config,
buckets: RwLock::new(HashMap::new()),
rejections_total: std::sync::atomic::AtomicU64::new(0),
}
}
pub fn check(
&self,
user_id: &str,
org_ids: &[String],
plan_tier: Option<&str>,
operation: &str,
) -> RateLimitResult {
if !self.config.enabled {
return RateLimitResult {
allowed: true,
remaining: u64::MAX,
limit: u64::MAX,
retry_after_secs: 0,
};
}
let cost = self.config.operation_cost(operation);
let (qps, burst) = self.resolve_tier(plan_tier);
let user_key = format!("user:{user_id}");
let user_result = self.check_bucket(&user_key, qps, burst, cost);
if !user_result.allowed {
debug!(
user_id = %user_id,
operation = %operation,
cost,
"rate limited (user bucket)"
);
return user_result;
}
for org_id in org_ids {
let org_key = format!("org:{org_id}");
let org_result = self.check_bucket(&org_key, qps * 10, burst * 10, cost);
if !org_result.allowed {
debug!(
user_id = %user_id,
org_id = %org_id,
operation = %operation,
"rate limited (org bucket)"
);
return org_result;
}
}
user_result
}
pub fn check_api_key(
&self,
key_id: &str,
max_qps: u64,
max_burst: u64,
operation: &str,
) -> RateLimitResult {
if !self.config.enabled || max_qps == 0 {
return RateLimitResult {
allowed: true,
remaining: u64::MAX,
limit: u64::MAX,
retry_after_secs: 0,
};
}
let cost = self.config.operation_cost(operation);
let key = format!("apikey:{key_id}");
self.check_bucket(&key, max_qps, max_burst, cost)
}
fn check_bucket(&self, key: &str, qps: u64, burst: u64, cost: u64) -> RateLimitResult {
{
let buckets = self.buckets.read().unwrap_or_else(|p| p.into_inner());
if let Some(bucket) = buckets.get(key) {
let allowed = bucket.try_acquire(cost);
return RateLimitResult {
allowed,
remaining: bucket.available(),
limit: bucket.capacity(),
retry_after_secs: if allowed {
0
} else {
(bucket.retry_after_ms() / 1000).max(1)
},
};
}
}
let mut buckets = self.buckets.write().unwrap_or_else(|p| p.into_inner());
let bucket = buckets
.entry(key.to_string())
.or_insert_with(|| TokenBucket::new(burst, qps as f64));
let allowed = bucket.try_acquire(cost);
RateLimitResult {
allowed,
remaining: bucket.available(),
limit: bucket.capacity(),
retry_after_secs: if allowed {
0
} else {
(bucket.retry_after_ms() / 1000).max(1)
},
}
}
fn resolve_tier(&self, plan_tier: Option<&str>) -> (u64, u64) {
if let Some(tier_name) = plan_tier
&& let Some(tier) = self.config.tier(tier_name)
{
return (tier.qps, tier.burst);
}
(self.config.default_qps, self.config.default_burst)
}
pub fn response_headers(result: &RateLimitResult) -> Vec<(String, String)> {
vec![
("X-RateLimit-Limit".into(), result.limit.to_string()),
("X-RateLimit-Remaining".into(), result.remaining.to_string()),
(
"X-RateLimit-Reset".into(),
result.retry_after_secs.to_string(),
),
]
}
pub fn retry_after_header(result: &RateLimitResult) -> Option<(String, String)> {
if result.allowed {
None
} else {
Some(("Retry-After".into(), result.retry_after_secs.to_string()))
}
}
pub fn record_rejection(&self) -> u64 {
self.rejections_total
.fetch_add(1, std::sync::atomic::Ordering::Relaxed)
+ 1
}
pub fn rejections_total(&self) -> u64 {
self.rejections_total
.load(std::sync::atomic::Ordering::Relaxed)
}
pub fn is_enabled(&self) -> bool {
self.config.enabled
}
pub fn active_buckets(&self) -> usize {
self.buckets.read().unwrap_or_else(|p| p.into_inner()).len()
}
pub fn config(&self) -> &RateLimitConfig {
&self.config
}
}
impl Default for RateLimiter {
fn default() -> Self {
Self::new(RateLimitConfig::default())
}
}
#[cfg(test)]
mod tests {
use super::*;
fn enabled_config() -> RateLimitConfig {
use crate::control::security::ratelimit::config::RateLimitTier;
let mut config = RateLimitConfig {
enabled: true,
default_qps: 10,
default_burst: 20,
..Default::default()
};
config.tiers.insert(
"pro".into(),
RateLimitTier {
qps: 5000,
burst: 10000,
},
);
config
}
#[test]
fn disabled_allows_all() {
let limiter = RateLimiter::new(RateLimitConfig::default());
let result = limiter.check("u1", &[], None, "point_get");
assert!(result.allowed);
}
#[test]
fn basic_rate_limiting() {
let limiter = RateLimiter::new(enabled_config());
for _ in 0..20 {
let r = limiter.check("u1", &[], None, "point_get");
assert!(r.allowed);
}
let r = limiter.check("u1", &[], None, "point_get");
assert!(!r.allowed);
assert!(r.retry_after_secs > 0);
}
#[test]
fn cost_multiplier_drains_faster() {
let limiter = RateLimiter::new(enabled_config());
let r = limiter.check("u1", &[], None, "vector_search");
assert!(r.allowed);
let r = limiter.check("u1", &[], None, "vector_search");
assert!(!r.allowed);
}
#[test]
fn tier_resolution() {
let limiter = RateLimiter::new(enabled_config());
for _ in 0..100 {
let r = limiter.check("u1", &[], Some("pro"), "point_get");
assert!(r.allowed);
}
}
#[test]
fn per_user_isolation() {
let limiter = RateLimiter::new(enabled_config());
for _ in 0..20 {
limiter.check("u1", &[], None, "point_get");
}
let r = limiter.check("u1", &[], None, "point_get");
assert!(!r.allowed);
let r = limiter.check("u2", &[], None, "point_get");
assert!(r.allowed);
}
#[test]
fn response_headers() {
let result = RateLimitResult {
allowed: true,
remaining: 50,
limit: 100,
retry_after_secs: 0,
};
let headers = RateLimiter::response_headers(&result);
assert_eq!(headers.len(), 3);
assert_eq!(headers[0].0, "X-RateLimit-Limit");
assert_eq!(headers[0].1, "100");
}
}