Skip to main content

fraiseql_server/middleware/rate_limit/
dispatch.rs

1//! Rate limiter enum dispatch — routes calls to the active backend.
2//!
3//! `RateLimiter` is the public handle used by the rest of the server.
4//! It wraps either the in-memory or the Redis backend behind a uniform
5//! async API so callers never need to know which backend is active.
6
7#[cfg(feature = "redis-rate-limiting")]
8use super::redis::RedisRateLimiter;
9use super::{
10    config::{CheckResult, RateLimitConfig, RateLimitingSecurityConfig},
11    in_memory::InMemoryRateLimiter,
12};
13
14/// Rate limiter that dispatches to either an in-memory or Redis backend.
15///
16/// Construct via [`RateLimiter::new`] (in-memory, default) or
17/// `RateLimiter::new_redis` (distributed Redis, requires the
18/// `redis-rate-limiting` Cargo feature).
19#[non_exhaustive]
20pub enum RateLimiter {
21    /// Single-node token-bucket limiter backed by `HashMap` with `RwLock`.
22    InMemory(InMemoryRateLimiter),
23    /// Distributed token-bucket limiter backed by Redis Lua scripts.
24    #[cfg(feature = "redis-rate-limiting")]
25    Redis(RedisRateLimiter),
26}
27
28impl RateLimiter {
29    /// Create an in-memory rate limiter.
30    pub fn new(config: RateLimitConfig) -> Self {
31        Self::InMemory(InMemoryRateLimiter::new(config))
32    }
33
34    /// Create a Redis-backed distributed rate limiter.
35    ///
36    /// # Errors
37    ///
38    /// Returns an error if the Redis URL is invalid or the initial connection
39    /// attempt fails.
40    #[cfg(feature = "redis-rate-limiting")]
41    pub async fn new_redis(url: &str, config: RateLimitConfig) -> Result<Self, redis::RedisError> {
42        let rl = RedisRateLimiter::new(url, config).await?;
43        Ok(Self::Redis(rl))
44    }
45
46    /// Attach per-path rules from `[security.rate_limiting]` auth endpoint fields.
47    #[must_use]
48    pub fn with_path_rules_from_security(self, sec: &RateLimitingSecurityConfig) -> Self {
49        match self {
50            Self::InMemory(rl) => Self::InMemory(rl.with_path_rules_from_security(sec)),
51            #[cfg(feature = "redis-rate-limiting")]
52            Self::Redis(rl) => Self::Redis(rl.with_path_rules_from_security(sec)),
53        }
54    }
55
56    /// Return the active rate limit configuration.
57    pub const fn config(&self) -> &RateLimitConfig {
58        match self {
59            Self::InMemory(rl) => rl.config(),
60            #[cfg(feature = "redis-rate-limiting")]
61            Self::Redis(rl) => rl.config(),
62        }
63    }
64
65    /// Number of per-path rate limit rules registered.
66    pub const fn path_rule_count(&self) -> usize {
67        match self {
68            Self::InMemory(rl) => rl.path_rule_count(),
69            #[cfg(feature = "redis-rate-limiting")]
70            Self::Redis(rl) => rl.path_rule_count(),
71        }
72    }
73
74    /// Seconds a client should wait before retrying after a per-path rate limit rejection.
75    ///
76    /// Returns the window duration for the matching path rule (e.g. 60s for an
77    /// auth/start rule with 5 req/60s), not the IP token-bucket interval.
78    pub fn retry_after_for_path(&self, path: &str) -> u32 {
79        match self {
80            Self::InMemory(rl) => rl.retry_after_for_path(path),
81            #[cfg(feature = "redis-rate-limiting")]
82            Self::Redis(rl) => rl.retry_after_for_path(path),
83        }
84    }
85
86    /// Check whether a request from `ip` is within the global IP rate limit.
87    pub async fn check_ip_limit(&self, ip: &str) -> CheckResult {
88        match self {
89            Self::InMemory(rl) => rl.check_ip_limit(ip).await,
90            #[cfg(feature = "redis-rate-limiting")]
91            Self::Redis(rl) => rl.check_ip_limit(ip).await,
92        }
93    }
94
95    /// Check whether a request from `user_id` is within the per-user limit.
96    pub async fn check_user_limit(&self, user_id: &str) -> CheckResult {
97        match self {
98            Self::InMemory(rl) => rl.check_user_limit(user_id).await,
99            #[cfg(feature = "redis-rate-limiting")]
100            Self::Redis(rl) => rl.check_user_limit(user_id).await,
101        }
102    }
103
104    /// Check the per-path rate limit for a request from `ip` to `path`.
105    ///
106    /// Returns an allowed [`CheckResult`] when no rule governs the path.
107    /// `CheckResult::retry_after_secs` reflects the actual per-path window, not
108    /// the global IP rate.
109    pub async fn check_path_limit(&self, path: &str, ip: &str) -> CheckResult {
110        match self {
111            Self::InMemory(rl) => rl.check_path_limit(path, ip).await,
112            #[cfg(feature = "redis-rate-limiting")]
113            Self::Redis(rl) => rl.check_path_limit(path, ip).await,
114        }
115    }
116
117    /// Evict stale in-memory buckets.
118    ///
119    /// No-op for the Redis backend — Redis handles expiry via `PEXPIRE`.
120    pub async fn cleanup(&self) {
121        match self {
122            Self::InMemory(rl) => rl.cleanup().await,
123            #[cfg(feature = "redis-rate-limiting")]
124            Self::Redis(_) => {},
125        }
126    }
127
128    /// Conservative static estimate of how long (in seconds) a client must wait
129    /// before the IP-level bucket refills one token: `ceil(1 / rps_per_ip)`.
130    ///
131    /// Used when no backend-computed `retry_after_ms` is available (e.g., the
132    /// in-memory backend before the precise value is plumbed end-to-end, or as
133    /// a fallback on Redis errors).  Minimum 1 second.
134    #[must_use]
135    pub fn retry_after_secs(&self) -> u32 {
136        let rps = self.config().rps_per_ip;
137        if rps == 0 {
138            return 1;
139        }
140        #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
141        // Reason: ceil(1/rps) is always a small positive integer
142        {
143            ((1.0_f64 / f64::from(rps)).ceil() as u32).max(1)
144        }
145    }
146}
147
148#[cfg(test)]
149mod tests {
150    use super::*;
151
152    #[test]
153    fn new_creates_in_memory_backend() {
154        let limiter = RateLimiter::new(RateLimitConfig::default());
155        assert!(matches!(limiter, RateLimiter::InMemory(_)));
156    }
157
158    #[test]
159    fn config_returns_reference_to_inner_config() {
160        let config = RateLimitConfig {
161            rps_per_ip: 42,
162            ..RateLimitConfig::default()
163        };
164        let limiter = RateLimiter::new(config);
165        assert_eq!(limiter.config().rps_per_ip, 42);
166    }
167
168    #[test]
169    fn path_rule_count_starts_at_zero() {
170        let limiter = RateLimiter::new(RateLimitConfig::default());
171        assert_eq!(limiter.path_rule_count(), 0);
172    }
173
174    #[test]
175    fn retry_after_secs_minimum_is_one() {
176        let config = RateLimitConfig {
177            rps_per_ip: u32::MAX,
178            ..RateLimitConfig::default()
179        };
180        let limiter = RateLimiter::new(config);
181        assert_eq!(limiter.retry_after_secs(), 1, "minimum retry_after must be 1s");
182    }
183}