Skip to main content

gatekpr_rate_limiter/
state.rs

1//! Rate limit state tracking
2//!
3//! Provides per-key rate limit tracking with sliding window counters.
4
5use crate::config::RateLimitConfig;
6use dashmap::DashMap;
7use std::sync::Arc;
8use std::time::{Duration, Instant};
9
10/// Result of a rate limit check
11#[derive(Debug, Clone, PartialEq, Eq)]
12pub enum RateLimitResult {
13    /// Request is allowed
14    Allowed {
15        /// Remaining requests in the minute window
16        remaining_minute: u32,
17        /// Remaining requests in the hour window
18        remaining_hour: u32,
19    },
20    /// Request is rate limited
21    Exceeded {
22        /// Seconds until the limit resets
23        retry_after: u64,
24        /// Which limit was exceeded ("minute" or "hour")
25        limit_type: &'static str,
26    },
27}
28
29impl RateLimitResult {
30    /// Check if the request was allowed
31    pub fn is_allowed(&self) -> bool {
32        matches!(self, RateLimitResult::Allowed { .. })
33    }
34
35    /// Check if the request was rate limited
36    pub fn is_exceeded(&self) -> bool {
37        matches!(self, RateLimitResult::Exceeded { .. })
38    }
39
40    /// Get the retry-after value if rate limited
41    pub fn retry_after(&self) -> Option<u64> {
42        match self {
43            RateLimitResult::Exceeded { retry_after, .. } => Some(*retry_after),
44            _ => None,
45        }
46    }
47}
48
49/// Rate limit state for a single key
50#[derive(Debug)]
51pub struct KeyRateLimit {
52    minute_count: u32,
53    hour_count: u32,
54    minute_reset: Instant,
55    hour_reset: Instant,
56}
57
58impl KeyRateLimit {
59    /// Create a new rate limit state
60    pub fn new() -> Self {
61        let now = Instant::now();
62        Self {
63            minute_count: 0,
64            hour_count: 0,
65            minute_reset: now + Duration::from_secs(60),
66            hour_reset: now + Duration::from_secs(3600),
67        }
68    }
69
70    /// Check the rate limit and increment counters if allowed
71    ///
72    /// Returns the result of the rate limit check.
73    pub fn check_and_increment(&mut self, config: &RateLimitConfig) -> RateLimitResult {
74        let now = Instant::now();
75
76        // Reset minute counter if window has passed
77        if now >= self.minute_reset {
78            self.minute_count = 0;
79            self.minute_reset = now + Duration::from_secs(60);
80        }
81
82        // Reset hour counter if window has passed
83        if now >= self.hour_reset {
84            self.hour_count = 0;
85            self.hour_reset = now + Duration::from_secs(3600);
86        }
87
88        // Check minute limit
89        if self.minute_count >= config.requests_per_minute {
90            let retry_after = self.minute_reset.duration_since(now).as_secs().max(1);
91            return RateLimitResult::Exceeded {
92                retry_after,
93                limit_type: "minute",
94            };
95        }
96
97        // Check hour limit
98        if self.hour_count >= config.requests_per_hour {
99            let retry_after = self.hour_reset.duration_since(now).as_secs().max(1);
100            return RateLimitResult::Exceeded {
101                retry_after,
102                limit_type: "hour",
103            };
104        }
105
106        // Increment counters
107        self.minute_count += 1;
108        self.hour_count += 1;
109
110        RateLimitResult::Allowed {
111            remaining_minute: config.requests_per_minute - self.minute_count,
112            remaining_hour: config.requests_per_hour - self.hour_count,
113        }
114    }
115
116    /// Check if this rate limit state has expired (inactive for over an hour)
117    pub fn is_expired(&self) -> bool {
118        Instant::now() >= self.hour_reset
119    }
120
121    /// Get the current minute count
122    pub fn minute_count(&self) -> u32 {
123        self.minute_count
124    }
125
126    /// Get the current hour count
127    pub fn hour_count(&self) -> u32 {
128        self.hour_count
129    }
130}
131
132impl Default for KeyRateLimit {
133    fn default() -> Self {
134        Self::new()
135    }
136}
137
138/// Thread-safe rate limit store using DashMap
139///
140/// Provides lock-free concurrent access to rate limit states.
141#[derive(Clone)]
142pub struct RateLimitStore {
143    state: Arc<DashMap<String, KeyRateLimit>>,
144}
145
146impl RateLimitStore {
147    /// Create a new rate limit store
148    pub fn new() -> Self {
149        Self {
150            state: Arc::new(DashMap::with_capacity(1000)),
151        }
152    }
153
154    /// Create a store with pre-allocated capacity
155    pub fn with_capacity(capacity: usize) -> Self {
156        Self {
157            state: Arc::new(DashMap::with_capacity(capacity)),
158        }
159    }
160
161    /// Check and increment the rate limit for a key
162    pub fn check(&self, key: &str, config: &RateLimitConfig) -> RateLimitResult {
163        let mut entry = self.state.entry(key.to_string()).or_default();
164        entry.check_and_increment(config)
165    }
166
167    /// Clean up expired entries
168    ///
169    /// Should be called periodically to prevent memory growth.
170    pub fn cleanup_expired(&self) {
171        self.state.retain(|_, limit| !limit.is_expired());
172    }
173
174    /// Get the number of tracked keys
175    pub fn len(&self) -> usize {
176        self.state.len()
177    }
178
179    /// Check if the store is empty
180    pub fn is_empty(&self) -> bool {
181        self.state.is_empty()
182    }
183
184    /// Remove a specific key from tracking
185    pub fn remove(&self, key: &str) {
186        self.state.remove(key);
187    }
188
189    /// Clear all tracked keys
190    pub fn clear(&self) {
191        self.state.clear();
192    }
193
194    /// Spawn a background task to periodically clean up expired entries
195    ///
196    /// Returns a join handle that can be used to abort the task if needed.
197    /// The task logs cleanup activity at debug level.
198    ///
199    /// # Arguments
200    /// * `interval` - How often to run cleanup (recommended: 1 hour)
201    ///
202    /// # Example
203    /// ```ignore
204    /// let store = Arc::new(RateLimitStore::new());
205    /// let handle = store.clone().spawn_cleanup_task(Duration::from_secs(3600));
206    /// // Later: handle.abort() to stop cleanup
207    /// ```
208    #[cfg(feature = "cleanup-task")]
209    pub fn spawn_cleanup_task(
210        self: Arc<Self>,
211        interval: std::time::Duration,
212    ) -> tokio::task::JoinHandle<()> {
213        tokio::spawn(async move {
214            let mut ticker = tokio::time::interval(interval);
215            // Skip the immediate first tick
216            ticker.tick().await;
217            loop {
218                ticker.tick().await;
219                let before = self.len();
220                self.cleanup_expired();
221                let after = self.len();
222                if before != after {
223                    tracing::debug!(
224                        before = before,
225                        after = after,
226                        removed = before - after,
227                        "Rate limiter cleanup completed"
228                    );
229                }
230            }
231        })
232    }
233}
234
235impl Default for RateLimitStore {
236    fn default() -> Self {
237        Self::new()
238    }
239}
240
241#[cfg(test)]
242mod tests {
243    use super::*;
244
245    #[test]
246    fn test_key_rate_limit_new() {
247        let limit = KeyRateLimit::new();
248        assert_eq!(limit.minute_count, 0);
249        assert_eq!(limit.hour_count, 0);
250    }
251
252    #[test]
253    fn test_check_and_increment_allowed() {
254        let mut limit = KeyRateLimit::new();
255        let config = RateLimitConfig::for_plan("free");
256
257        let result = limit.check_and_increment(&config);
258        assert!(result.is_allowed());
259        assert_eq!(limit.minute_count, 1);
260        assert_eq!(limit.hour_count, 1);
261    }
262
263    #[test]
264    fn test_check_and_increment_exceeded_minute() {
265        let mut limit = KeyRateLimit::new();
266        let config = RateLimitConfig::for_plan("free"); // 20/min
267
268        // Use up all minute requests
269        for _ in 0..20 {
270            let result = limit.check_and_increment(&config);
271            assert!(result.is_allowed());
272        }
273
274        // Next request should be rate limited
275        let result = limit.check_and_increment(&config);
276        assert!(result.is_exceeded());
277        assert!(result.retry_after().unwrap() > 0);
278    }
279
280    #[test]
281    fn test_rate_limit_result_methods() {
282        let allowed = RateLimitResult::Allowed {
283            remaining_minute: 10,
284            remaining_hour: 100,
285        };
286        assert!(allowed.is_allowed());
287        assert!(!allowed.is_exceeded());
288        assert!(allowed.retry_after().is_none());
289
290        let exceeded = RateLimitResult::Exceeded {
291            retry_after: 30,
292            limit_type: "minute",
293        };
294        assert!(!exceeded.is_allowed());
295        assert!(exceeded.is_exceeded());
296        assert_eq!(exceeded.retry_after(), Some(30));
297    }
298
299    #[test]
300    fn test_store_basic() {
301        let store = RateLimitStore::new();
302        let config = RateLimitConfig::for_plan("free");
303
304        let result = store.check("user1", &config);
305        assert!(result.is_allowed());
306        assert_eq!(store.len(), 1);
307
308        let result = store.check("user2", &config);
309        assert!(result.is_allowed());
310        assert_eq!(store.len(), 2);
311    }
312
313    #[test]
314    fn test_store_remove() {
315        let store = RateLimitStore::new();
316        let config = RateLimitConfig::for_plan("free");
317
318        store.check("user1", &config);
319        assert_eq!(store.len(), 1);
320
321        store.remove("user1");
322        assert_eq!(store.len(), 0);
323    }
324
325    #[test]
326    fn test_store_clear() {
327        let store = RateLimitStore::new();
328        let config = RateLimitConfig::for_plan("free");
329
330        store.check("user1", &config);
331        store.check("user2", &config);
332        assert_eq!(store.len(), 2);
333
334        store.clear();
335        assert!(store.is_empty());
336    }
337}