1use std::hash::Hash;
2use std::sync::{Arc, Mutex};
3use std::time::{Duration, Instant};
4
5use hashbrown::HashMap;
6
7#[derive(Default, Debug, Clone)]
8pub struct Limiter<T>
9where
10 T: Hash + Eq + Send + 'static,
11{
12 requests: Arc<Mutex<HashMap<T, AssociatedEntity>>>,
13}
14
15#[derive(Debug, Clone, Hash)]
16pub struct AssociatedEntity {
17 bucket: usize, bucket_init: Instant, bucket_max: usize, refresh_rate: Duration, }
22
23impl<T> Limiter<T>
24where
25 T: Hash + Eq + Send + 'static,
26{
27 pub fn new() -> Self {
28 Limiter {
29 requests: Arc::new(Mutex::new(HashMap::new())),
30 }
31 }
32
33 pub fn add_limited_entity(&self, entity: T, max_limit: usize, refresh_rate: Duration) {
40 let mut requests = self.requests.lock().unwrap();
41 requests.insert(
42 entity,
43 AssociatedEntity {
44 bucket: max_limit,
45 bucket_init: Instant::now(),
46 bucket_max: max_limit,
47 refresh_rate,
48 },
49 );
50 }
51
52 pub fn remove_limited_entity(&self, entity: T) -> Option<AssociatedEntity> {
60 let mut requests = self.requests.lock().unwrap();
61 requests.remove(&entity)
62 }
63
64 pub fn is_entity_limited(&mut self, entity: &T) -> Option<bool> {
76 let mut requests = self.requests.lock().unwrap();
77 let now = Instant::now();
78
79 if let Some(entry) = requests.get_mut(entity) {
80 if now.duration_since(entry.bucket_init) >= entry.refresh_rate {
81 entry.bucket = entry.bucket_max;
82 entry.bucket_init = now;
83 }
84
85 if entry.bucket > 0 {
86 entry.bucket -= 1; Some(true)
88 } else {
89 Some(false)
91 }
92 } else {
93 None
94 }
95 }
96
97 pub fn get_bucket_remaining(&self, entity: &T) -> Option<usize> {
107 let requests = self.requests.lock().unwrap();
108 requests.get(entity).map(|entry| entry.bucket)
109 }
110}
111
112#[cfg(test)]
113mod tests {
114 use super::*;
115 use std::thread;
116 use std::time::Duration;
117
118 #[test]
119 fn test_add_limited_entity() {
120 let limiter: Limiter<&str> = Limiter::new();
121 limiter.add_limited_entity("user1", 5, Duration::from_secs(60));
122
123 let requests = limiter.requests.lock().unwrap();
124 assert!(requests.contains_key("user1"));
125 assert_eq!(requests["user1"].bucket_max, 5);
126 assert_eq!(requests["user1"].bucket, 5);
127 }
128
129 #[test]
130 fn test_limiter_refresh_rate() {
131 let mut limiter: Limiter<&str> = Limiter::new();
132 let refresh_rate = Duration::from_millis(500);
133 let max_requests = 3;
134
135 limiter.add_limited_entity("user1", max_requests, refresh_rate);
136
137 for _ in 0..max_requests {
138 assert_eq!(
139 limiter.is_entity_limited(&"user1"),
140 Some(true),
141 "Request should be allowed"
142 );
143 }
144
145 assert_eq!(
146 limiter.is_entity_limited(&"user1"),
147 Some(false),
148 "Request should be denied after limit is reached"
149 );
150
151 thread::sleep(refresh_rate + Duration::from_millis(50));
152
153 for i in 0..max_requests {
155 assert_eq!(
156 limiter.is_entity_limited(&"user1"),
157 Some(true),
158 "Request {} should be allowed after refresh",
159 i + 1
160 );
161 }
162
163 assert_eq!(
165 limiter.is_entity_limited(&"user1"),
166 Some(false),
167 "Request should be denied after refreshed limit is reached"
168 );
169 }
170
171 #[test]
172 fn test_is_entity_limited_allows_requests() {
173 let mut limiter: Limiter<&str> = Limiter::new();
174 limiter.add_limited_entity("user1", 2, Duration::from_secs(60));
175
176 assert_eq!(limiter.is_entity_limited(&"user1"), Some(true));
177 assert_eq!(limiter.is_entity_limited(&"user1"), Some(true));
178 assert_eq!(limiter.is_entity_limited(&"user1"), Some(false));
179 }
180
181 #[test]
182 fn test_is_entity_limited_refills_bucket() {
183 let mut limiter: Limiter<&str> = Limiter::new();
184 limiter.add_limited_entity("user1", 1, Duration::from_millis(10));
185
186 assert_eq!(limiter.is_entity_limited(&"user1"), Some(true));
187 assert_eq!(limiter.is_entity_limited(&"user1"), Some(false));
188 thread::sleep(Duration::from_millis(25));
189 assert_eq!(limiter.is_entity_limited(&"user1"), Some(true));
190 }
191
192 #[test]
193 fn test_is_entity_limited_not_found() {
194 let mut limiter: Limiter<&str> = Limiter::new();
195 assert_eq!(limiter.is_entity_limited(&"unknown_user"), None);
196 }
197
198 #[test]
199 fn test_multiple_entities() {
200 let mut limiter: Limiter<&str> = Limiter::new();
201 limiter.add_limited_entity("user1", 3, Duration::from_secs(60));
202 limiter.add_limited_entity("user2", 5, Duration::from_secs(60));
203
204 assert_eq!(limiter.is_entity_limited(&"user1"), Some(true));
205 assert_eq!(limiter.is_entity_limited(&"user1"), Some(true));
206 assert_eq!(limiter.is_entity_limited(&"user1"), Some(true));
207 assert_eq!(limiter.is_entity_limited(&"user1"), Some(false)); assert_eq!(limiter.is_entity_limited(&"user2"), Some(true));
210 assert_eq!(limiter.is_entity_limited(&"user2"), Some(true));
211 assert_eq!(limiter.is_entity_limited(&"user2"), Some(true));
212 assert_eq!(limiter.is_entity_limited(&"user2"), Some(true));
213 assert_eq!(limiter.is_entity_limited(&"user2"), Some(true));
214 assert_eq!(limiter.is_entity_limited(&"user2"), Some(false)); }
216
217 #[test]
218 fn test_limiter_with_multiple_threads() {
219 let limiter = Arc::new(Mutex::new(Limiter::new()));
220 limiter
221 .lock()
222 .unwrap()
223 .add_limited_entity("user1", 5, Duration::from_secs(60));
224
225 let limiter_clone1 = Arc::clone(&limiter);
226 let limiter_clone2 = Arc::clone(&limiter);
227 let limiter_clone3 = Arc::clone(&limiter);
228
229 let thread1 = thread::spawn(move || {
230 for _ in 0..2 {
231 assert_eq!(
232 limiter_clone1.lock().unwrap().is_entity_limited(&"user1"),
233 Some(true)
234 );
235 }
236 });
237
238 let thread2 = thread::spawn(move || {
239 for _ in 0..2 {
240 assert_eq!(
241 limiter_clone2.lock().unwrap().is_entity_limited(&"user1"),
242 Some(true)
243 );
244 }
245 });
246
247 let thread3 = thread::spawn(move || {
248 assert_eq!(
249 limiter_clone3.lock().unwrap().is_entity_limited(&"user1"),
250 Some(true)
251 );
252 });
253
254 thread1.join().unwrap();
255 thread2.join().unwrap();
256 thread3.join().unwrap();
257
258 assert_eq!(
259 limiter.lock().unwrap().is_entity_limited(&"user1"),
260 Some(false)
261 );
262 }
263
264 #[test]
265 fn test_remove_limited_entity() {
266 let mut limiter: Limiter<&str> = Limiter::new();
267 limiter.add_limited_entity("user1", 5, Duration::from_secs(60));
268
269 {
270 let requests = limiter.requests.lock().unwrap();
271 assert!(requests.contains_key("user1"));
272 }
273
274 let removed_entity_exact = limiter.remove_limited_entity("user1");
275
276 assert!(removed_entity_exact.is_some());
277 assert_eq!(removed_entity_exact.unwrap().bucket_max, 5);
278
279 {
280 let requests = limiter.requests.lock().unwrap();
281 assert!(!requests.contains_key("user1"));
282 }
283
284 let removed_non_existent = limiter.remove_limited_entity("unknown_user");
285 assert!(removed_non_existent.is_none());
286
287 assert_eq!(limiter.is_entity_limited(&"user1"), None);
288
289 limiter.add_limited_entity("user2", 5, Duration::from_secs(60));
290 let borrowed_key: &str = "user2";
291 let removed_entity_borrowed = limiter.remove_limited_entity(borrowed_key);
292
293 assert!(removed_entity_borrowed.is_some());
294 assert_eq!(removed_entity_borrowed.unwrap().bucket_max, 5);
295
296 {
297 let requests = limiter.requests.lock().unwrap();
298 assert!(!requests.contains_key("user2"));
299 }
300 }
301
302 #[test]
303 fn test_remove_limited_entity_memory_reuse() {
304 let limiter: Limiter<&str> = Limiter::new();
305
306 limiter.add_limited_entity("user1", 5, Duration::from_secs(60));
307
308 {
309 let requests = limiter.requests.lock().unwrap();
310 assert!(requests.contains_key("user1"));
311 }
312
313 let removed_entity = limiter.remove_limited_entity("user1");
314 assert!(removed_entity.is_some());
315
316 {
317 let requests = limiter.requests.lock().unwrap();
318 assert!(!requests.contains_key("user1"));
319 }
320
321 limiter.add_limited_entity("user1", 10, Duration::from_secs(120));
322
323 {
324 let requests = limiter.requests.lock().unwrap();
325 assert!(requests.contains_key("user1"));
326 assert_eq!(requests["user1"].bucket_max, 10);
327 assert_eq!(requests["user1"].bucket, 10); }
329
330 let removed_entity_after_reuse = limiter.remove_limited_entity("user1");
331 assert!(removed_entity_after_reuse.is_some());
332 assert_eq!(removed_entity_after_reuse.unwrap().bucket_max, 10);
333
334 {
335 let requests = limiter.requests.lock().unwrap();
336 assert!(!requests.contains_key("user1"));
337 }
338 }
339
340 #[test]
341 fn test_get_bucket_remaining() {
342 let mut limiter: Limiter<&str> = Limiter::new();
343
344 assert_eq!(limiter.get_bucket_remaining(&"user1"), None);
345
346 limiter.add_limited_entity("user1", 5, Duration::from_secs(60));
347
348 assert_eq!(limiter.get_bucket_remaining(&"user1"), Some(5));
349
350 limiter.is_entity_limited(&"user1");
351
352 assert_eq!(limiter.get_bucket_remaining(&"user1"), Some(4));
353
354 limiter.is_entity_limited(&"user1");
355 limiter.is_entity_limited(&"user1");
356
357 assert_eq!(limiter.get_bucket_remaining(&"user1"), Some(2));
358 }
359}