use std::sync::Arc;
use dashmap::DashMap;
use tracing::debug;
use super::{
config::{CheckResult, RateLimitConfig, RateLimitingSecurityConfig},
key::{PathRateLimit, path_matches_rule},
token_bucket::TokenBucket,
};
pub struct InMemoryRateLimiter {
pub(super) config: RateLimitConfig,
pub(super) ip_buckets: Arc<DashMap<String, TokenBucket>>,
pub(super) user_buckets: Arc<DashMap<String, TokenBucket>>,
pub(super) path_rules: Vec<PathRateLimit>,
pub(super) path_ip_buckets: Arc<DashMap<(String, String), TokenBucket>>,
pub(super) tenant_buckets: Arc<DashMap<String, TokenBucket>>,
}
impl InMemoryRateLimiter {
pub(super) fn new(config: RateLimitConfig) -> Self {
Self {
config,
ip_buckets: Arc::new(DashMap::new()),
user_buckets: Arc::new(DashMap::new()),
path_rules: Vec::new(),
path_ip_buckets: Arc::new(DashMap::new()),
tenant_buckets: Arc::new(DashMap::new()),
}
}
#[allow(clippy::cast_precision_loss)] #[must_use]
pub(super) fn with_path_rules_from_security(
mut self,
sec: &RateLimitingSecurityConfig,
) -> Self {
let mut rules = Vec::new();
if sec.auth_start_max_requests > 0 && sec.auth_start_window_secs > 0 {
rules.push(PathRateLimit {
path_prefix: "/auth/start".to_string(),
tokens_per_sec: f64::from(sec.auth_start_max_requests)
/ sec.auth_start_window_secs as f64,
burst: f64::from(sec.auth_start_max_requests),
});
}
if sec.auth_callback_max_requests > 0 && sec.auth_callback_window_secs > 0 {
rules.push(PathRateLimit {
path_prefix: "/auth/callback".to_string(),
tokens_per_sec: f64::from(sec.auth_callback_max_requests)
/ sec.auth_callback_window_secs as f64,
burst: f64::from(sec.auth_callback_max_requests),
});
}
if sec.auth_refresh_max_requests > 0 && sec.auth_refresh_window_secs > 0 {
rules.push(PathRateLimit {
path_prefix: "/auth/refresh".to_string(),
tokens_per_sec: f64::from(sec.auth_refresh_max_requests)
/ sec.auth_refresh_window_secs as f64,
burst: f64::from(sec.auth_refresh_max_requests),
});
}
self.path_rules = rules;
self
}
pub(super) async fn check_path_limit(&self, path: &str, ip: &str) -> CheckResult {
if !self.config.enabled {
return CheckResult::allow(f64::from(self.config.burst_size));
}
let rule = self.path_rules.iter().find(|r| path_matches_rule(path, &r.path_prefix));
let Some(rule) = rule else {
return CheckResult::allow(f64::from(self.config.burst_size));
};
let key = (rule.path_prefix.clone(), ip.to_string());
let (tokens_per_sec, burst) = (rule.tokens_per_sec, rule.burst);
if !self.path_ip_buckets.contains_key(&key)
&& self.path_ip_buckets.len() >= self.config.max_buckets
{
debug!(
ip = ip,
path = path,
"Path-IP bucket capacity reached — denying unseen combination"
);
let retry = if tokens_per_sec > 0.0 {
#[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
((1.0_f64 / tokens_per_sec).ceil() as u32).max(1)
} else {
1
};
return CheckResult::deny(retry);
}
let (allowed, remaining) = {
let mut bucket_ref = self
.path_ip_buckets
.entry(key)
.or_insert_with(|| TokenBucket::new(burst, tokens_per_sec));
let bucket = bucket_ref.value_mut();
let allowed = bucket.try_consume(1.0);
let remaining = bucket.token_count();
(allowed, remaining)
};
if allowed {
CheckResult::allow(remaining)
} else {
debug!(ip = ip, path = path, "Per-path rate limit exceeded");
#[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
let retry = if tokens_per_sec > 0.0 {
((1.0_f64 / tokens_per_sec).ceil() as u32).max(1)
} else {
1
};
CheckResult::deny(retry)
}
}
pub(super) const fn config(&self) -> &RateLimitConfig {
&self.config
}
pub(super) async fn check_ip_limit(&self, ip: &str, tenant_id: Option<&str>) -> CheckResult {
if !self.config.enabled {
return CheckResult::allow(f64::from(self.config.burst_size));
}
let key = tenant_id.map_or_else(|| ip.to_string(), |tid| format!("{}:{}", tid, ip));
if !self.ip_buckets.contains_key(&key) && self.ip_buckets.len() >= self.config.max_buckets {
debug!(ip = ip, tenant_id = ?tenant_id, "IP bucket capacity reached — denying unseen IP");
return CheckResult::deny(1);
}
let (allowed, remaining) = {
let mut bucket_ref = self.ip_buckets.entry(key).or_insert_with(|| {
TokenBucket::new(
f64::from(self.config.burst_size),
f64::from(self.config.rps_per_ip),
)
});
let bucket = bucket_ref.value_mut();
let allowed = bucket.try_consume(1.0);
let remaining = bucket.token_count();
(allowed, remaining)
};
if allowed {
CheckResult::allow(remaining)
} else {
debug!(ip = ip, "Rate limit exceeded for IP");
let rps = self.config.rps_per_ip;
#[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
let retry = if rps == 0 {
1
} else {
((1.0_f64 / f64::from(rps)).ceil() as u32).max(1)
};
CheckResult::deny(retry)
}
}
pub(super) async fn check_user_limit(
&self,
user_id: &str,
tenant_id: Option<&str>,
) -> CheckResult {
if !self.config.enabled {
return CheckResult::allow(f64::from(self.config.burst_size));
}
let key =
tenant_id.map_or_else(|| user_id.to_string(), |tid| format!("{}:{}", tid, user_id));
if !self.user_buckets.contains_key(&key)
&& self.user_buckets.len() >= self.config.max_buckets
{
debug!(user_id = user_id, tenant_id = ?tenant_id, "User bucket capacity reached — denying unseen user");
return CheckResult::deny(1);
}
let (allowed, remaining) = {
let mut bucket_ref = self.user_buckets.entry(key).or_insert_with(|| {
TokenBucket::new(
f64::from(self.config.burst_size),
f64::from(self.config.rps_per_user),
)
});
let bucket = bucket_ref.value_mut();
let allowed = bucket.try_consume(1.0);
let remaining = bucket.token_count();
(allowed, remaining)
};
if allowed {
CheckResult::allow(remaining)
} else {
debug!(user_id = user_id, "Rate limit exceeded for user");
let rps = self.config.rps_per_user;
#[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
let retry = if rps == 0 {
1
} else {
((1.0_f64 / f64::from(rps)).ceil() as u32).max(1)
};
CheckResult::deny(retry)
}
}
#[allow(clippy::cast_precision_loss)] pub(super) async fn check_tenant_limit(
&self,
tenant_key: &str,
rps: u32,
burst: u32,
) -> CheckResult {
let bucket_key = format!("tenant:{tenant_key}");
if !self.tenant_buckets.contains_key(&bucket_key)
&& self.tenant_buckets.len() >= self.config.max_buckets
{
debug!(
tenant_key = tenant_key,
"Tenant bucket capacity reached — denying unseen tenant"
);
return CheckResult::deny(1);
}
let (allowed, remaining) = {
let mut bucket_ref = self
.tenant_buckets
.entry(bucket_key)
.or_insert_with(|| TokenBucket::new(f64::from(burst), f64::from(rps)));
let bucket = bucket_ref.value_mut();
let allowed = bucket.try_consume(1.0);
let remaining = bucket.token_count();
(allowed, remaining)
};
if allowed {
CheckResult::allow(remaining)
} else {
debug!(tenant_key = tenant_key, "Per-tenant rate limit exceeded");
#[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
let retry = if rps == 0 {
1
} else {
((1.0_f64 / f64::from(rps)).ceil() as u32).max(1)
};
CheckResult::deny(retry)
}
}
#[allow(clippy::cast_precision_loss)] pub(super) async fn cleanup(&self) {
let ip_refill_secs = if self.config.rps_per_ip == 0 {
self.config.cleanup_interval_secs as f64
} else {
f64::from(self.config.burst_size) / f64::from(self.config.rps_per_ip)
};
let user_refill_secs = if self.config.rps_per_user == 0 {
self.config.cleanup_interval_secs as f64
} else {
f64::from(self.config.burst_size) / f64::from(self.config.rps_per_user)
};
let now = std::time::Instant::now();
let ip_threshold = now
.checked_sub(std::time::Duration::from_secs_f64(ip_refill_secs))
.unwrap_or(now);
let user_threshold = now
.checked_sub(std::time::Duration::from_secs_f64(user_refill_secs))
.unwrap_or(now);
let before_ip = self.ip_buckets.len();
self.ip_buckets.retain(|_, b| b.last_refill >= ip_threshold);
let evicted_ip = before_ip.saturating_sub(self.ip_buckets.len());
let before_user = self.user_buckets.len();
self.user_buckets.retain(|_, b| b.last_refill >= user_threshold);
let evicted_user = before_user.saturating_sub(self.user_buckets.len());
self.path_ip_buckets.retain(|_, b| b.last_refill >= ip_threshold);
self.tenant_buckets.retain(|_, b| b.last_refill >= ip_threshold);
debug!(evicted_ip, evicted_user, "Rate limiter cleanup complete");
}
pub(super) const fn path_rule_count(&self) -> usize {
self.path_rules.len()
}
pub(super) fn retry_after_for_path(&self, path: &str) -> u32 {
if let Some(rule) = self.path_rules.iter().find(|r| path_matches_rule(path, &r.path_prefix))
{
if rule.tokens_per_sec > 0.0 {
#[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
return ((1.0_f64 / rule.tokens_per_sec).ceil() as u32).max(1);
}
}
1
}
}