rust_serv/throttle/
limiter.rs1use std::collections::HashMap;
4use std::sync::Arc;
5use std::time::Duration;
6use tokio::sync::RwLock;
7
8use super::config::ThrottleConfig;
9use super::token_bucket::TokenBucket;
10
11#[derive(Debug, Clone, PartialEq)]
13pub enum ThrottleResult {
14 Allowed { remaining: u64 },
16 Throttled { wait_ms: u64 },
18 Unlimited,
20}
21
22impl ThrottleResult {
23 pub fn is_allowed(&self) -> bool {
25 matches!(self, ThrottleResult::Allowed { .. } | ThrottleResult::Unlimited)
26 }
27}
28
29#[derive(Debug)]
31pub struct ThrottleLimiter {
32 config: ThrottleConfig,
33 global_bucket: Arc<RwLock<TokenBucket>>,
34 ip_buckets: Arc<RwLock<HashMap<String, TokenBucket>>>,
35}
36
37impl ThrottleLimiter {
38 pub fn new(config: ThrottleConfig) -> Self {
40 let global_bucket = if config.has_global_limit() {
41 TokenBucket::new(
42 config.bucket_capacity,
43 config.global_limit,
44 )
45 } else {
46 TokenBucket::new(0, 0)
47 };
48
49 Self {
50 config,
51 global_bucket: Arc::new(RwLock::new(global_bucket)),
52 ip_buckets: Arc::new(RwLock::new(HashMap::new())),
53 }
54 }
55
56 pub fn config(&self) -> &ThrottleConfig {
58 &self.config
59 }
60
61 pub async fn check(&self, ip: &str, bytes: u64) -> ThrottleResult {
63 if !self.config.is_active() {
64 return ThrottleResult::Unlimited;
65 }
66
67 if self.config.has_global_limit() {
69 let mut bucket = self.global_bucket.write().await;
70 if !bucket.try_consume(bytes) {
71 let wait = bucket.wait_time(bytes);
72 return ThrottleResult::Throttled {
73 wait_ms: wait.as_millis() as u64,
74 };
75 }
76 }
77
78 if self.config.has_per_ip_limit() {
80 let mut buckets = self.ip_buckets.write().await;
81 let bucket = buckets.entry(ip.to_string()).or_insert_with(|| {
82 TokenBucket::new(
83 self.config.bucket_capacity,
84 self.config.per_ip_limit,
85 )
86 });
87
88 if !bucket.try_consume(bytes) {
89 let wait = bucket.wait_time(bytes);
90 return ThrottleResult::Throttled {
91 wait_ms: wait.as_millis() as u64,
92 };
93 }
94 }
95
96 ThrottleResult::Allowed { remaining: 0 }
97 }
98
99 pub async fn consume(&self, ip: &str, bytes: u64) -> u64 {
102 if !self.config.is_active() {
103 return bytes;
104 }
105
106 let mut total_consumed = bytes;
107
108 if self.config.has_global_limit() {
110 let mut bucket = self.global_bucket.write().await;
111 total_consumed = total_consumed.min(bucket.consume(bytes));
112 }
113
114 if self.config.has_per_ip_limit() && total_consumed > 0 {
116 let mut buckets = self.ip_buckets.write().await;
117 let bucket = buckets.entry(ip.to_string()).or_insert_with(|| {
118 TokenBucket::new(
119 self.config.bucket_capacity,
120 self.config.per_ip_limit,
121 )
122 });
123 total_consumed = total_consumed.min(bucket.consume(bytes));
124 }
125
126 total_consumed
127 }
128
129 pub async fn wait_time(&self, ip: &str, bytes: u64) -> Duration {
131 if !self.config.is_active() {
132 return Duration::ZERO;
133 }
134
135 let mut max_wait = Duration::ZERO;
136
137 if self.config.has_global_limit() {
138 let mut bucket = self.global_bucket.write().await;
139 max_wait = max_wait.max(bucket.wait_time(bytes));
140 }
141
142 if self.config.has_per_ip_limit() {
143 let buckets = self.ip_buckets.write().await;
144 if let Some(bucket) = buckets.get(ip) {
145 let mut bucket = bucket.clone();
146 max_wait = max_wait.max(bucket.wait_time(bytes));
147 }
148 }
149
150 max_wait
151 }
152
153 pub async fn reset(&self) {
155 if self.config.has_global_limit() {
156 let mut bucket = self.global_bucket.write().await;
157 bucket.reset();
158 }
159
160 let mut buckets = self.ip_buckets.write().await;
161 for bucket in buckets.values_mut() {
162 bucket.reset();
163 }
164 }
165
166 pub async fn clear_ip_buckets(&self) {
168 let mut buckets = self.ip_buckets.write().await;
169 buckets.clear();
170 }
171
172 pub async fn tracked_ip_count(&self) -> usize {
174 let buckets = self.ip_buckets.read().await;
175 buckets.len()
176 }
177
178 pub async fn remove_ip(&self, ip: &str) -> bool {
180 let mut buckets = self.ip_buckets.write().await;
181 buckets.remove(ip).is_some()
182 }
183
184 pub fn update_config(&mut self, config: ThrottleConfig) {
186 self.config = config;
187
188 let global_bucket = if self.config.has_global_limit() {
190 TokenBucket::new(
191 self.config.bucket_capacity,
192 self.config.global_limit,
193 )
194 } else {
195 TokenBucket::new(0, 0)
196 };
197
198 self.global_bucket = Arc::new(RwLock::new(global_bucket));
199 }
200
201 pub async fn global_tokens(&self) -> u64 {
203 if !self.config.has_global_limit() {
204 return u64::MAX;
205 }
206 let mut bucket = self.global_bucket.write().await;
207 bucket.tokens()
208 }
209
210 pub async fn ip_tokens(&self, ip: &str) -> u64 {
212 if !self.config.has_per_ip_limit() {
213 return u64::MAX;
214 }
215 let buckets = self.ip_buckets.read().await;
216 if let Some(bucket) = buckets.get(ip) {
217 let mut bucket = bucket.clone();
218 bucket.tokens()
219 } else {
220 self.config.bucket_capacity
221 }
222 }
223}
224
225impl Clone for ThrottleLimiter {
226 fn clone(&self) -> Self {
227 Self {
228 config: self.config.clone(),
229 global_bucket: Arc::clone(&self.global_bucket),
230 ip_buckets: Arc::clone(&self.ip_buckets),
231 }
232 }
233}
234
235#[cfg(test)]
236mod tests {
237 use super::*;
238
239 #[test]
240 fn test_limiter_creation() {
241 let config = ThrottleConfig::new();
242 let limiter = ThrottleLimiter::new(config);
243
244 assert!(!limiter.config().is_active());
245 }
246
247 #[tokio::test]
248 async fn test_check_unlimited() {
249 let config = ThrottleConfig::new(); let limiter = ThrottleLimiter::new(config);
251
252 let result = limiter.check("127.0.0.1", 1000).await;
253 assert_eq!(result, ThrottleResult::Unlimited);
254 assert!(result.is_allowed());
255 }
256
257 #[tokio::test]
258 async fn test_check_global_limit() {
259 let config = ThrottleConfig::new()
260 .enable()
261 .with_global_limit(1000)
262 .with_bucket_capacity(1000);
263
264 let limiter = ThrottleLimiter::new(config);
265
266 let result = limiter.check("127.0.0.1", 500).await;
268 assert!(result.is_allowed());
269
270 let result = limiter.check("127.0.0.1", 500).await;
272 assert!(result.is_allowed());
273
274 let result = limiter.check("127.0.0.1", 500).await;
276 assert!(!result.is_allowed());
277 }
278
279 #[tokio::test]
280 async fn test_check_per_ip_limit() {
281 let config = ThrottleConfig::new()
282 .enable()
283 .with_per_ip_limit(500)
284 .with_bucket_capacity(500);
285
286 let limiter = ThrottleLimiter::new(config);
287
288 let result = limiter.check("127.0.0.1", 300).await;
290 assert!(result.is_allowed());
291
292 let result = limiter.check("127.0.0.1", 300).await;
293 assert!(!result.is_allowed());
294
295 let result = limiter.check("127.0.0.2", 300).await;
297 assert!(result.is_allowed());
298 }
299
300 #[tokio::test]
301 async fn test_consume_unlimited() {
302 let config = ThrottleConfig::new();
303 let limiter = ThrottleLimiter::new(config);
304
305 let consumed = limiter.consume("127.0.0.1", 1000).await;
306 assert_eq!(consumed, 1000);
307 }
308
309 #[tokio::test]
310 async fn test_consume_partial() {
311 let config = ThrottleConfig::new()
312 .enable()
313 .with_global_limit(500)
314 .with_bucket_capacity(500);
315
316 let limiter = ThrottleLimiter::new(config);
317
318 let consumed = limiter.consume("127.0.0.1", 300).await;
319 assert_eq!(consumed, 300);
320
321 let consumed = limiter.consume("127.0.0.1", 300).await;
322 assert_eq!(consumed, 200); }
324
325 #[tokio::test]
326 async fn test_wait_time_zero() {
327 let config = ThrottleConfig::new()
328 .enable()
329 .with_global_limit(1000)
330 .with_bucket_capacity(1000);
331
332 let limiter = ThrottleLimiter::new(config);
333
334 let wait = limiter.wait_time("127.0.0.1", 500).await;
335 assert_eq!(wait, Duration::ZERO);
336 }
337
338 #[tokio::test]
339 async fn test_reset() {
340 let config = ThrottleConfig::new()
341 .enable()
342 .with_global_limit(1000)
343 .with_bucket_capacity(1000);
344
345 let limiter = ThrottleLimiter::new(config);
346
347 limiter.check("127.0.0.1", 1000).await;
348
349 limiter.reset().await;
350
351 let tokens = limiter.global_tokens().await;
352 assert_eq!(tokens, 1000);
353 }
354
355 #[tokio::test]
356 async fn test_clear_ip_buckets() {
357 let config = ThrottleConfig::new()
358 .enable()
359 .with_per_ip_limit(1000)
360 .with_bucket_capacity(1000);
361
362 let limiter = ThrottleLimiter::new(config);
363
364 limiter.check("127.0.0.1", 500).await;
365 limiter.check("127.0.0.2", 500).await;
366
367 assert_eq!(limiter.tracked_ip_count().await, 2);
368
369 limiter.clear_ip_buckets().await;
370
371 assert_eq!(limiter.tracked_ip_count().await, 0);
372 }
373
374 #[tokio::test]
375 async fn test_remove_ip() {
376 let config = ThrottleConfig::new()
377 .enable()
378 .with_per_ip_limit(1000)
379 .with_bucket_capacity(1000);
380
381 let limiter = ThrottleLimiter::new(config);
382
383 limiter.check("127.0.0.1", 500).await;
384 limiter.check("127.0.0.2", 500).await;
385
386 assert!(limiter.remove_ip("127.0.0.1").await);
387 assert_eq!(limiter.tracked_ip_count().await, 1);
388
389 assert!(!limiter.remove_ip("127.0.0.1").await); }
391
392 #[tokio::test]
393 async fn test_update_config() {
394 let config = ThrottleConfig::new()
395 .enable()
396 .with_global_limit(1000)
397 .with_bucket_capacity(1000);
398
399 let mut limiter = ThrottleLimiter::new(config);
400
401 limiter.check("127.0.0.1", 500).await;
402
403 let new_config = ThrottleConfig::new()
404 .enable()
405 .with_global_limit(2000)
406 .with_bucket_capacity(2000);
407
408 limiter.update_config(new_config);
409
410 let tokens = limiter.global_tokens().await;
412 assert_eq!(tokens, 2000);
413 }
414
415 #[tokio::test]
416 async fn test_global_tokens() {
417 let config = ThrottleConfig::new()
418 .enable()
419 .with_global_limit(1000)
420 .with_bucket_capacity(1000);
421
422 let limiter = ThrottleLimiter::new(config);
423
424 assert_eq!(limiter.global_tokens().await, 1000);
425
426 limiter.consume("127.0.0.1", 300).await;
427
428 assert_eq!(limiter.global_tokens().await, 700);
429 }
430
431 #[tokio::test]
432 async fn test_ip_tokens() {
433 let config = ThrottleConfig::new()
434 .enable()
435 .with_per_ip_limit(1000)
436 .with_bucket_capacity(1000);
437
438 let limiter = ThrottleLimiter::new(config);
439
440 assert_eq!(limiter.ip_tokens("127.0.0.1").await, 1000);
442
443 limiter.consume("127.0.0.1", 300).await;
444
445 assert_eq!(limiter.ip_tokens("127.0.0.1").await, 700);
446 }
447
448 #[tokio::test]
449 async fn test_clone() {
450 let config = ThrottleConfig::new()
451 .enable()
452 .with_global_limit(1000)
453 .with_bucket_capacity(1000);
454
455 let limiter = ThrottleLimiter::new(config);
456 let cloned = limiter.clone();
457
458 limiter.consume("127.0.0.1", 500).await;
460 assert_eq!(cloned.global_tokens().await, 500);
461 }
462
463 #[test]
464 fn test_throttle_result_is_allowed() {
465 assert!(ThrottleResult::Allowed { remaining: 100 }.is_allowed());
466 assert!(ThrottleResult::Unlimited.is_allowed());
467 assert!(!ThrottleResult::Throttled { wait_ms: 100 }.is_allowed());
468 }
469}