use std::collections::HashMap;
use std::sync::RwLock;
use tracing::debug;
use nodedb_types::{DatabaseId, TenantId};
use super::bucket::TokenBucket;
use super::config::RateLimitConfig;
pub struct RateLimitResult {
pub allowed: bool,
pub remaining: u64,
pub limit: u64,
pub retry_after_secs: u64,
}
pub struct QuotaCheckParams {
pub tenant_max_qps: Option<u64>,
pub database_max_qps: Option<u64>,
pub tenant_id: TenantId,
pub database_id: DatabaseId,
}
pub enum LoginRateLimitOutcome {
Allowed,
IpExceeded,
UserExceeded,
}
pub struct RateLimiter {
config: RateLimitConfig,
buckets: RwLock<HashMap<String, TokenBucket>>,
rejections_total: std::sync::atomic::AtomicU64,
login_ip_cap: std::sync::atomic::AtomicU64,
login_user_cap: std::sync::atomic::AtomicU64,
}
impl RateLimiter {
pub fn new(config: RateLimitConfig) -> Self {
Self {
config,
buckets: RwLock::new(HashMap::new()),
rejections_total: std::sync::atomic::AtomicU64::new(0),
login_ip_cap: std::sync::atomic::AtomicU64::new(30),
login_user_cap: std::sync::atomic::AtomicU64::new(10),
}
}
pub fn set_login_capacities(&self, ip_cap: u64, user_cap: u64) {
self.login_ip_cap
.store(ip_cap, std::sync::atomic::Ordering::Relaxed);
self.login_user_cap
.store(user_cap, std::sync::atomic::Ordering::Relaxed);
}
pub fn check_login(&self, peer_addr: &str, username: &str) -> LoginRateLimitOutcome {
let ip_cap = self.login_ip_cap.load(std::sync::atomic::Ordering::Relaxed);
let user_cap = self
.login_user_cap
.load(std::sync::atomic::Ordering::Relaxed);
if ip_cap > 0 {
let ip_key = format!("login_ip:{peer_addr}");
let ip_rate = (ip_cap as f64) / 60.0;
if !self.check_login_bucket(&ip_key, ip_cap, ip_rate) {
return LoginRateLimitOutcome::IpExceeded;
}
}
if user_cap > 0 && !username.is_empty() {
let user_key = format!("login_user:{username}");
let user_rate = (user_cap as f64) / 60.0;
if !self.check_login_bucket(&user_key, user_cap, user_rate) {
return LoginRateLimitOutcome::UserExceeded;
}
}
LoginRateLimitOutcome::Allowed
}
fn check_login_bucket(&self, key: &str, capacity: u64, rate_per_sec: f64) -> bool {
{
let buckets = self.buckets.read().unwrap_or_else(|p| p.into_inner());
if let Some(bucket) = buckets.get(key) {
return bucket.try_acquire(1);
}
}
let mut buckets = self.buckets.write().unwrap_or_else(|p| p.into_inner());
let bucket = buckets
.entry(key.to_string())
.or_insert_with(|| TokenBucket::new(capacity, rate_per_sec));
bucket.try_acquire(1)
}
pub fn check(
&self,
user_id: &str,
org_ids: &[String],
plan_tier: Option<&str>,
operation: &str,
quota: Option<&QuotaCheckParams>,
) -> RateLimitResult {
if !self.config.enabled {
return RateLimitResult {
allowed: true,
remaining: u64::MAX,
limit: u64::MAX,
retry_after_secs: 0,
};
}
let cost = self.config.operation_cost(operation);
let (qps, burst) = self.resolve_tier(plan_tier);
let user_key = format!("user:{user_id}");
let user_result = self.check_bucket(&user_key, qps, burst, cost);
if !user_result.allowed {
debug!(
user_id = %user_id,
operation = %operation,
cost,
"rate limited (user bucket)"
);
return user_result;
}
for org_id in org_ids {
let org_key = format!("org:{org_id}");
let org_result = self.check_bucket(&org_key, qps * 10, burst * 10, cost);
if !org_result.allowed {
debug!(
user_id = %user_id,
org_id = %org_id,
operation = %operation,
"rate limited (org bucket)"
);
return org_result;
}
}
if let Some(q) = quota {
if q.tenant_max_qps.is_some_and(|v| v > 0) {
let tenant_qps = q.tenant_max_qps.unwrap_or(0);
let tenant_key = format!("tenant:{}", q.tenant_id.as_u64());
let tenant_result = self.check_bucket(&tenant_key, tenant_qps, tenant_qps, cost);
if !tenant_result.allowed {
debug!(
tenant_id = q.tenant_id.as_u64(),
operation = %operation,
"rate limited (tenant bucket)"
);
return tenant_result;
}
}
if q.database_max_qps.is_some_and(|v| v > 0) {
let db_qps = q.database_max_qps.unwrap_or(0);
let db_key = format!("database:{}", q.database_id.as_u64());
let db_result = self.check_bucket(&db_key, db_qps, db_qps, cost);
if !db_result.allowed {
debug!(
database_id = q.database_id.as_u64(),
operation = %operation,
"rate limited (database bucket)"
);
return db_result;
}
}
}
user_result
}
pub fn check_api_key(
&self,
key_id: &str,
max_qps: u64,
max_burst: u64,
operation: &str,
) -> RateLimitResult {
if !self.config.enabled || max_qps == 0 {
return RateLimitResult {
allowed: true,
remaining: u64::MAX,
limit: u64::MAX,
retry_after_secs: 0,
};
}
let cost = self.config.operation_cost(operation);
let key = format!("apikey:{key_id}");
self.check_bucket(&key, max_qps, max_burst, cost)
}
fn check_bucket(&self, key: &str, qps: u64, burst: u64, cost: u64) -> RateLimitResult {
{
let buckets = self.buckets.read().unwrap_or_else(|p| p.into_inner());
if let Some(bucket) = buckets.get(key) {
let allowed = bucket.try_acquire(cost);
return RateLimitResult {
allowed,
remaining: bucket.available(),
limit: bucket.capacity(),
retry_after_secs: if allowed {
0
} else {
(bucket.retry_after_ms() / 1000).max(1)
},
};
}
}
let mut buckets = self.buckets.write().unwrap_or_else(|p| p.into_inner());
let bucket = buckets
.entry(key.to_string())
.or_insert_with(|| TokenBucket::new(burst, qps as f64));
let allowed = bucket.try_acquire(cost);
RateLimitResult {
allowed,
remaining: bucket.available(),
limit: bucket.capacity(),
retry_after_secs: if allowed {
0
} else {
(bucket.retry_after_ms() / 1000).max(1)
},
}
}
fn resolve_tier(&self, plan_tier: Option<&str>) -> (u64, u64) {
if let Some(tier_name) = plan_tier
&& let Some(tier) = self.config.tier(tier_name)
{
return (tier.qps, tier.burst);
}
(self.config.default_qps, self.config.default_burst)
}
pub fn response_headers(result: &RateLimitResult) -> Vec<(String, String)> {
vec![
("X-RateLimit-Limit".into(), result.limit.to_string()),
("X-RateLimit-Remaining".into(), result.remaining.to_string()),
(
"X-RateLimit-Reset".into(),
result.retry_after_secs.to_string(),
),
]
}
pub fn retry_after_header(result: &RateLimitResult) -> Option<(String, String)> {
if result.allowed {
None
} else {
Some(("Retry-After".into(), result.retry_after_secs.to_string()))
}
}
pub fn record_rejection(&self) -> u64 {
self.rejections_total
.fetch_add(1, std::sync::atomic::Ordering::Relaxed)
+ 1
}
pub fn rejections_total(&self) -> u64 {
self.rejections_total
.load(std::sync::atomic::Ordering::Relaxed)
}
pub fn is_enabled(&self) -> bool {
self.config.enabled
}
pub fn active_buckets(&self) -> usize {
self.buckets.read().unwrap_or_else(|p| p.into_inner()).len()
}
pub fn config(&self) -> &RateLimitConfig {
&self.config
}
}
impl Default for RateLimiter {
fn default() -> Self {
Self::new(RateLimitConfig::default())
}
}
impl std::fmt::Debug for RateLimiter {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("RateLimiter")
.field(
"login_ip_cap",
&self.login_ip_cap.load(std::sync::atomic::Ordering::Relaxed),
)
.field(
"login_user_cap",
&self
.login_user_cap
.load(std::sync::atomic::Ordering::Relaxed),
)
.finish_non_exhaustive()
}
}
#[cfg(test)]
mod tests {
use super::*;
fn enabled_config() -> RateLimitConfig {
use crate::control::security::ratelimit::config::RateLimitTier;
let mut config = RateLimitConfig {
enabled: true,
default_qps: 10,
default_burst: 20,
..Default::default()
};
config.tiers.insert(
"pro".into(),
RateLimitTier {
qps: 5000,
burst: 10000,
},
);
config
}
#[test]
fn disabled_allows_all() {
let limiter = RateLimiter::new(RateLimitConfig::default());
let result = limiter.check("u1", &[], None, "point_get", None);
assert!(result.allowed);
}
#[test]
fn basic_rate_limiting() {
let limiter = RateLimiter::new(enabled_config());
for _ in 0..20 {
let r = limiter.check("u1", &[], None, "point_get", None);
assert!(r.allowed);
}
let r = limiter.check("u1", &[], None, "point_get", None);
assert!(!r.allowed);
assert!(r.retry_after_secs > 0);
}
#[test]
fn cost_multiplier_drains_faster() {
let limiter = RateLimiter::new(enabled_config());
let r = limiter.check("u1", &[], None, "vector_search", None);
assert!(r.allowed);
let r = limiter.check("u1", &[], None, "vector_search", None);
assert!(!r.allowed);
}
#[test]
fn tier_resolution() {
let limiter = RateLimiter::new(enabled_config());
for _ in 0..100 {
let r = limiter.check("u1", &[], Some("pro"), "point_get", None);
assert!(r.allowed);
}
}
#[test]
fn per_user_isolation() {
let limiter = RateLimiter::new(enabled_config());
for _ in 0..20 {
limiter.check("u1", &[], None, "point_get", None);
}
let r = limiter.check("u1", &[], None, "point_get", None);
assert!(!r.allowed);
let r = limiter.check("u2", &[], None, "point_get", None);
assert!(r.allowed);
}
#[test]
fn response_headers() {
let result = RateLimitResult {
allowed: true,
remaining: 50,
limit: 100,
retry_after_secs: 0,
};
let headers = RateLimiter::response_headers(&result);
assert_eq!(headers.len(), 3);
assert_eq!(headers[0].0, "X-RateLimit-Limit");
assert_eq!(headers[0].1, "100");
}
fn login_limiter(ip_cap: u64, user_cap: u64) -> RateLimiter {
let limiter = RateLimiter::new(RateLimitConfig::default());
limiter.set_login_capacities(ip_cap, user_cap);
limiter
}
#[test]
fn login_rate_limit_ip() {
let limiter = login_limiter(30, 10);
for i in 0..30 {
let outcome = limiter.check_login("10.0.0.1", &format!("user_{i}"));
assert!(
matches!(outcome, LoginRateLimitOutcome::Allowed),
"attempt {i} should be allowed"
);
}
let outcome = limiter.check_login("10.0.0.1", "user_overflow");
assert!(
matches!(outcome, LoginRateLimitOutcome::IpExceeded),
"31st attempt from same IP must be rate-limited"
);
let outcome = limiter.check_login("10.0.0.2", "user_other");
assert!(
matches!(outcome, LoginRateLimitOutcome::Allowed),
"different IP must still be allowed"
);
}
#[test]
fn login_rate_limit_user() {
let limiter = login_limiter(30, 10);
for i in 0..10 {
let outcome = limiter.check_login(&format!("10.0.0.{i}"), "victim");
assert!(
matches!(outcome, LoginRateLimitOutcome::Allowed),
"attempt {i} should be allowed"
);
}
let outcome = limiter.check_login("10.0.0.200", "victim");
assert!(
matches!(outcome, LoginRateLimitOutcome::UserExceeded),
"11th attempt for same user must be rate-limited"
);
let outcome = limiter.check_login("10.0.0.200", "other_user");
assert!(
matches!(outcome, LoginRateLimitOutcome::Allowed),
"different username must still be allowed"
);
}
#[test]
fn login_rate_limit_window() {
let limiter = login_limiter(2, 100);
assert!(matches!(
limiter.check_login("192.0.2.1", "u"),
LoginRateLimitOutcome::Allowed
));
assert!(matches!(
limiter.check_login("192.0.2.1", "u"),
LoginRateLimitOutcome::Allowed
));
assert!(matches!(
limiter.check_login("192.0.2.1", "u"),
LoginRateLimitOutcome::IpExceeded
));
{
let buckets = limiter.buckets.read().unwrap_or_else(|p| p.into_inner());
let bucket = buckets
.get("login_ip:192.0.2.1")
.expect("bucket must exist");
assert_eq!(
bucket.available(),
0,
"bucket must be empty after exhaustion"
);
}
}
#[test]
fn login_rate_limit_audit() {
use crate::control::security::audit::emitter::test_helpers::CapturingEmitter;
use crate::control::security::audit::emitter::{AuditEmitContext, AuditEmitter};
use crate::control::security::audit::event::AuditEvent;
let emitter = CapturingEmitter::new();
emitter.emit(
AuditEvent::LoginRateLimited,
"login_rate_limit",
"ip=10.0.0.1 user=alice",
AuditEmitContext::new(None, "", "alice"),
);
let recorded = emitter.recorded();
assert_eq!(recorded.len(), 1);
assert_eq!(recorded[0].0, AuditEvent::LoginRateLimited);
assert!(recorded[0].2.contains("alice"));
}
#[test]
fn login_rate_limit_constant_time() {
use std::time::Instant;
let limiter = login_limiter(5, 5);
let start = Instant::now();
for i in 0..10 {
let _ = limiter.check_login("10.1.2.3", &format!("user{i}"));
}
let elapsed = start.elapsed();
assert!(
elapsed.as_millis() < 10,
"check_login must be non-blocking; took {elapsed:?}"
);
}
fn db_id() -> DatabaseId {
DatabaseId::DEFAULT
}
fn t_id(n: u64) -> TenantId {
TenantId::new(n)
}
#[test]
fn database_cap_deny_while_tenant_has_headroom() {
let limiter = RateLimiter::new(enabled_config());
let quota = QuotaCheckParams {
tenant_max_qps: Some(1000),
database_max_qps: Some(5),
tenant_id: t_id(1),
database_id: db_id(),
};
for _ in 0..5 {
let r = limiter.check("u1", &[], None, "point_get", Some("a));
assert!(r.allowed, "first 5 should be allowed under database cap");
}
let r = limiter.check("u1", &[], None, "point_get", Some("a));
assert!(
!r.allowed,
"database bucket exhausted — request must be denied"
);
}
#[test]
fn tenant_cap_deny_while_database_has_headroom() {
let limiter = RateLimiter::new(enabled_config());
let quota = QuotaCheckParams {
tenant_max_qps: Some(3),
database_max_qps: Some(1000),
tenant_id: t_id(2),
database_id: db_id(),
};
for _ in 0..3 {
let r = limiter.check("u2", &[], None, "point_get", Some("a));
assert!(r.allowed, "first 3 should be allowed under tenant cap");
}
let r = limiter.check("u2", &[], None, "point_get", Some("a));
assert!(
!r.allowed,
"tenant bucket exhausted — request must be denied"
);
}
#[test]
fn when_both_would_deny_tenant_wins_over_database() {
let limiter = RateLimiter::new(enabled_config());
let quota = QuotaCheckParams {
tenant_max_qps: Some(1),
database_max_qps: Some(1),
tenant_id: t_id(3),
database_id: db_id(),
};
let r = limiter.check("u3", &[], None, "point_get", Some("a));
assert!(r.allowed, "first request should be allowed");
let r2 = limiter.check("u3", &[], None, "point_get", Some("a));
assert!(!r2.allowed, "second request must be denied");
let quota_db_only = QuotaCheckParams {
tenant_max_qps: None,
database_max_qps: Some(1),
tenant_id: t_id(3),
database_id: db_id(),
};
let r3 = limiter.check("u3", &[], None, "point_get", Some("a_db_only));
assert!(!r3.allowed, "database bucket should also be exhausted");
}
#[test]
fn no_quota_params_skips_tenant_and_database_buckets() {
let limiter = RateLimiter::new(enabled_config());
for _ in 0..20 {
let r = limiter.check("u4", &[], None, "point_get", None);
assert!(r.allowed);
}
let r = limiter.check("u4", &[], None, "point_get", None);
assert!(!r.allowed);
}
}