#[cfg(feature = "redis-rate-limiting")]
use super::redis::RedisRateLimiter;
use super::{
config::{CheckResult, RateLimitConfig, RateLimitingSecurityConfig},
in_memory::InMemoryRateLimiter,
};
#[non_exhaustive]
pub enum RateLimiter {
InMemory(InMemoryRateLimiter),
#[cfg(feature = "redis-rate-limiting")]
Redis(RedisRateLimiter),
}
impl RateLimiter {
pub fn new(config: RateLimitConfig) -> Self {
Self::InMemory(InMemoryRateLimiter::new(config))
}
#[cfg(feature = "redis-rate-limiting")]
pub async fn new_redis(url: &str, config: RateLimitConfig) -> Result<Self, redis::RedisError> {
let rl = RedisRateLimiter::new(url, config).await?;
Ok(Self::Redis(rl))
}
#[must_use]
pub fn with_path_rules_from_security(self, sec: &RateLimitingSecurityConfig) -> Self {
match self {
Self::InMemory(rl) => Self::InMemory(rl.with_path_rules_from_security(sec)),
#[cfg(feature = "redis-rate-limiting")]
Self::Redis(rl) => Self::Redis(rl.with_path_rules_from_security(sec)),
}
}
pub const fn config(&self) -> &RateLimitConfig {
match self {
Self::InMemory(rl) => rl.config(),
#[cfg(feature = "redis-rate-limiting")]
Self::Redis(rl) => rl.config(),
}
}
pub const fn path_rule_count(&self) -> usize {
match self {
Self::InMemory(rl) => rl.path_rule_count(),
#[cfg(feature = "redis-rate-limiting")]
Self::Redis(rl) => rl.path_rule_count(),
}
}
pub fn retry_after_for_path(&self, path: &str) -> u32 {
match self {
Self::InMemory(rl) => rl.retry_after_for_path(path),
#[cfg(feature = "redis-rate-limiting")]
Self::Redis(rl) => rl.retry_after_for_path(path),
}
}
pub async fn check_ip_limit(&self, ip: &str) -> CheckResult {
match self {
Self::InMemory(rl) => rl.check_ip_limit(ip).await,
#[cfg(feature = "redis-rate-limiting")]
Self::Redis(rl) => rl.check_ip_limit(ip).await,
}
}
pub async fn check_user_limit(&self, user_id: &str) -> CheckResult {
match self {
Self::InMemory(rl) => rl.check_user_limit(user_id).await,
#[cfg(feature = "redis-rate-limiting")]
Self::Redis(rl) => rl.check_user_limit(user_id).await,
}
}
pub async fn check_path_limit(&self, path: &str, ip: &str) -> CheckResult {
match self {
Self::InMemory(rl) => rl.check_path_limit(path, ip).await,
#[cfg(feature = "redis-rate-limiting")]
Self::Redis(rl) => rl.check_path_limit(path, ip).await,
}
}
pub async fn cleanup(&self) {
match self {
Self::InMemory(rl) => rl.cleanup().await,
#[cfg(feature = "redis-rate-limiting")]
Self::Redis(_) => {},
}
}
#[must_use]
pub fn retry_after_secs(&self) -> u32 {
let rps = self.config().rps_per_ip;
if rps == 0 {
return 1;
}
#[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
{
((1.0_f64 / f64::from(rps)).ceil() as u32).max(1)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn new_creates_in_memory_backend() {
let limiter = RateLimiter::new(RateLimitConfig::default());
assert!(matches!(limiter, RateLimiter::InMemory(_)));
}
#[test]
fn config_returns_reference_to_inner_config() {
let config = RateLimitConfig {
rps_per_ip: 42,
..RateLimitConfig::default()
};
let limiter = RateLimiter::new(config);
assert_eq!(limiter.config().rps_per_ip, 42);
}
#[test]
fn path_rule_count_starts_at_zero() {
let limiter = RateLimiter::new(RateLimitConfig::default());
assert_eq!(limiter.path_rule_count(), 0);
}
#[test]
fn retry_after_secs_minimum_is_one() {
let config = RateLimitConfig {
rps_per_ip: u32::MAX,
..RateLimitConfig::default()
};
let limiter = RateLimiter::new(config);
assert_eq!(limiter.retry_after_secs(), 1, "minimum retry_after must be 1s");
}
}