Skip to main content

primitives/utils/rate_limiter/
token_bucket.rs

1//! Token Bucket Rate Limiter
2//!
3//!  This module provides an implementation of a token bucket rate limiter.
4//!  Tokens are replenished at a fixed rate, and actions can only proceed if enough
5//!  tokens are available.
6
7use std::{
8    sync::{
9        atomic::{AtomicBool, AtomicU64, Ordering},
10        Arc,
11    },
12    time::Duration,
13};
14
15use tokio::{sync::RwLock, task::JoinHandle};
16
17use crate::utils::RateLimiter;
18
19/// Configuration for the Tocken Bucket Limiter
20#[derive(Debug, Clone, Copy)]
21pub struct TokenBucketConfig {
22    /// Initial number of tokens in the bucket
23    pub initial_tokens: u64,
24    /// Number of tokens to add per replenishment interval
25    pub tokens_per_interval: u64,
26    /// How often to replenish tokens
27    pub replenish_interval: Duration,
28    /// Maximum number of tokens the bucket can hold (capacity)
29    pub max_tokens: u64,
30}
31
32impl Default for TokenBucketConfig {
33    fn default() -> Self {
34        Self {
35            initial_tokens: 100,
36            tokens_per_interval: 10,
37            replenish_interval: Duration::from_secs(1),
38            max_tokens: 100,
39        }
40    }
41}
42
43/// Token bucket rate limiter implementation
44pub struct TokenBucket {
45    tokens: Arc<AtomicU64>,
46    config: Arc<RwLock<TokenBucketConfig>>,
47    task_handle: Option<JoinHandle<()>>,
48    shutdown_flag: Arc<AtomicBool>,
49}
50
51impl std::fmt::Debug for TokenBucket {
52    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
53        f.debug_struct("TokenBucket")
54            .field("tokens", &self.tokens.load(Ordering::Acquire))
55            .field("config", &self.config.blocking_read())
56            .field("shutdown", &self.shutdown_flag.load(Ordering::Acquire))
57            .finish()
58    }
59}
60
61impl TokenBucket {
62    /// Create a new token bucket rate limiter with the given configuration
63    pub fn new(config: TokenBucketConfig) -> Self {
64        let tokens = Arc::new(AtomicU64::new(config.initial_tokens));
65        Self {
66            tokens,
67            config: Arc::new(RwLock::new(config)),
68            task_handle: None,
69            shutdown_flag: Arc::new(AtomicBool::new(false)),
70        }
71    }
72
73    /// Create a new token bucket rate limiter with the given configuration, and
74    /// start the replenishment task
75    pub fn initialize(config: TokenBucketConfig) -> Self {
76        let mut limiter = Self::new(config);
77        let handle = limiter.start();
78        limiter.task_handle = Some(handle);
79        limiter
80    }
81
82    /// Get the current configuration
83    pub async fn get_config(&self) -> TokenBucketConfig {
84        let config_guard = self.config.read().await;
85        *config_guard
86    }
87
88    /// Update the rate limiter configuration dynamically
89    /// The new configuration will take effect on the next replenishment cycle
90    pub async fn update_config(&self, new_config: TokenBucketConfig) {
91        let mut config_guard = self.config.write().await;
92        *config_guard = new_config;
93    }
94
95    /// Set the token refreshment rate
96    pub async fn set_tokens_per_interval(&self, tokens_per_interval: u64) {
97        let mut config_guard = self.config.write().await;
98        config_guard.tokens_per_interval = tokens_per_interval;
99    }
100
101    /// Set the replenishment interval
102    pub async fn set_replenish_interval(&self, replenish_interval: Duration) {
103        let mut config_guard = self.config.write().await;
104        config_guard.replenish_interval = replenish_interval;
105    }
106}
107
108impl TokenBucket {
109    pub fn start(&mut self) -> JoinHandle<()> {
110        let tokens = self.tokens.clone();
111        let config = self.config.clone();
112        let shutdown_flag = self.shutdown_flag.clone();
113
114        tokio::spawn(async move {
115            loop {
116                // Check shutdown flag
117                if shutdown_flag.load(Ordering::Acquire) {
118                    break;
119                }
120
121                // Read current config
122                let current_config = *config.read().await;
123
124                // Wait for the replenishment interval
125                tokio::time::sleep(current_config.replenish_interval).await;
126
127                // Replenish tokens atomically (add tokens up to max_tokens)
128                let _ = tokens.fetch_update(Ordering::AcqRel, Ordering::Acquire, |current| {
129                    let new_value = std::cmp::min(
130                        current.saturating_add(current_config.tokens_per_interval),
131                        current_config.max_tokens,
132                    );
133                    Some(new_value)
134                });
135            }
136        })
137    }
138
139    pub async fn stop(&mut self) {
140        self.shutdown_flag.store(true, Ordering::Release);
141        if let Some(handle) = self.task_handle.take() {
142            let _ = handle.await;
143        }
144    }
145
146    pub fn get_tokens(&self) -> &Arc<AtomicU64> {
147        &self.tokens
148    }
149}
150
151impl RateLimiter for TokenBucket {
152    type TokenType = u64;
153
154    /// Try to consume a specified number of tokens
155    /// Returns true if successful, false if not enough tokens available
156    fn try_consume(&self, tokens: u64) -> bool {
157        self.tokens
158            .fetch_update(Ordering::AcqRel, Ordering::Acquire, |current| {
159                if current >= tokens {
160                    Some(current - tokens)
161                } else {
162                    None
163                }
164            })
165            .is_ok()
166    }
167
168    /// Get the number of available tokens
169    fn available_tokens(&self) -> u64 {
170        self.tokens.load(Ordering::Acquire)
171    }
172}
173
174impl Drop for TokenBucket {
175    fn drop(&mut self) {
176        // Signal shutdown
177        self.shutdown_flag.store(true, Ordering::Release);
178        // Abort the task if it's still running
179        if let Some(handle) = self.task_handle.take() {
180            handle.abort();
181        }
182    }
183}
184
185#[cfg(test)]
186mod tests {
187    use std::time::Instant;
188
189    use tokio::time::sleep;
190
191    use super::*;
192
193    #[tokio::test]
194    async fn test_initial_tokens() {
195        let config = TokenBucketConfig {
196            initial_tokens: 50,
197            tokens_per_interval: 10,
198            replenish_interval: Duration::from_millis(100),
199            max_tokens: 100,
200        };
201
202        let limiter = TokenBucket::initialize(config);
203        assert_eq!(limiter.available_tokens(), 50);
204    }
205
206    #[tokio::test]
207    async fn test_try_consume_success() {
208        let config = TokenBucketConfig {
209            initial_tokens: 50,
210            tokens_per_interval: 10,
211            replenish_interval: Duration::from_secs(1),
212            max_tokens: 100,
213        };
214
215        let limiter = TokenBucket::initialize(config);
216
217        // Should succeed
218        assert!(limiter.try_consume(20));
219        assert_eq!(limiter.available_tokens(), 30);
220
221        // Should succeed again
222        assert!(limiter.try_consume(30));
223        assert_eq!(limiter.available_tokens(), 0);
224    }
225
226    #[tokio::test]
227    async fn test_try_consume_failure() {
228        let config = TokenBucketConfig {
229            initial_tokens: 10,
230            tokens_per_interval: 5,
231            replenish_interval: Duration::from_secs(1),
232            max_tokens: 100,
233        };
234
235        let limiter = TokenBucket::initialize(config);
236
237        // Should succeed
238        assert!(limiter.try_consume(5));
239        assert_eq!(limiter.available_tokens(), 5);
240
241        // Should fail (not enough tokens)
242        assert!(!limiter.try_consume(10));
243        assert_eq!(limiter.available_tokens(), 5);
244    }
245
246    #[tokio::test]
247    async fn test_token_replenishment() {
248        let config = TokenBucketConfig {
249            initial_tokens: 10,
250            tokens_per_interval: 20,
251            replenish_interval: Duration::from_millis(100),
252            max_tokens: 100,
253        };
254
255        let limiter = TokenBucket::initialize(config);
256
257        // Consume all tokens
258        assert!(limiter.try_consume(10));
259        assert_eq!(limiter.available_tokens(), 0);
260
261        // Wait for replenishment
262        sleep(Duration::from_millis(150)).await;
263
264        // Tokens should be replenished
265        let tokens = limiter.available_tokens();
266        assert!(tokens >= 20, "Expected at least 20 tokens, got {tokens}");
267    }
268
269    #[tokio::test]
270    async fn test_max_tokens_cap() {
271        let config = TokenBucketConfig {
272            initial_tokens: 90,
273            tokens_per_interval: 20,
274            replenish_interval: Duration::from_millis(100),
275            max_tokens: 100,
276        };
277
278        let limiter = TokenBucket::initialize(config);
279
280        // Wait for replenishment
281        sleep(Duration::from_millis(150)).await;
282
283        // Tokens should not exceed max_tokens
284        let tokens = limiter.available_tokens();
285        assert!(tokens <= 100, "Tokens exceeded max: {tokens}");
286        assert_eq!(tokens, 100, "Expected tokens to be capped at 100");
287    }
288
289    #[tokio::test]
290    async fn test_dynamic_config_update() {
291        let config = TokenBucketConfig {
292            initial_tokens: 10,
293            tokens_per_interval: 5,
294            replenish_interval: Duration::from_millis(100),
295            max_tokens: 50,
296        };
297
298        let limiter = TokenBucket::initialize(config);
299
300        // Consume some tokens
301        assert!(limiter.try_consume(10));
302        assert_eq!(limiter.available_tokens(), 0);
303
304        // Update config with faster replenishment
305        let new_config = TokenBucketConfig {
306            initial_tokens: 10,
307            tokens_per_interval: 30,
308            replenish_interval: Duration::from_millis(100),
309            max_tokens: 50,
310        };
311        limiter.update_config(new_config).await;
312
313        // Wait for replenishment with new config
314        sleep(Duration::from_millis(150)).await;
315
316        // Should have more tokens now
317        let tokens = limiter.available_tokens();
318        assert!(tokens >= 30, "Expected at least 30 tokens, got {tokens}");
319    }
320
321    #[tokio::test]
322    async fn test_concurrent_consumption() {
323        let config = TokenBucketConfig {
324            initial_tokens: 1000,
325            tokens_per_interval: 100,
326            replenish_interval: Duration::from_millis(100),
327            max_tokens: 1000,
328        };
329
330        let limiter = TokenBucket::initialize(config);
331        let tokens = limiter.get_tokens();
332        let mut handles = vec![];
333
334        // Spawn multiple concurrent consumers
335        for _ in 0..10 {
336            let tokens = tokens.clone();
337            let handle = tokio::spawn(async move {
338                for _ in 0..10 {
339                    tokens.try_consume(10);
340                    sleep(Duration::from_millis(5)).await;
341                }
342            });
343            handles.push(handle);
344        }
345
346        // Wait for all tasks to complete
347        for handle in handles {
348            handle.await.unwrap();
349        }
350
351        // All tokens should have been consumed or some replenished
352        let tokens = limiter.available_tokens();
353        assert!(tokens <= 1000, "Tokens exceeded max");
354    }
355
356    #[tokio::test]
357    async fn test_rate_limiting_behavior() {
358        let config = TokenBucketConfig {
359            initial_tokens: 5,
360            tokens_per_interval: 5,
361            replenish_interval: Duration::from_millis(100),
362            max_tokens: 10,
363        };
364
365        let limiter = TokenBucket::initialize(config);
366        let start = Instant::now();
367
368        // Consume all initial tokens
369        assert!(limiter.try_consume(5));
370
371        // Try to consume more - should fail
372        assert!(!limiter.try_consume(5));
373
374        // Wait for replenishment
375        while limiter.available_tokens() < 5 {
376            sleep(Duration::from_millis(10)).await;
377        }
378
379        let elapsed = start.elapsed();
380
381        // Should have taken at least one replenish interval
382        assert!(elapsed >= Duration::from_millis(100));
383
384        // Now we should be able to consume again
385        assert!(limiter.try_consume(5));
386    }
387
388    #[tokio::test]
389    async fn test_get_config() {
390        let config = TokenBucketConfig {
391            initial_tokens: 42,
392            tokens_per_interval: 13,
393            replenish_interval: Duration::from_millis(250),
394            max_tokens: 200,
395        };
396
397        let limiter = TokenBucket::initialize(config);
398        let retrieved_config = limiter.get_config().await;
399
400        assert_eq!(retrieved_config.initial_tokens, 42);
401        assert_eq!(retrieved_config.tokens_per_interval, 13);
402        assert_eq!(
403            retrieved_config.replenish_interval,
404            Duration::from_millis(250)
405        );
406        assert_eq!(retrieved_config.max_tokens, 200);
407    }
408}