gatekpr_rate_limiter/
state.rs1use crate::config::RateLimitConfig;
6use dashmap::DashMap;
7use std::sync::Arc;
8use std::time::{Duration, Instant};
9
10#[derive(Debug, Clone, PartialEq, Eq)]
12pub enum RateLimitResult {
13 Allowed {
15 remaining_minute: u32,
17 remaining_hour: u32,
19 },
20 Exceeded {
22 retry_after: u64,
24 limit_type: &'static str,
26 },
27}
28
29impl RateLimitResult {
30 pub fn is_allowed(&self) -> bool {
32 matches!(self, RateLimitResult::Allowed { .. })
33 }
34
35 pub fn is_exceeded(&self) -> bool {
37 matches!(self, RateLimitResult::Exceeded { .. })
38 }
39
40 pub fn retry_after(&self) -> Option<u64> {
42 match self {
43 RateLimitResult::Exceeded { retry_after, .. } => Some(*retry_after),
44 _ => None,
45 }
46 }
47}
48
49#[derive(Debug)]
51pub struct KeyRateLimit {
52 minute_count: u32,
53 hour_count: u32,
54 minute_reset: Instant,
55 hour_reset: Instant,
56}
57
58impl KeyRateLimit {
59 pub fn new() -> Self {
61 let now = Instant::now();
62 Self {
63 minute_count: 0,
64 hour_count: 0,
65 minute_reset: now + Duration::from_secs(60),
66 hour_reset: now + Duration::from_secs(3600),
67 }
68 }
69
70 pub fn check_and_increment(&mut self, config: &RateLimitConfig) -> RateLimitResult {
74 let now = Instant::now();
75
76 if now >= self.minute_reset {
78 self.minute_count = 0;
79 self.minute_reset = now + Duration::from_secs(60);
80 }
81
82 if now >= self.hour_reset {
84 self.hour_count = 0;
85 self.hour_reset = now + Duration::from_secs(3600);
86 }
87
88 if self.minute_count >= config.requests_per_minute {
90 let retry_after = self.minute_reset.duration_since(now).as_secs().max(1);
91 return RateLimitResult::Exceeded {
92 retry_after,
93 limit_type: "minute",
94 };
95 }
96
97 if self.hour_count >= config.requests_per_hour {
99 let retry_after = self.hour_reset.duration_since(now).as_secs().max(1);
100 return RateLimitResult::Exceeded {
101 retry_after,
102 limit_type: "hour",
103 };
104 }
105
106 self.minute_count += 1;
108 self.hour_count += 1;
109
110 RateLimitResult::Allowed {
111 remaining_minute: config.requests_per_minute - self.minute_count,
112 remaining_hour: config.requests_per_hour - self.hour_count,
113 }
114 }
115
116 pub fn is_expired(&self) -> bool {
118 Instant::now() >= self.hour_reset
119 }
120
121 pub fn minute_count(&self) -> u32 {
123 self.minute_count
124 }
125
126 pub fn hour_count(&self) -> u32 {
128 self.hour_count
129 }
130}
131
132impl Default for KeyRateLimit {
133 fn default() -> Self {
134 Self::new()
135 }
136}
137
138#[derive(Clone)]
142pub struct RateLimitStore {
143 state: Arc<DashMap<String, KeyRateLimit>>,
144}
145
146impl RateLimitStore {
147 pub fn new() -> Self {
149 Self {
150 state: Arc::new(DashMap::with_capacity(1000)),
151 }
152 }
153
154 pub fn with_capacity(capacity: usize) -> Self {
156 Self {
157 state: Arc::new(DashMap::with_capacity(capacity)),
158 }
159 }
160
161 pub fn check(&self, key: &str, config: &RateLimitConfig) -> RateLimitResult {
163 let mut entry = self.state.entry(key.to_string()).or_default();
164 entry.check_and_increment(config)
165 }
166
167 pub fn cleanup_expired(&self) {
171 self.state.retain(|_, limit| !limit.is_expired());
172 }
173
174 pub fn len(&self) -> usize {
176 self.state.len()
177 }
178
179 pub fn is_empty(&self) -> bool {
181 self.state.is_empty()
182 }
183
184 pub fn remove(&self, key: &str) {
186 self.state.remove(key);
187 }
188
189 pub fn clear(&self) {
191 self.state.clear();
192 }
193
194 #[cfg(feature = "cleanup-task")]
209 pub fn spawn_cleanup_task(
210 self: Arc<Self>,
211 interval: std::time::Duration,
212 ) -> tokio::task::JoinHandle<()> {
213 tokio::spawn(async move {
214 let mut ticker = tokio::time::interval(interval);
215 ticker.tick().await;
217 loop {
218 ticker.tick().await;
219 let before = self.len();
220 self.cleanup_expired();
221 let after = self.len();
222 if before != after {
223 tracing::debug!(
224 before = before,
225 after = after,
226 removed = before - after,
227 "Rate limiter cleanup completed"
228 );
229 }
230 }
231 })
232 }
233}
234
235impl Default for RateLimitStore {
236 fn default() -> Self {
237 Self::new()
238 }
239}
240
241#[cfg(test)]
242mod tests {
243 use super::*;
244
245 #[test]
246 fn test_key_rate_limit_new() {
247 let limit = KeyRateLimit::new();
248 assert_eq!(limit.minute_count, 0);
249 assert_eq!(limit.hour_count, 0);
250 }
251
252 #[test]
253 fn test_check_and_increment_allowed() {
254 let mut limit = KeyRateLimit::new();
255 let config = RateLimitConfig::for_plan("free");
256
257 let result = limit.check_and_increment(&config);
258 assert!(result.is_allowed());
259 assert_eq!(limit.minute_count, 1);
260 assert_eq!(limit.hour_count, 1);
261 }
262
263 #[test]
264 fn test_check_and_increment_exceeded_minute() {
265 let mut limit = KeyRateLimit::new();
266 let config = RateLimitConfig::for_plan("free"); for _ in 0..20 {
270 let result = limit.check_and_increment(&config);
271 assert!(result.is_allowed());
272 }
273
274 let result = limit.check_and_increment(&config);
276 assert!(result.is_exceeded());
277 assert!(result.retry_after().unwrap() > 0);
278 }
279
280 #[test]
281 fn test_rate_limit_result_methods() {
282 let allowed = RateLimitResult::Allowed {
283 remaining_minute: 10,
284 remaining_hour: 100,
285 };
286 assert!(allowed.is_allowed());
287 assert!(!allowed.is_exceeded());
288 assert!(allowed.retry_after().is_none());
289
290 let exceeded = RateLimitResult::Exceeded {
291 retry_after: 30,
292 limit_type: "minute",
293 };
294 assert!(!exceeded.is_allowed());
295 assert!(exceeded.is_exceeded());
296 assert_eq!(exceeded.retry_after(), Some(30));
297 }
298
299 #[test]
300 fn test_store_basic() {
301 let store = RateLimitStore::new();
302 let config = RateLimitConfig::for_plan("free");
303
304 let result = store.check("user1", &config);
305 assert!(result.is_allowed());
306 assert_eq!(store.len(), 1);
307
308 let result = store.check("user2", &config);
309 assert!(result.is_allowed());
310 assert_eq!(store.len(), 2);
311 }
312
313 #[test]
314 fn test_store_remove() {
315 let store = RateLimitStore::new();
316 let config = RateLimitConfig::for_plan("free");
317
318 store.check("user1", &config);
319 assert_eq!(store.len(), 1);
320
321 store.remove("user1");
322 assert_eq!(store.len(), 0);
323 }
324
325 #[test]
326 fn test_store_clear() {
327 let store = RateLimitStore::new();
328 let config = RateLimitConfig::for_plan("free");
329
330 store.check("user1", &config);
331 store.check("user2", &config);
332 assert_eq!(store.len(), 2);
333
334 store.clear();
335 assert!(store.is_empty());
336 }
337}