1#![allow(
2 clippy::cast_possible_truncation,
3 clippy::cast_sign_loss,
4 clippy::cast_precision_loss,
5 clippy::cast_possible_wrap
6)]
7use serde::{Deserialize, Serialize};
34use std::collections::HashMap;
35use std::sync::{Arc, RwLock};
36use std::time::{Duration, Instant};
37
38#[derive(Debug, Clone, Serialize, Deserialize)]
40pub struct RateLimitConfig {
41 pub rate: f64,
43 pub burst: Option<u32>,
46 pub sliding_window: bool,
48 pub window_size: u64,
50}
51
52impl RateLimitConfig {
53 #[must_use]
59 pub fn new(rate: f64) -> Self {
60 Self {
61 rate,
62 burst: None,
63 sliding_window: false,
64 window_size: 1,
65 }
66 }
67
68 #[must_use]
70 pub fn with_burst(mut self, burst: u32) -> Self {
71 self.burst = Some(burst);
72 self
73 }
74
75 #[must_use]
77 pub fn with_sliding_window(mut self, window_size: u64) -> Self {
78 self.sliding_window = true;
79 self.window_size = window_size;
80 self
81 }
82
83 #[must_use]
85 #[inline]
86 pub fn effective_burst(&self) -> u32 {
87 self.burst.unwrap_or(self.rate.ceil() as u32)
88 }
89}
90
91impl Default for RateLimitConfig {
92 fn default() -> Self {
93 Self {
94 rate: 100.0, burst: None,
96 sliding_window: false,
97 window_size: 1,
98 }
99 }
100}
101
102pub trait RateLimiter: Send + Sync {
104 fn try_acquire(&mut self) -> bool;
108
109 fn acquire(&mut self) -> Duration;
113
114 fn time_until_available(&self) -> Duration;
116
117 fn available_permits(&self) -> u32;
119
120 fn reset(&mut self);
122
123 fn set_rate(&mut self, rate: f64);
125
126 fn config(&self) -> &RateLimitConfig;
128}
129
130#[derive(Debug)]
139pub struct TokenBucket {
140 config: RateLimitConfig,
141 tokens: f64,
143 last_refill: Instant,
145}
146
147impl TokenBucket {
148 #[must_use]
150 pub fn new(config: RateLimitConfig) -> Self {
151 let tokens = f64::from(config.effective_burst());
152 Self {
153 config,
154 tokens,
155 last_refill: Instant::now(),
156 }
157 }
158
159 #[must_use]
161 pub fn with_rate(rate: f64) -> Self {
162 Self::new(RateLimitConfig::new(rate))
163 }
164
165 #[inline]
167 fn refill(&mut self) {
168 let now = Instant::now();
169 let elapsed = now.duration_since(self.last_refill);
170 let new_tokens = elapsed.as_secs_f64() * self.config.rate;
171 let max_tokens = f64::from(self.config.effective_burst());
172 self.tokens = (self.tokens + new_tokens).min(max_tokens);
173 self.last_refill = now;
174 }
175}
176
177impl RateLimiter for TokenBucket {
178 fn try_acquire(&mut self) -> bool {
179 self.refill();
180 if self.tokens >= 1.0 {
181 self.tokens -= 1.0;
182 true
183 } else {
184 false
185 }
186 }
187
188 fn acquire(&mut self) -> Duration {
189 let start = Instant::now();
190 while !self.try_acquire() {
191 let wait_time = self.time_until_available();
192 if wait_time > Duration::ZERO {
193 std::thread::sleep(wait_time);
194 }
195 }
196 start.elapsed()
197 }
198
199 fn time_until_available(&self) -> Duration {
200 if self.tokens >= 1.0 {
201 Duration::ZERO
202 } else {
203 let tokens_needed = 1.0 - self.tokens;
204 let seconds = tokens_needed / self.config.rate;
205 Duration::from_secs_f64(seconds)
206 }
207 }
208
209 fn available_permits(&self) -> u32 {
210 self.tokens.floor() as u32
211 }
212
213 fn reset(&mut self) {
214 self.tokens = f64::from(self.config.effective_burst());
215 self.last_refill = Instant::now();
216 }
217
218 fn set_rate(&mut self, rate: f64) {
219 self.config.rate = rate;
220 }
221
222 fn config(&self) -> &RateLimitConfig {
223 &self.config
224 }
225}
226
227#[derive(Debug)]
232pub struct SlidingWindow {
233 config: RateLimitConfig,
234 timestamps: Vec<Instant>,
236}
237
238impl SlidingWindow {
239 #[must_use]
241 pub fn new(config: RateLimitConfig) -> Self {
242 Self {
243 config,
244 timestamps: Vec::new(),
245 }
246 }
247
248 #[must_use]
250 pub fn with_rate(rate: f64, window_size: u64) -> Self {
251 let config = RateLimitConfig::new(rate).with_sliding_window(window_size);
252 Self::new(config)
253 }
254
255 #[inline]
257 fn cleanup(&mut self) {
258 let window = Duration::from_secs(self.config.window_size);
259 let cutoff = Instant::now()
260 .checked_sub(window)
261 .expect("window duration should be valid for subtraction");
262 self.timestamps.retain(|&t| t > cutoff);
263 }
264
265 #[inline]
267 fn max_executions(&self) -> usize {
268 (self.config.rate * self.config.window_size as f64).ceil() as usize
269 }
270}
271
272impl RateLimiter for SlidingWindow {
273 fn try_acquire(&mut self) -> bool {
274 self.cleanup();
275 if self.timestamps.len() < self.max_executions() {
276 self.timestamps.push(Instant::now());
277 true
278 } else {
279 false
280 }
281 }
282
283 fn acquire(&mut self) -> Duration {
284 let start = Instant::now();
285 while !self.try_acquire() {
286 let wait_time = self.time_until_available();
287 if wait_time > Duration::ZERO {
288 std::thread::sleep(wait_time);
289 }
290 }
291 start.elapsed()
292 }
293
294 fn time_until_available(&self) -> Duration {
295 if self.timestamps.len() < self.max_executions() {
296 Duration::ZERO
297 } else if let Some(&oldest) = self.timestamps.first() {
298 let window = Duration::from_secs(self.config.window_size);
299 let expires = oldest + window;
300 let now = Instant::now();
301 if expires > now {
302 expires - now
303 } else {
304 Duration::ZERO
305 }
306 } else {
307 Duration::ZERO
308 }
309 }
310
311 fn available_permits(&self) -> u32 {
312 let max = self.max_executions();
313 let current = self.timestamps.len();
314 (max.saturating_sub(current)) as u32
315 }
316
317 fn reset(&mut self) {
318 self.timestamps.clear();
319 }
320
321 fn set_rate(&mut self, rate: f64) {
322 self.config.rate = rate;
323 }
324
325 fn config(&self) -> &RateLimitConfig {
326 &self.config
327 }
328}
329
330#[derive(Debug)]
335pub struct TaskRateLimiter {
336 limiters: HashMap<String, TokenBucket>,
338 default_config: Option<RateLimitConfig>,
340}
341
342impl TaskRateLimiter {
343 #[must_use]
345 pub fn new() -> Self {
346 Self {
347 limiters: HashMap::new(),
348 default_config: None,
349 }
350 }
351
352 #[must_use]
354 pub fn with_default(config: RateLimitConfig) -> Self {
355 Self {
356 limiters: HashMap::new(),
357 default_config: Some(config),
358 }
359 }
360
361 pub fn set_task_rate(&mut self, task_name: impl Into<String>, config: RateLimitConfig) {
363 let name = task_name.into();
364 self.limiters.insert(name, TokenBucket::new(config));
365 }
366
367 pub fn remove_task_rate(&mut self, task_name: &str) {
369 self.limiters.remove(task_name);
370 }
371
372 pub fn try_acquire(&mut self, task_name: &str) -> bool {
376 if let Some(limiter) = self.limiters.get_mut(task_name) {
377 limiter.try_acquire()
378 } else if let Some(ref config) = self.default_config {
379 let mut limiter = TokenBucket::new(config.clone());
381 let result = limiter.try_acquire();
382 self.limiters.insert(task_name.to_string(), limiter);
383 result
384 } else {
385 true
387 }
388 }
389
390 #[must_use]
392 pub fn time_until_available(&self, task_name: &str) -> Duration {
393 if let Some(limiter) = self.limiters.get(task_name) {
394 limiter.time_until_available()
395 } else {
396 Duration::ZERO
397 }
398 }
399
400 #[inline]
402 #[must_use]
403 pub fn has_rate_limit(&self, task_name: &str) -> bool {
404 self.limiters.contains_key(task_name) || self.default_config.is_some()
405 }
406
407 #[inline]
409 pub fn get_rate_limit(&self, task_name: &str) -> Option<&RateLimitConfig> {
410 self.limiters
411 .get(task_name)
412 .map(RateLimiter::config)
413 .or(self.default_config.as_ref())
414 }
415
416 pub fn reset_all(&mut self) {
418 for limiter in self.limiters.values_mut() {
419 limiter.reset();
420 }
421 }
422}
423
424impl Default for TaskRateLimiter {
425 fn default() -> Self {
426 Self::new()
427 }
428}
429
430#[derive(Debug, Clone)]
434pub struct WorkerRateLimiter {
435 inner: Arc<RwLock<TaskRateLimiter>>,
436}
437
438impl WorkerRateLimiter {
439 #[must_use]
441 pub fn new() -> Self {
442 Self {
443 inner: Arc::new(RwLock::new(TaskRateLimiter::new())),
444 }
445 }
446
447 #[must_use]
449 pub fn with_default(config: RateLimitConfig) -> Self {
450 Self {
451 inner: Arc::new(RwLock::new(TaskRateLimiter::with_default(config))),
452 }
453 }
454
455 pub fn set_task_rate(&self, task_name: impl Into<String>, config: RateLimitConfig) {
457 if let Ok(mut guard) = self.inner.write() {
458 guard.set_task_rate(task_name, config);
459 }
460 }
461
462 pub fn remove_task_rate(&self, task_name: &str) {
464 if let Ok(mut guard) = self.inner.write() {
465 guard.remove_task_rate(task_name);
466 }
467 }
468
469 #[must_use]
471 pub fn try_acquire(&self, task_name: &str) -> bool {
472 if let Ok(mut guard) = self.inner.write() {
473 guard.try_acquire(task_name)
474 } else {
475 true
477 }
478 }
479
480 #[must_use]
482 pub fn time_until_available(&self, task_name: &str) -> Duration {
483 if let Ok(guard) = self.inner.read() {
484 guard.time_until_available(task_name)
485 } else {
486 Duration::ZERO
487 }
488 }
489
490 #[inline]
492 #[must_use]
493 pub fn has_rate_limit(&self, task_name: &str) -> bool {
494 if let Ok(guard) = self.inner.read() {
495 guard.has_rate_limit(task_name)
496 } else {
497 false
498 }
499 }
500
501 pub fn reset_all(&self) {
503 if let Ok(mut guard) = self.inner.write() {
504 guard.reset_all();
505 }
506 }
507}
508
509impl Default for WorkerRateLimiter {
510 fn default() -> Self {
511 Self::new()
512 }
513}
514
515#[must_use]
517pub fn create_rate_limiter(config: RateLimitConfig) -> Box<dyn RateLimiter> {
518 if config.sliding_window {
519 Box::new(SlidingWindow::new(config))
520 } else {
521 Box::new(TokenBucket::new(config))
522 }
523}
524
525use async_trait::async_trait;
560
561#[async_trait]
566pub trait DistributedRateLimiter: Send + Sync {
567 async fn try_acquire(&self) -> crate::Result<bool>;
572
573 async fn time_until_available(&self) -> crate::Result<Duration>;
577
578 async fn available_permits(&self) -> crate::Result<u32>;
582
583 async fn reset(&self) -> crate::Result<()>;
587
588 async fn set_rate(&self, rate: f64) -> crate::Result<()>;
593
594 fn config(&self) -> &RateLimitConfig;
596
597 fn backend_name(&self) -> &str;
599}
600
601#[derive(Debug, Clone)]
606pub struct DistributedRateLimiterState {
607 pub key: String,
609 pub config: RateLimitConfig,
611 pub fallback: Arc<RwLock<TokenBucket>>,
613}
614
615impl DistributedRateLimiterState {
616 #[must_use]
623 pub fn new(key: String, config: RateLimitConfig) -> Self {
624 let fallback = Arc::new(RwLock::new(TokenBucket::new(config.clone())));
625 Self {
626 key,
627 config,
628 fallback,
629 }
630 }
631
632 #[inline]
636 #[must_use]
637 pub fn token_key(&self) -> String {
638 format!("{}:tokens", self.key)
639 }
640
641 #[inline]
645 #[must_use]
646 pub fn refill_key(&self) -> String {
647 format!("{}:refill", self.key)
648 }
649
650 #[inline]
654 #[must_use]
655 pub fn window_key(&self) -> String {
656 format!("{}:window", self.key)
657 }
658
659 fn try_acquire_fallback(&self) -> bool {
661 if let Ok(mut guard) = self.fallback.write() {
662 guard.try_acquire()
663 } else {
664 true
666 }
667 }
668}
669
670#[derive(Debug, Clone)]
723pub struct DistributedTokenBucketSpec {
724 state: DistributedRateLimiterState,
725}
726
727impl DistributedTokenBucketSpec {
728 #[must_use]
733 pub fn new(key: String, config: RateLimitConfig) -> Self {
734 Self {
735 state: DistributedRateLimiterState::new(key, config),
736 }
737 }
738
739 #[must_use]
744 pub fn lua_acquire_script() -> &'static str {
745 r"
746 local tokens_key = KEYS[1]
747 local refill_key = KEYS[2]
748 local rate = tonumber(ARGV[1])
749 local burst = tonumber(ARGV[2])
750 local now = tonumber(ARGV[3])
751 local ttl = tonumber(ARGV[4])
752
753 local last_refill = redis.call('GET', refill_key)
754 local tokens = redis.call('GET', tokens_key)
755
756 if not tokens then
757 tokens = burst
758 else
759 tokens = tonumber(tokens)
760 end
761
762 if last_refill then
763 local elapsed = (now - tonumber(last_refill)) / 1000.0
764 tokens = math.min(tokens + elapsed * rate, burst)
765 end
766
767 if tokens >= 1.0 then
768 tokens = tokens - 1.0
769 redis.call('SET', tokens_key, tostring(tokens), 'EX', ttl)
770 redis.call('SET', refill_key, tostring(now), 'EX', ttl)
771 return {1, tokens}
772 else
773 redis.call('SET', tokens_key, tostring(tokens), 'EX', ttl)
774 redis.call('SET', refill_key, tostring(now), 'EX', ttl)
775 return {0, tokens}
776 end
777 "
778 }
779
780 #[must_use]
782 pub fn lua_available_script() -> &'static str {
783 r"
784 local tokens_key = KEYS[1]
785 local refill_key = KEYS[2]
786 local rate = tonumber(ARGV[1])
787 local burst = tonumber(ARGV[2])
788 local now = tonumber(ARGV[3])
789
790 local last_refill = redis.call('GET', refill_key)
791 local tokens = redis.call('GET', tokens_key)
792
793 if not tokens then
794 return burst
795 else
796 tokens = tonumber(tokens)
797 end
798
799 if last_refill then
800 local elapsed = (now - tonumber(last_refill)) / 1000.0
801 tokens = math.min(tokens + elapsed * rate, burst)
802 end
803
804 return math.floor(tokens)
805 "
806 }
807
808 #[inline]
810 #[must_use]
811 pub fn state(&self) -> &DistributedRateLimiterState {
812 &self.state
813 }
814
815 #[must_use]
817 pub fn try_acquire_fallback(&self) -> bool {
818 self.state.try_acquire_fallback()
819 }
820}
821
822#[derive(Debug, Clone)]
859pub struct DistributedSlidingWindowSpec {
860 state: DistributedRateLimiterState,
861}
862
863impl DistributedSlidingWindowSpec {
864 #[must_use]
866 pub fn new(key: String, config: RateLimitConfig) -> Self {
867 Self {
868 state: DistributedRateLimiterState::new(key, config),
869 }
870 }
871
872 #[must_use]
874 pub fn lua_acquire_script() -> &'static str {
875 r"
876 local window_key = KEYS[1]
877 local now = tonumber(ARGV[1])
878 local window_size = tonumber(ARGV[2])
879 local max_count = tonumber(ARGV[3])
880 local uuid = ARGV[4]
881
882 local cutoff = now - window_size * 1000
883 redis.call('ZREMRANGEBYSCORE', window_key, '-inf', cutoff)
884
885 local count = redis.call('ZCARD', window_key)
886 if count < max_count then
887 redis.call('ZADD', window_key, now, uuid)
888 redis.call('EXPIRE', window_key, window_size * 2)
889 return {1, max_count - count - 1}
890 else
891 return {0, 0}
892 end
893 "
894 }
895
896 #[must_use]
898 pub fn lua_available_script() -> &'static str {
899 r"
900 local window_key = KEYS[1]
901 local now = tonumber(ARGV[1])
902 local window_size = tonumber(ARGV[2])
903 local max_count = tonumber(ARGV[3])
904
905 local cutoff = now - window_size * 1000
906 redis.call('ZREMRANGEBYSCORE', window_key, '-inf', cutoff)
907
908 local count = redis.call('ZCARD', window_key)
909 return math.max(0, max_count - count)
910 "
911 }
912
913 #[must_use]
915 pub fn lua_time_until_script() -> &'static str {
916 r"
917 local window_key = KEYS[1]
918 local now = tonumber(ARGV[1])
919 local window_size = tonumber(ARGV[2])
920 local max_count = tonumber(ARGV[3])
921
922 local cutoff = now - window_size * 1000
923 redis.call('ZREMRANGEBYSCORE', window_key, '-inf', cutoff)
924
925 local count = redis.call('ZCARD', window_key)
926 if count < max_count then
927 return 0
928 else
929 local oldest = redis.call('ZRANGE', window_key, 0, 0, 'WITHSCORES')
930 if #oldest >= 2 then
931 local oldest_timestamp = tonumber(oldest[2])
932 local expires = oldest_timestamp + window_size * 1000
933 return math.max(0, expires - now)
934 else
935 return 0
936 end
937 #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss, clippy::cast_precision_loss)]
938 "
939 }
940
941 #[must_use]
943 #[inline]
944 pub fn max_executions(&self) -> usize {
945 (self.state.config.rate * self.state.config.window_size as f64).ceil() as usize
946 }
947
948 #[inline]
950 #[must_use]
951 pub fn state(&self) -> &DistributedRateLimiterState {
952 &self.state
953 }
954
955 #[must_use]
957 pub fn try_acquire_fallback(&self) -> bool {
958 self.state.try_acquire_fallback()
959 }
960}
961
962#[derive(Debug, Clone)]
993pub struct DistributedRateLimiterCoordinator {
994 namespace: String,
996 token_buckets: Arc<RwLock<HashMap<String, DistributedTokenBucketSpec>>>,
998 sliding_windows: Arc<RwLock<HashMap<String, DistributedSlidingWindowSpec>>>,
1000 default_config: Option<RateLimitConfig>,
1002}
1003
1004impl DistributedRateLimiterCoordinator {
1005 pub fn new(namespace: impl Into<String>) -> Self {
1011 Self {
1012 namespace: namespace.into(),
1013 token_buckets: Arc::new(RwLock::new(HashMap::new())),
1014 sliding_windows: Arc::new(RwLock::new(HashMap::new())),
1015 default_config: None,
1016 }
1017 }
1018
1019 pub fn with_default(namespace: impl Into<String>, config: RateLimitConfig) -> Self {
1021 Self {
1022 namespace: namespace.into(),
1023 token_buckets: Arc::new(RwLock::new(HashMap::new())),
1024 sliding_windows: Arc::new(RwLock::new(HashMap::new())),
1025 default_config: Some(config),
1026 }
1027 }
1028
1029 pub fn set_task_rate(&self, task_name: impl Into<String>, config: RateLimitConfig) {
1034 let name = task_name.into();
1035 let key = format!("{}:ratelimit:{}", self.namespace, name);
1036
1037 if config.sliding_window {
1038 if let Ok(mut guard) = self.sliding_windows.write() {
1039 guard.insert(name.clone(), DistributedSlidingWindowSpec::new(key, config));
1040 }
1041 } else if let Ok(mut guard) = self.token_buckets.write() {
1042 guard.insert(name.clone(), DistributedTokenBucketSpec::new(key, config));
1043 }
1044 }
1045
1046 pub fn remove_task_rate(&self, task_name: &str) {
1048 if let Ok(mut guard) = self.token_buckets.write() {
1049 guard.remove(task_name);
1050 }
1051 if let Ok(mut guard) = self.sliding_windows.write() {
1052 guard.remove(task_name);
1053 }
1054 }
1055
1056 #[inline]
1058 #[must_use]
1059 pub fn get_token_bucket_spec(&self, task_name: &str) -> Option<DistributedTokenBucketSpec> {
1060 if let Ok(guard) = self.token_buckets.read() {
1061 guard.get(task_name).cloned()
1062 } else {
1063 None
1064 }
1065 }
1066
1067 #[inline]
1069 #[must_use]
1070 pub fn get_sliding_window_spec(&self, task_name: &str) -> Option<DistributedSlidingWindowSpec> {
1071 if let Ok(guard) = self.sliding_windows.read() {
1072 guard.get(task_name).cloned()
1073 } else {
1074 None
1075 }
1076 }
1077
1078 #[inline]
1080 #[must_use]
1081 pub fn has_rate_limit(&self, task_name: &str) -> bool {
1082 let has_bucket = if let Ok(guard) = self.token_buckets.read() {
1083 guard.contains_key(task_name)
1084 } else {
1085 false
1086 };
1087
1088 let has_window = if let Ok(guard) = self.sliding_windows.read() {
1089 guard.contains_key(task_name)
1090 } else {
1091 false
1092 };
1093
1094 has_bucket || has_window || self.default_config.is_some()
1095 }
1096
1097 #[must_use]
1101 pub fn try_acquire_fallback(&self, task_name: &str) -> bool {
1102 if let Some(spec) = self.get_token_bucket_spec(task_name) {
1104 return spec.try_acquire_fallback();
1105 }
1106
1107 if let Some(spec) = self.get_sliding_window_spec(task_name) {
1109 return spec.try_acquire_fallback();
1110 }
1111
1112 if let Some(ref config) = self.default_config {
1114 let key = format!("{}:ratelimit:{}", self.namespace, task_name);
1115 let spec = DistributedTokenBucketSpec::new(key, config.clone());
1116 return spec.try_acquire_fallback();
1117 }
1118
1119 true
1121 }
1122
1123 #[must_use]
1125 pub fn redis_key(&self, task_name: &str) -> String {
1126 format!("{}:ratelimit:{}", self.namespace, task_name)
1127 }
1128}
1129
1130#[cfg(test)]
1131mod tests {
1132 use super::*;
1133 use std::thread;
1134 use std::time::Duration;
1135
1136 #[test]
1137 fn test_token_bucket_basic() {
1138 let config = RateLimitConfig::new(10.0).with_burst(5);
1139 let mut limiter = TokenBucket::new(config);
1140
1141 for _ in 0..5 {
1143 assert!(limiter.try_acquire());
1144 }
1145
1146 assert!(!limiter.try_acquire());
1148 }
1149
1150 #[test]
1151 fn test_token_bucket_refill() {
1152 let config = RateLimitConfig::new(100.0).with_burst(10);
1153 let mut limiter = TokenBucket::new(config);
1154
1155 for _ in 0..10 {
1157 assert!(limiter.try_acquire());
1158 }
1159 assert!(!limiter.try_acquire());
1160
1161 thread::sleep(Duration::from_millis(15));
1163
1164 assert!(limiter.try_acquire());
1166 }
1167
1168 #[test]
1169 fn test_sliding_window_basic() {
1170 let config = RateLimitConfig::new(5.0).with_sliding_window(1);
1171 let mut limiter = SlidingWindow::new(config);
1172
1173 for _ in 0..5 {
1175 assert!(limiter.try_acquire());
1176 }
1177
1178 assert!(!limiter.try_acquire());
1180 }
1181
1182 #[test]
1183 fn test_task_rate_limiter() {
1184 let mut manager = TaskRateLimiter::new();
1185
1186 manager.set_task_rate("task_a", RateLimitConfig::new(10.0).with_burst(2));
1188
1189 assert!(manager.try_acquire("task_a"));
1191 assert!(manager.try_acquire("task_a"));
1192 assert!(!manager.try_acquire("task_a"));
1193
1194 assert!(manager.try_acquire("task_b"));
1196 assert!(manager.try_acquire("task_b"));
1197 assert!(manager.try_acquire("task_b"));
1198 }
1199
1200 #[test]
1201 fn test_task_rate_limiter_default() {
1202 let mut manager = TaskRateLimiter::with_default(RateLimitConfig::new(10.0).with_burst(2));
1203
1204 assert!(manager.try_acquire("task_a"));
1206 assert!(manager.try_acquire("task_a"));
1207 assert!(!manager.try_acquire("task_a"));
1208
1209 assert!(manager.try_acquire("task_b"));
1210 assert!(manager.try_acquire("task_b"));
1211 assert!(!manager.try_acquire("task_b"));
1212 }
1213
1214 #[test]
1215 fn test_worker_rate_limiter_thread_safe() {
1216 let limiter = WorkerRateLimiter::new();
1217 limiter.set_task_rate("task_a", RateLimitConfig::new(0.1).with_burst(10));
1220
1221 let limiter_clone = limiter.clone();
1222
1223 let handles: Vec<_> = (0..4)
1225 .map(|_| {
1226 let l = limiter_clone.clone();
1227 thread::spawn(move || {
1228 let mut count = 0;
1229 for _ in 0..5 {
1230 if l.try_acquire("task_a") {
1231 count += 1;
1232 }
1233 }
1234 count
1235 })
1236 })
1237 .collect();
1238
1239 let total: usize = handles.into_iter().map(|h| h.join().unwrap()).sum();
1240
1241 assert!(total <= 10);
1243 }
1244
1245 #[test]
1246 fn test_rate_limit_config_serialization() {
1247 let config = RateLimitConfig::new(50.0)
1248 .with_burst(100)
1249 .with_sliding_window(10);
1250
1251 let json = serde_json::to_string(&config).unwrap();
1252 let parsed: RateLimitConfig = serde_json::from_str(&json).unwrap();
1253
1254 assert!((parsed.rate - 50.0).abs() < f64::EPSILON);
1255 assert_eq!(parsed.burst, Some(100));
1256 assert!(parsed.sliding_window);
1257 assert_eq!(parsed.window_size, 10);
1258 }
1259
1260 #[test]
1261 fn test_time_until_available() {
1262 let config = RateLimitConfig::new(10.0).with_burst(1);
1263 let mut limiter = TokenBucket::new(config);
1264
1265 assert!(limiter.try_acquire());
1267
1268 let wait_time = limiter.time_until_available();
1270 assert!(wait_time > Duration::ZERO);
1271 assert!(wait_time <= Duration::from_millis(150));
1272 }
1273
1274 #[test]
1275 fn test_reset() {
1276 let config = RateLimitConfig::new(10.0).with_burst(5);
1277 let mut limiter = TokenBucket::new(config);
1278
1279 for _ in 0..5 {
1281 limiter.try_acquire();
1282 }
1283 assert!(!limiter.try_acquire());
1284
1285 limiter.reset();
1287 assert!(limiter.try_acquire());
1288 }
1289
1290 #[test]
1291 fn test_set_rate() {
1292 let config = RateLimitConfig::new(10.0).with_burst(10);
1293 let mut limiter = TokenBucket::new(config);
1294
1295 limiter.set_rate(100.0);
1297 assert!((limiter.config().rate - 100.0).abs() < f64::EPSILON);
1298 }
1299
1300 #[test]
1301 fn test_create_rate_limiter() {
1302 let config = RateLimitConfig::new(10.0);
1304 let mut limiter = create_rate_limiter(config);
1305 assert!(limiter.try_acquire());
1306
1307 let config = RateLimitConfig::new(10.0).with_sliding_window(1);
1309 let mut limiter = create_rate_limiter(config);
1310 assert!(limiter.try_acquire());
1311 }
1312}