rate_gate/
lib.rs

1use std::hash::Hash;
2use std::sync::{Arc, Mutex};
3use std::time::{Duration, Instant};
4
5use hashbrown::HashMap;
6
7#[derive(Default, Debug, Clone)]
8pub struct Limiter<T>
9where
10    T: Hash + Eq + Send + 'static,
11{
12    requests: Arc<Mutex<HashMap<T, AssociatedEntity>>>,
13}
14
15#[derive(Debug, Clone, Hash)]
16pub struct AssociatedEntity {
17    bucket: usize, // How many requests are left in the bucket, 0 means the hard limit.
18    bucket_init: Instant, // When was the last bucket refreshed
19    bucket_max: usize, // set by user, this is the value the bucket will get refilled with.
20    refresh_rate: Duration, // Every refresh_rate tick bucket gets filled with bucket_max
21}
22
23impl<T> Limiter<T>
24where
25    T: Hash + Eq + Send + 'static,
26{
27    pub fn new() -> Self {
28        Limiter {
29            requests: Arc::new(Mutex::new(HashMap::new())),
30        }
31    }
32
33    /// Adds a entity to the limiter
34    /// `entity` is something hashable like a IP, username, etc...
35    ///
36    /// `max_limit` is the max number of requests in the given timeframe that you allow for that specific entity
37    ///
38    /// `refresh_rate` is the timeframe after which the entity gets a renewed limit
39    pub fn add_limited_entity(&self, entity: T, max_limit: usize, refresh_rate: Duration) {
40        let mut requests = self.requests.lock().unwrap();
41        requests.insert(
42            entity,
43            AssociatedEntity {
44                bucket: max_limit,
45                bucket_init: Instant::now(),
46                bucket_max: max_limit,
47                refresh_rate,
48            },
49        );
50    }
51
52    /// Removes a entity from the limiter
53    ///
54    /// Removes a key from the map, returning the value at the key if the key was previously in the map.
55    /// Keeps the allocated memory for reuse.
56    ///
57    /// The key may be any borrowed form of the map's key type,
58    /// but Hash and Eq on the borrowed form must match those for the key type.
59    pub fn remove_limited_entity(&self, entity: T) -> Option<AssociatedEntity> {
60        let mut requests = self.requests.lock().unwrap();
61        requests.remove(&entity)
62    }
63
64    /// Checks whether a entity has requests left to consume.
65    ///
66    /// `entity` has been added by you previously with `add_limited_entity`
67    ///
68    /// ### returns:
69    ///
70    /// `None` -> entity was not found by the limiter, create one with `add_limited_entity`.
71    ///
72    /// `Some(false)` -> entity is rate limited, no requests to consume.
73    ///
74    /// `Some(true)` -> everything worked, entity had requests left.
75    pub fn is_entity_limited(&mut self, entity: &T) -> Option<bool> {
76        let mut requests = self.requests.lock().unwrap();
77        let now = Instant::now();
78
79        if let Some(entry) = requests.get_mut(entity) {
80            if now.duration_since(entry.bucket_init) >= entry.refresh_rate {
81                entry.bucket = entry.bucket_max;
82                entry.bucket_init = now;
83            }
84
85            if entry.bucket > 0 {
86                entry.bucket -= 1; // request allowed
87                Some(true)
88            } else {
89                // entity is limited, request denied.
90                Some(false)
91            }
92        } else {
93            None
94        }
95    }
96
97    /// Returns the current amount of requests left in the entity's bucket.
98    ///
99    /// `entity` has been added by you previously with `add_limited_entity`
100    ///
101    /// ### returns:
102    ///
103    /// `None` -> entity was not found by the limiter, create one with `add_limited_entity`.
104    ///
105    /// `Some(usize)` -> the current number of requests left in the entity's bucket.
106    pub fn get_bucket_remaining(&self, entity: &T) -> Option<usize> {
107        let requests = self.requests.lock().unwrap();
108        requests.get(entity).map(|entry| entry.bucket)
109    }
110}
111
112#[cfg(test)]
113mod tests {
114    use super::*;
115    use std::thread;
116    use std::time::Duration;
117
118    #[test]
119    fn test_add_limited_entity() {
120        let limiter: Limiter<&str> = Limiter::new();
121        limiter.add_limited_entity("user1", 5, Duration::from_secs(60));
122
123        let requests = limiter.requests.lock().unwrap();
124        assert!(requests.contains_key("user1"));
125        assert_eq!(requests["user1"].bucket_max, 5);
126        assert_eq!(requests["user1"].bucket, 5);
127    }
128
129    #[test]
130    fn test_limiter_refresh_rate() {
131        let mut limiter: Limiter<&str> = Limiter::new();
132        let refresh_rate = Duration::from_millis(500);
133        let max_requests = 3;
134
135        limiter.add_limited_entity("user1", max_requests, refresh_rate);
136
137        for _ in 0..max_requests {
138            assert_eq!(
139                limiter.is_entity_limited(&"user1"),
140                Some(true),
141                "Request should be allowed"
142            );
143        }
144
145        assert_eq!(
146            limiter.is_entity_limited(&"user1"),
147            Some(false),
148            "Request should be denied after limit is reached"
149        );
150
151        thread::sleep(refresh_rate + Duration::from_millis(50));
152
153        // After refresh, we should be able to make max_requests again
154        for i in 0..max_requests {
155            assert_eq!(
156                limiter.is_entity_limited(&"user1"),
157                Some(true),
158                "Request {} should be allowed after refresh",
159                i + 1
160            );
161        }
162
163        // The next request should be denied again
164        assert_eq!(
165            limiter.is_entity_limited(&"user1"),
166            Some(false),
167            "Request should be denied after refreshed limit is reached"
168        );
169    }
170
171    #[test]
172    fn test_is_entity_limited_allows_requests() {
173        let mut limiter: Limiter<&str> = Limiter::new();
174        limiter.add_limited_entity("user1", 2, Duration::from_secs(60));
175
176        assert_eq!(limiter.is_entity_limited(&"user1"), Some(true));
177        assert_eq!(limiter.is_entity_limited(&"user1"), Some(true));
178        assert_eq!(limiter.is_entity_limited(&"user1"), Some(false));
179    }
180
181    #[test]
182    fn test_is_entity_limited_refills_bucket() {
183        let mut limiter: Limiter<&str> = Limiter::new();
184        limiter.add_limited_entity("user1", 1, Duration::from_millis(10));
185
186        assert_eq!(limiter.is_entity_limited(&"user1"), Some(true));
187        assert_eq!(limiter.is_entity_limited(&"user1"), Some(false));
188        thread::sleep(Duration::from_millis(25));
189        assert_eq!(limiter.is_entity_limited(&"user1"), Some(true));
190    }
191
192    #[test]
193    fn test_is_entity_limited_not_found() {
194        let mut limiter: Limiter<&str> = Limiter::new();
195        assert_eq!(limiter.is_entity_limited(&"unknown_user"), None);
196    }
197
198    #[test]
199    fn test_multiple_entities() {
200        let mut limiter: Limiter<&str> = Limiter::new();
201        limiter.add_limited_entity("user1", 3, Duration::from_secs(60));
202        limiter.add_limited_entity("user2", 5, Duration::from_secs(60));
203
204        assert_eq!(limiter.is_entity_limited(&"user1"), Some(true));
205        assert_eq!(limiter.is_entity_limited(&"user1"), Some(true));
206        assert_eq!(limiter.is_entity_limited(&"user1"), Some(true));
207        assert_eq!(limiter.is_entity_limited(&"user1"), Some(false)); // Now should be limited
208
209        assert_eq!(limiter.is_entity_limited(&"user2"), Some(true));
210        assert_eq!(limiter.is_entity_limited(&"user2"), Some(true));
211        assert_eq!(limiter.is_entity_limited(&"user2"), Some(true));
212        assert_eq!(limiter.is_entity_limited(&"user2"), Some(true));
213        assert_eq!(limiter.is_entity_limited(&"user2"), Some(true));
214        assert_eq!(limiter.is_entity_limited(&"user2"), Some(false)); // Now should be limited
215    }
216
217    #[test]
218    fn test_limiter_with_multiple_threads() {
219        let limiter = Arc::new(Mutex::new(Limiter::new()));
220        limiter
221            .lock()
222            .unwrap()
223            .add_limited_entity("user1", 5, Duration::from_secs(60));
224
225        let limiter_clone1 = Arc::clone(&limiter);
226        let limiter_clone2 = Arc::clone(&limiter);
227        let limiter_clone3 = Arc::clone(&limiter);
228
229        let thread1 = thread::spawn(move || {
230            for _ in 0..2 {
231                assert_eq!(
232                    limiter_clone1.lock().unwrap().is_entity_limited(&"user1"),
233                    Some(true)
234                );
235            }
236        });
237
238        let thread2 = thread::spawn(move || {
239            for _ in 0..2 {
240                assert_eq!(
241                    limiter_clone2.lock().unwrap().is_entity_limited(&"user1"),
242                    Some(true)
243                );
244            }
245        });
246
247        let thread3 = thread::spawn(move || {
248            assert_eq!(
249                limiter_clone3.lock().unwrap().is_entity_limited(&"user1"),
250                Some(true)
251            );
252        });
253
254        thread1.join().unwrap();
255        thread2.join().unwrap();
256        thread3.join().unwrap();
257
258        assert_eq!(
259            limiter.lock().unwrap().is_entity_limited(&"user1"),
260            Some(false)
261        );
262    }
263
264    #[test]
265    fn test_remove_limited_entity() {
266        let mut limiter: Limiter<&str> = Limiter::new();
267        limiter.add_limited_entity("user1", 5, Duration::from_secs(60));
268
269        {
270            let requests = limiter.requests.lock().unwrap();
271            assert!(requests.contains_key("user1"));
272        }
273
274        let removed_entity_exact = limiter.remove_limited_entity("user1");
275
276        assert!(removed_entity_exact.is_some());
277        assert_eq!(removed_entity_exact.unwrap().bucket_max, 5);
278
279        {
280            let requests = limiter.requests.lock().unwrap();
281            assert!(!requests.contains_key("user1"));
282        }
283
284        let removed_non_existent = limiter.remove_limited_entity("unknown_user");
285        assert!(removed_non_existent.is_none());
286
287        assert_eq!(limiter.is_entity_limited(&"user1"), None);
288
289        limiter.add_limited_entity("user2", 5, Duration::from_secs(60));
290        let borrowed_key: &str = "user2";
291        let removed_entity_borrowed = limiter.remove_limited_entity(borrowed_key);
292
293        assert!(removed_entity_borrowed.is_some());
294        assert_eq!(removed_entity_borrowed.unwrap().bucket_max, 5);
295
296        {
297            let requests = limiter.requests.lock().unwrap();
298            assert!(!requests.contains_key("user2"));
299        }
300    }
301
302    #[test]
303    fn test_remove_limited_entity_memory_reuse() {
304        let limiter: Limiter<&str> = Limiter::new();
305
306        limiter.add_limited_entity("user1", 5, Duration::from_secs(60));
307
308        {
309            let requests = limiter.requests.lock().unwrap();
310            assert!(requests.contains_key("user1"));
311        }
312
313        let removed_entity = limiter.remove_limited_entity("user1");
314        assert!(removed_entity.is_some());
315
316        {
317            let requests = limiter.requests.lock().unwrap();
318            assert!(!requests.contains_key("user1"));
319        }
320
321        limiter.add_limited_entity("user1", 10, Duration::from_secs(120));
322
323        {
324            let requests = limiter.requests.lock().unwrap();
325            assert!(requests.contains_key("user1"));
326            assert_eq!(requests["user1"].bucket_max, 10);
327            assert_eq!(requests["user1"].bucket, 10); // Should reflect the new bucket max
328        }
329
330        let removed_entity_after_reuse = limiter.remove_limited_entity("user1");
331        assert!(removed_entity_after_reuse.is_some());
332        assert_eq!(removed_entity_after_reuse.unwrap().bucket_max, 10);
333
334        {
335            let requests = limiter.requests.lock().unwrap();
336            assert!(!requests.contains_key("user1"));
337        }
338    }
339
340    #[test]
341    fn test_get_bucket_remaining() {
342        let mut limiter: Limiter<&str> = Limiter::new();
343
344        assert_eq!(limiter.get_bucket_remaining(&"user1"), None);
345
346        limiter.add_limited_entity("user1", 5, Duration::from_secs(60));
347
348        assert_eq!(limiter.get_bucket_remaining(&"user1"), Some(5));
349
350        limiter.is_entity_limited(&"user1");
351
352        assert_eq!(limiter.get_bucket_remaining(&"user1"), Some(4));
353
354        limiter.is_entity_limited(&"user1");
355        limiter.is_entity_limited(&"user1");
356
357        assert_eq!(limiter.get_bucket_remaining(&"user1"), Some(2));
358    }
359}