use std::collections::HashMap;
use std::num::NonZeroU32;
use std::sync::Arc;
use std::time::Duration;
use dashmap::DashMap;
use governor::{
Quota, RateLimiter as GovernorLimiter,
clock::{Clock, DefaultClock},
state::keyed::DashMapStateStore,
};
#[derive(Clone)]
pub struct RateLimiter {
config: Arc<RateLimitConfig>,
limiters: Arc<DashMap<String, Arc<EndpointLimiter>>>,
}
impl std::fmt::Debug for RateLimiter {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("RateLimiter")
.field("config", &self.config)
.finish_non_exhaustive()
}
}
type KeyedLimiter = GovernorLimiter<RateLimitKey, DashMapStateStore<RateLimitKey>, DefaultClock>;
struct EndpointLimiter {
limiter: KeyedLimiter,
limit: EndpointLimit,
}
#[derive(Debug, Clone)]
pub struct RateLimitConfig {
pub default_limit: EndpointLimit,
pub endpoint_limits: HashMap<String, EndpointLimit>,
pub enabled: bool,
pub cleanup_interval: Duration,
}
#[derive(Debug, Clone)]
pub struct EndpointLimit {
pub requests: u32,
pub window: Duration,
pub burst: u32,
}
#[derive(Debug, Clone, Hash, PartialEq, Eq)]
pub struct RateLimitKey {
pub key_type: String,
pub value: String,
}
#[derive(Debug, Clone)]
pub struct RateLimitInfo {
pub retry_after: Duration,
pub current_count: u32,
pub limit: u32,
pub window: Duration,
}
impl RateLimiter {
pub fn new(config: RateLimitConfig) -> Self {
Self {
config: Arc::new(config),
limiters: Arc::new(DashMap::new()),
}
}
pub fn for_auth() -> Self {
let config = RateLimitConfig::builder()
.default_limit(100, Duration::from_secs(60))
.endpoint_limit(
"login",
EndpointLimit {
requests: 5,
window: Duration::from_secs(60),
burst: 2,
},
)
.endpoint_limit(
"token",
EndpointLimit {
requests: 10,
window: Duration::from_secs(60),
burst: 3,
},
)
.endpoint_limit(
"refresh",
EndpointLimit {
requests: 20,
window: Duration::from_secs(60),
burst: 5,
},
)
.endpoint_limit(
"authorize",
EndpointLimit {
requests: 10,
window: Duration::from_secs(60),
burst: 3,
},
)
.endpoint_limit(
"revoke",
EndpointLimit {
requests: 10,
window: Duration::from_secs(60),
burst: 2,
},
)
.build();
Self::new(config)
}
pub fn disabled() -> Self {
Self::new(RateLimitConfig {
default_limit: EndpointLimit {
requests: u32::MAX,
window: Duration::from_secs(1),
burst: 0,
},
endpoint_limits: HashMap::new(),
enabled: false,
cleanup_interval: Duration::from_secs(3600),
})
}
pub async fn check(&self, key: &RateLimitKey, endpoint: &str) -> Result<(), RateLimitInfo> {
if !self.config.enabled {
return Ok(());
}
let endpoint_limiter = self.endpoint_limiter(endpoint);
let limit = endpoint_limiter.limit.clone();
match endpoint_limiter.limiter.check_key(key) {
Ok(()) => Ok(()),
Err(not_until) => {
let retry_after = not_until.wait_time_from(DefaultClock::default().now());
crate::auth_metrics::record_rate_limited(endpoint, &key.key_type);
Err(RateLimitInfo {
retry_after,
current_count: limit.requests.saturating_add(limit.burst),
limit: limit.requests,
window: limit.window,
})
}
}
}
pub async fn record(&self, key: &RateLimitKey, endpoint: &str) {
if !self.config.enabled {
return;
}
let endpoint_limiter = self.endpoint_limiter(endpoint);
let _ = endpoint_limiter.limiter.check_key(key);
}
pub async fn get_usage(&self, _key: &RateLimitKey, endpoint: &str) -> Option<(u32, u32)> {
if !self.config.enabled {
return None;
}
let endpoint_limiter = self.endpoint_limiter(endpoint);
Some((0, endpoint_limiter.limit.requests))
}
pub async fn reset(&self, key: &RateLimitKey) {
for mut entry in self.limiters.iter_mut() {
let old = entry.value().clone();
let fresh = Arc::new(build_endpoint_limiter(&old.limit));
*entry.value_mut() = fresh;
drop(old);
let _ = key; }
}
pub async fn reset_all(&self) {
self.limiters.clear();
}
fn endpoint_limiter(&self, endpoint: &str) -> Arc<EndpointLimiter> {
if let Some(existing) = self.limiters.get(endpoint) {
return Arc::clone(&*existing);
}
let limit = self
.config
.endpoint_limits
.get(endpoint)
.cloned()
.unwrap_or_else(|| self.config.default_limit.clone());
let new = Arc::new(build_endpoint_limiter(&limit));
let entry = self
.limiters
.entry(endpoint.to_string())
.or_insert_with(|| Arc::clone(&new));
Arc::clone(&*entry)
}
}
fn build_endpoint_limiter(limit: &EndpointLimit) -> EndpointLimiter {
let requests = limit.requests.max(1);
let burst_cap = limit.requests.saturating_add(limit.burst).max(1);
let replenish = limit.window / requests;
let quota = Quota::with_period(replenish)
.unwrap_or_else(|| Quota::per_minute(NonZeroU32::new(1).expect("1 is nonzero")))
.allow_burst(NonZeroU32::new(burst_cap).expect("burst_cap is nonzero"));
EndpointLimiter {
limiter: GovernorLimiter::keyed(quota),
limit: limit.clone(),
}
}
impl RateLimitConfig {
pub fn builder() -> RateLimitConfigBuilder {
RateLimitConfigBuilder::default()
}
}
#[derive(Debug, Default)]
pub struct RateLimitConfigBuilder {
default_limit: Option<EndpointLimit>,
endpoint_limits: HashMap<String, EndpointLimit>,
enabled: bool,
cleanup_interval: Option<Duration>,
}
impl RateLimitConfigBuilder {
pub fn default_limit(mut self, requests: u32, window: Duration) -> Self {
self.default_limit = Some(EndpointLimit {
requests,
window,
burst: requests / 10, });
self.enabled = true;
self
}
pub fn endpoint_limit(mut self, endpoint: impl Into<String>, limit: EndpointLimit) -> Self {
self.endpoint_limits.insert(endpoint.into(), limit);
self.enabled = true;
self
}
pub fn cleanup_interval(mut self, interval: Duration) -> Self {
self.cleanup_interval = Some(interval);
self
}
pub fn enabled(mut self, enabled: bool) -> Self {
self.enabled = enabled;
self
}
pub fn build(self) -> RateLimitConfig {
RateLimitConfig {
default_limit: self.default_limit.unwrap_or(EndpointLimit {
requests: 100,
window: Duration::from_secs(60),
burst: 10,
}),
endpoint_limits: self.endpoint_limits,
enabled: self.enabled,
cleanup_interval: self.cleanup_interval.unwrap_or(Duration::from_secs(300)),
}
}
}
impl RateLimitKey {
pub fn ip(ip: impl Into<String>) -> Self {
Self {
key_type: "ip".to_string(),
value: ip.into(),
}
}
pub fn user(user_id: impl Into<String>) -> Self {
Self {
key_type: "user".to_string(),
value: user_id.into(),
}
}
pub fn api_key_prefix(prefix: impl Into<String>) -> Self {
Self {
key_type: "api_key".to_string(),
value: prefix.into(),
}
}
pub fn session(session_id: impl Into<String>) -> Self {
Self {
key_type: "session".to_string(),
value: session_id.into(),
}
}
pub fn composite(components: Vec<(&str, &str)>) -> Self {
let value = components
.into_iter()
.map(|(k, v)| format!("{k}:{v}"))
.collect::<Vec<_>>()
.join("|");
Self {
key_type: "composite".to_string(),
value,
}
}
}
impl std::fmt::Display for RateLimitInfo {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"Rate limited: {}/{} requests in {:?}, retry after {:?}",
self.current_count, self.limit, self.window, self.retry_after
)
}
}
impl std::error::Error for RateLimitInfo {}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_rate_limiter_allows_under_limit() {
let limiter = RateLimiter::new(
RateLimitConfig::builder()
.default_limit(5, Duration::from_secs(60))
.build(),
);
let key = RateLimitKey::ip("192.168.1.1");
for _ in 0..5 {
assert!(limiter.check(&key, "test").await.is_ok());
}
}
#[tokio::test]
async fn test_rate_limiter_blocks_over_limit() {
let limiter = RateLimiter::new(
RateLimitConfig::builder()
.endpoint_limit(
"test",
EndpointLimit {
requests: 2,
window: Duration::from_secs(60),
burst: 0,
},
)
.build(),
);
let key = RateLimitKey::ip("192.168.1.1");
assert!(limiter.check(&key, "test").await.is_ok());
assert!(limiter.check(&key, "test").await.is_ok());
let result = limiter.check(&key, "test").await;
assert!(result.is_err());
let info = result.unwrap_err();
assert_eq!(info.limit, 2);
assert!(info.retry_after > Duration::ZERO);
}
#[tokio::test]
async fn test_rate_limiter_allows_burst() {
let limiter = RateLimiter::new(
RateLimitConfig::builder()
.endpoint_limit(
"test",
EndpointLimit {
requests: 2,
window: Duration::from_secs(60),
burst: 2,
},
)
.build(),
);
let key = RateLimitKey::ip("192.168.1.1");
for i in 0..4 {
assert!(
limiter.check(&key, "test").await.is_ok(),
"Request {} should be allowed",
i
);
}
assert!(limiter.check(&key, "test").await.is_err());
}
#[tokio::test]
async fn test_rate_limiter_disabled() {
let limiter = RateLimiter::disabled();
let key = RateLimitKey::ip("192.168.1.1");
for _ in 0..1000 {
assert!(limiter.check(&key, "test").await.is_ok());
}
}
#[tokio::test]
async fn test_rate_limiter_different_keys() {
let limiter = RateLimiter::new(
RateLimitConfig::builder()
.endpoint_limit(
"test",
EndpointLimit {
requests: 1,
window: Duration::from_secs(60),
burst: 0,
},
)
.build(),
);
let key1 = RateLimitKey::ip("192.168.1.1");
let key2 = RateLimitKey::ip("192.168.1.2");
assert!(limiter.check(&key1, "test").await.is_ok());
assert!(limiter.check(&key2, "test").await.is_ok());
assert!(limiter.check(&key1, "test").await.is_err());
assert!(limiter.check(&key2, "test").await.is_err());
}
#[tokio::test]
async fn test_rate_limiter_reset() {
let limiter = RateLimiter::new(
RateLimitConfig::builder()
.endpoint_limit(
"test",
EndpointLimit {
requests: 1,
window: Duration::from_secs(60),
burst: 0,
},
)
.build(),
);
let key = RateLimitKey::ip("192.168.1.1");
assert!(limiter.check(&key, "test").await.is_ok());
assert!(limiter.check(&key, "test").await.is_err());
limiter.reset(&key).await;
assert!(limiter.check(&key, "test").await.is_ok());
}
#[tokio::test]
async fn test_rate_limiter_for_auth() {
let limiter = RateLimiter::for_auth();
let key = RateLimitKey::ip("192.168.1.1");
for i in 0..7 {
assert!(
limiter.check(&key, "login").await.is_ok(),
"Login request {} should be allowed",
i
);
}
assert!(limiter.check(&key, "login").await.is_err());
}
#[test]
fn test_rate_limit_key_creation() {
let ip_key = RateLimitKey::ip("10.0.0.1");
assert_eq!(ip_key.key_type, "ip");
assert_eq!(ip_key.value, "10.0.0.1");
let user_key = RateLimitKey::user("user123");
assert_eq!(user_key.key_type, "user");
let composite = RateLimitKey::composite(vec![("ip", "10.0.0.1"), ("endpoint", "/login")]);
assert_eq!(composite.key_type, "composite");
assert_eq!(composite.value, "ip:10.0.0.1|endpoint:/login");
}
#[tokio::test]
async fn test_get_usage_returns_limit_when_enabled() {
let limiter = RateLimiter::new(
RateLimitConfig::builder()
.default_limit(10, Duration::from_secs(60))
.build(),
);
let key = RateLimitKey::ip("192.168.1.1");
let usage = limiter.get_usage(&key, "test").await;
assert_eq!(usage, Some((0, 10)));
let disabled = RateLimiter::disabled();
assert_eq!(disabled.get_usage(&key, "test").await, None);
}
}