primitives/utils/rate_limiter/
token_bucket.rs1use std::{
8 sync::{
9 atomic::{AtomicBool, AtomicU64, Ordering},
10 Arc,
11 },
12 time::Duration,
13};
14
15use tokio::{sync::RwLock, task::JoinHandle};
16
17use crate::utils::RateLimiter;
18
19#[derive(Debug, Clone, Copy)]
21pub struct TokenBucketConfig {
22 pub initial_tokens: u64,
24 pub tokens_per_interval: u64,
26 pub replenish_interval: Duration,
28 pub max_tokens: u64,
30}
31
32impl Default for TokenBucketConfig {
33 fn default() -> Self {
34 Self {
35 initial_tokens: 100,
36 tokens_per_interval: 10,
37 replenish_interval: Duration::from_secs(1),
38 max_tokens: 100,
39 }
40 }
41}
42
43pub struct TokenBucket {
45 tokens: Arc<AtomicU64>,
46 config: Arc<RwLock<TokenBucketConfig>>,
47 task_handle: Option<JoinHandle<()>>,
48 shutdown_flag: Arc<AtomicBool>,
49}
50
51impl std::fmt::Debug for TokenBucket {
52 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
53 f.debug_struct("TokenBucket")
54 .field("tokens", &self.tokens.load(Ordering::Acquire))
55 .field("config", &self.config.blocking_read())
56 .field("shutdown", &self.shutdown_flag.load(Ordering::Acquire))
57 .finish()
58 }
59}
60
61impl TokenBucket {
62 pub fn new(config: TokenBucketConfig) -> Self {
64 let tokens = Arc::new(AtomicU64::new(config.initial_tokens));
65 Self {
66 tokens,
67 config: Arc::new(RwLock::new(config)),
68 task_handle: None,
69 shutdown_flag: Arc::new(AtomicBool::new(false)),
70 }
71 }
72
73 pub fn initialize(config: TokenBucketConfig) -> Self {
76 let mut limiter = Self::new(config);
77 let handle = limiter.start();
78 limiter.task_handle = Some(handle);
79 limiter
80 }
81
82 pub async fn get_config(&self) -> TokenBucketConfig {
84 let config_guard = self.config.read().await;
85 *config_guard
86 }
87
88 pub async fn update_config(&self, new_config: TokenBucketConfig) {
91 let mut config_guard = self.config.write().await;
92 *config_guard = new_config;
93 }
94
95 pub async fn set_tokens_per_interval(&self, tokens_per_interval: u64) {
97 let mut config_guard = self.config.write().await;
98 config_guard.tokens_per_interval = tokens_per_interval;
99 }
100
101 pub async fn set_replenish_interval(&self, replenish_interval: Duration) {
103 let mut config_guard = self.config.write().await;
104 config_guard.replenish_interval = replenish_interval;
105 }
106}
107
108impl TokenBucket {
109 pub fn start(&mut self) -> JoinHandle<()> {
110 let tokens = self.tokens.clone();
111 let config = self.config.clone();
112 let shutdown_flag = self.shutdown_flag.clone();
113
114 tokio::spawn(async move {
115 loop {
116 if shutdown_flag.load(Ordering::Acquire) {
118 break;
119 }
120
121 let current_config = *config.read().await;
123
124 tokio::time::sleep(current_config.replenish_interval).await;
126
127 let _ = tokens.fetch_update(Ordering::AcqRel, Ordering::Acquire, |current| {
129 let new_value = std::cmp::min(
130 current.saturating_add(current_config.tokens_per_interval),
131 current_config.max_tokens,
132 );
133 Some(new_value)
134 });
135 }
136 })
137 }
138
139 pub async fn stop(&mut self) {
140 self.shutdown_flag.store(true, Ordering::Release);
141 if let Some(handle) = self.task_handle.take() {
142 let _ = handle.await;
143 }
144 }
145
146 pub fn get_tokens(&self) -> &Arc<AtomicU64> {
147 &self.tokens
148 }
149}
150
151impl RateLimiter for TokenBucket {
152 type TokenType = u64;
153
154 fn try_consume(&self, tokens: u64) -> bool {
157 self.tokens
158 .fetch_update(Ordering::AcqRel, Ordering::Acquire, |current| {
159 if current >= tokens {
160 Some(current - tokens)
161 } else {
162 None
163 }
164 })
165 .is_ok()
166 }
167
168 fn available_tokens(&self) -> u64 {
170 self.tokens.load(Ordering::Acquire)
171 }
172}
173
174impl Drop for TokenBucket {
175 fn drop(&mut self) {
176 self.shutdown_flag.store(true, Ordering::Release);
178 if let Some(handle) = self.task_handle.take() {
180 handle.abort();
181 }
182 }
183}
184
185#[cfg(test)]
186mod tests {
187 use std::time::Instant;
188
189 use tokio::time::sleep;
190
191 use super::*;
192
193 #[tokio::test]
194 async fn test_initial_tokens() {
195 let config = TokenBucketConfig {
196 initial_tokens: 50,
197 tokens_per_interval: 10,
198 replenish_interval: Duration::from_millis(100),
199 max_tokens: 100,
200 };
201
202 let limiter = TokenBucket::initialize(config);
203 assert_eq!(limiter.available_tokens(), 50);
204 }
205
206 #[tokio::test]
207 async fn test_try_consume_success() {
208 let config = TokenBucketConfig {
209 initial_tokens: 50,
210 tokens_per_interval: 10,
211 replenish_interval: Duration::from_secs(1),
212 max_tokens: 100,
213 };
214
215 let limiter = TokenBucket::initialize(config);
216
217 assert!(limiter.try_consume(20));
219 assert_eq!(limiter.available_tokens(), 30);
220
221 assert!(limiter.try_consume(30));
223 assert_eq!(limiter.available_tokens(), 0);
224 }
225
226 #[tokio::test]
227 async fn test_try_consume_failure() {
228 let config = TokenBucketConfig {
229 initial_tokens: 10,
230 tokens_per_interval: 5,
231 replenish_interval: Duration::from_secs(1),
232 max_tokens: 100,
233 };
234
235 let limiter = TokenBucket::initialize(config);
236
237 assert!(limiter.try_consume(5));
239 assert_eq!(limiter.available_tokens(), 5);
240
241 assert!(!limiter.try_consume(10));
243 assert_eq!(limiter.available_tokens(), 5);
244 }
245
246 #[tokio::test]
247 async fn test_token_replenishment() {
248 let config = TokenBucketConfig {
249 initial_tokens: 10,
250 tokens_per_interval: 20,
251 replenish_interval: Duration::from_millis(100),
252 max_tokens: 100,
253 };
254
255 let limiter = TokenBucket::initialize(config);
256
257 assert!(limiter.try_consume(10));
259 assert_eq!(limiter.available_tokens(), 0);
260
261 sleep(Duration::from_millis(150)).await;
263
264 let tokens = limiter.available_tokens();
266 assert!(tokens >= 20, "Expected at least 20 tokens, got {tokens}");
267 }
268
269 #[tokio::test]
270 async fn test_max_tokens_cap() {
271 let config = TokenBucketConfig {
272 initial_tokens: 90,
273 tokens_per_interval: 20,
274 replenish_interval: Duration::from_millis(100),
275 max_tokens: 100,
276 };
277
278 let limiter = TokenBucket::initialize(config);
279
280 sleep(Duration::from_millis(150)).await;
282
283 let tokens = limiter.available_tokens();
285 assert!(tokens <= 100, "Tokens exceeded max: {tokens}");
286 assert_eq!(tokens, 100, "Expected tokens to be capped at 100");
287 }
288
289 #[tokio::test]
290 async fn test_dynamic_config_update() {
291 let config = TokenBucketConfig {
292 initial_tokens: 10,
293 tokens_per_interval: 5,
294 replenish_interval: Duration::from_millis(100),
295 max_tokens: 50,
296 };
297
298 let limiter = TokenBucket::initialize(config);
299
300 assert!(limiter.try_consume(10));
302 assert_eq!(limiter.available_tokens(), 0);
303
304 let new_config = TokenBucketConfig {
306 initial_tokens: 10,
307 tokens_per_interval: 30,
308 replenish_interval: Duration::from_millis(100),
309 max_tokens: 50,
310 };
311 limiter.update_config(new_config).await;
312
313 sleep(Duration::from_millis(150)).await;
315
316 let tokens = limiter.available_tokens();
318 assert!(tokens >= 30, "Expected at least 30 tokens, got {tokens}");
319 }
320
321 #[tokio::test]
322 async fn test_concurrent_consumption() {
323 let config = TokenBucketConfig {
324 initial_tokens: 1000,
325 tokens_per_interval: 100,
326 replenish_interval: Duration::from_millis(100),
327 max_tokens: 1000,
328 };
329
330 let limiter = TokenBucket::initialize(config);
331 let tokens = limiter.get_tokens();
332 let mut handles = vec![];
333
334 for _ in 0..10 {
336 let tokens = tokens.clone();
337 let handle = tokio::spawn(async move {
338 for _ in 0..10 {
339 tokens.try_consume(10);
340 sleep(Duration::from_millis(5)).await;
341 }
342 });
343 handles.push(handle);
344 }
345
346 for handle in handles {
348 handle.await.unwrap();
349 }
350
351 let tokens = limiter.available_tokens();
353 assert!(tokens <= 1000, "Tokens exceeded max");
354 }
355
356 #[tokio::test]
357 async fn test_rate_limiting_behavior() {
358 let config = TokenBucketConfig {
359 initial_tokens: 5,
360 tokens_per_interval: 5,
361 replenish_interval: Duration::from_millis(100),
362 max_tokens: 10,
363 };
364
365 let limiter = TokenBucket::initialize(config);
366 let start = Instant::now();
367
368 assert!(limiter.try_consume(5));
370
371 assert!(!limiter.try_consume(5));
373
374 while limiter.available_tokens() < 5 {
376 sleep(Duration::from_millis(10)).await;
377 }
378
379 let elapsed = start.elapsed();
380
381 assert!(elapsed >= Duration::from_millis(100));
383
384 assert!(limiter.try_consume(5));
386 }
387
388 #[tokio::test]
389 async fn test_get_config() {
390 let config = TokenBucketConfig {
391 initial_tokens: 42,
392 tokens_per_interval: 13,
393 replenish_interval: Duration::from_millis(250),
394 max_tokens: 200,
395 };
396
397 let limiter = TokenBucket::initialize(config);
398 let retrieved_config = limiter.get_config().await;
399
400 assert_eq!(retrieved_config.initial_tokens, 42);
401 assert_eq!(retrieved_config.tokens_per_interval, 13);
402 assert_eq!(
403 retrieved_config.replenish_interval,
404 Duration::from_millis(250)
405 );
406 assert_eq!(retrieved_config.max_tokens, 200);
407 }
408}