riglr_core/util/
rate_limiter.rs

1//! Rate limiting utilities for riglr-core
2//!
3//! This module provides flexible, strategy-based rate limiting that supports
4//! multiple algorithms including token bucket, fixed window, and custom strategies.
5
6use std::sync::Arc;
7use std::time::Duration;
8
9use super::rate_limit_strategy::{FixedWindowStrategy, RateLimitStrategy};
10use super::token_bucket::TokenBucketStrategy;
11use crate::ToolError;
12
13/// Rate limiting strategy type
14#[derive(Debug, Clone, Copy)]
15pub enum RateLimitStrategyType {
16    /// Token bucket algorithm (default)
17    TokenBucket,
18    /// Fixed window algorithm
19    FixedWindow,
20}
21
22/// A configurable rate limiter for controlling request rates.
23///
24/// This rate limiter supports multiple strategies:
25/// - Token bucket: Continuous token replenishment with burst capacity
26/// - Fixed window: Fixed request count per time window
27/// - Custom strategies via the RateLimitStrategy trait
28///
29/// # Example
30///
31/// ```rust
32/// use riglr_core::util::{RateLimiter, RateLimitStrategyType};
33/// use std::time::Duration;
34///
35/// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
36/// // Default token bucket strategy
37/// let rate_limiter = RateLimiter::new(10, Duration::from_secs(60));
38///
39/// // Check rate limit for a user
40/// rate_limiter.check_rate_limit("user123")?;
41///
42/// // Use fixed window strategy
43/// let fixed_limiter = RateLimiter::builder()
44///     .strategy(RateLimitStrategyType::FixedWindow)
45///     .max_requests(100)
46///     .time_window(Duration::from_secs(60))
47///     .build();
48/// # Ok(())
49/// # }
50/// ```
51#[derive(Clone)]
52pub struct RateLimiter {
53    /// The underlying rate limiting strategy
54    strategy: Arc<dyn RateLimitStrategy>,
55}
56
57impl std::fmt::Debug for RateLimiter {
58    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
59        f.debug_struct("RateLimiter")
60            .field("strategy", &self.strategy.strategy_name())
61            .finish()
62    }
63}
64
65impl RateLimiter {
66    /// Create a new rate limiter with the default token bucket strategy
67    pub fn new(max_requests: usize, time_window: Duration) -> Self {
68        Self {
69            strategy: Arc::new(TokenBucketStrategy::new(max_requests, time_window)),
70        }
71    }
72
73    /// Create a rate limiter with a custom strategy
74    pub fn with_strategy<S: RateLimitStrategy + 'static>(strategy: S) -> Self {
75        Self {
76            strategy: Arc::new(strategy),
77        }
78    }
79
80    /// Create a new rate limiter builder for advanced configuration
81    pub fn builder() -> RateLimiterBuilder {
82        RateLimiterBuilder::default()
83    }
84
85    /// Check if a client has exceeded the rate limit
86    ///
87    /// # Arguments
88    /// * `client_id` - Unique identifier for the client (e.g., IP address, user ID)
89    ///
90    /// # Returns
91    /// * `Ok(())` if the request is allowed
92    /// * `Err(ToolError::RateLimited)` if the rate limit is exceeded
93    pub fn check_rate_limit(&self, client_id: &str) -> Result<(), ToolError> {
94        self.strategy.check_rate_limit(client_id)
95    }
96
97    /// Reset rate limit for a specific client
98    pub fn reset_client(&self, client_id: &str) {
99        self.strategy.reset_client(client_id)
100    }
101
102    /// Clear all rate limit data
103    pub fn clear_all(&self) {
104        self.strategy.clear_all()
105    }
106
107    /// Get current request count for a client
108    pub fn get_request_count(&self, client_id: &str) -> usize {
109        self.strategy.get_request_count(client_id)
110    }
111
112    /// Get the name of the current strategy
113    pub fn strategy_name(&self) -> &str {
114        self.strategy.strategy_name()
115    }
116}
117
118/// Builder for creating customized RateLimiter instances
119#[derive(Debug, Default)]
120pub struct RateLimiterBuilder {
121    strategy_type: Option<RateLimitStrategyType>,
122    max_requests: Option<usize>,
123    time_window: Option<Duration>,
124    burst_size: Option<usize>,
125}
126
127impl RateLimiterBuilder {
128    /// Set the rate limiting strategy type
129    pub fn strategy(mut self, strategy: RateLimitStrategyType) -> Self {
130        self.strategy_type = Some(strategy);
131        self
132    }
133
134    /// Set the maximum number of requests allowed in the time window
135    pub fn max_requests(mut self, max: usize) -> Self {
136        self.max_requests = Some(max);
137        self
138    }
139
140    /// Set the time window for rate limiting
141    pub fn time_window(mut self, window: Duration) -> Self {
142        self.time_window = Some(window);
143        self
144    }
145
146    /// Set the burst size for temporary spikes
147    pub fn burst_size(mut self, size: usize) -> Self {
148        self.burst_size = Some(size);
149        self
150    }
151
152    /// Build the RateLimiter
153    pub fn build(self) -> RateLimiter {
154        let max_requests = self.max_requests.unwrap_or(10);
155        let time_window = self.time_window.unwrap_or_else(|| Duration::from_secs(60));
156        let strategy_type = self
157            .strategy_type
158            .unwrap_or(RateLimitStrategyType::TokenBucket);
159
160        let strategy: Arc<dyn RateLimitStrategy> = match strategy_type {
161            RateLimitStrategyType::TokenBucket => {
162                if let Some(burst_size) = self.burst_size {
163                    Arc::new(TokenBucketStrategy::with_burst(
164                        max_requests,
165                        time_window,
166                        burst_size,
167                    ))
168                } else {
169                    Arc::new(TokenBucketStrategy::new(max_requests, time_window))
170                }
171            }
172            RateLimitStrategyType::FixedWindow => {
173                Arc::new(FixedWindowStrategy::new(max_requests, time_window))
174            }
175        };
176
177        RateLimiter { strategy }
178    }
179}
180
181#[cfg(test)]
182mod tests {
183    use super::*;
184    use std::thread;
185
186    #[test]
187    fn test_rate_limiter_allows_requests_within_limit() {
188        let limiter = RateLimiter::new(3, Duration::from_secs(1));
189
190        assert!(limiter.check_rate_limit("user1").is_ok());
191        assert!(limiter.check_rate_limit("user1").is_ok());
192        assert!(limiter.check_rate_limit("user1").is_ok());
193    }
194
195    #[test]
196    fn test_rate_limiter_blocks_requests_over_limit() {
197        let limiter = RateLimiter::new(2, Duration::from_secs(1));
198
199        assert!(limiter.check_rate_limit("user1").is_ok());
200        assert!(limiter.check_rate_limit("user1").is_ok());
201        assert!(limiter.check_rate_limit("user1").is_err());
202    }
203
204    #[test]
205    fn test_rate_limiter_with_different_clients() {
206        let limiter = RateLimiter::new(1, Duration::from_secs(1));
207
208        assert!(limiter.check_rate_limit("user1").is_ok());
209        assert!(limiter.check_rate_limit("user2").is_ok());
210        assert!(limiter.check_rate_limit("user1").is_err());
211        assert!(limiter.check_rate_limit("user2").is_err());
212    }
213
214    #[test]
215    fn test_rate_limiter_builder() {
216        let limiter = RateLimiter::builder()
217            .max_requests(5)
218            .time_window(Duration::from_secs(10))
219            .burst_size(2)
220            .build();
221
222        assert!(limiter.check_rate_limit("user1").is_ok());
223    }
224
225    #[test]
226    fn test_reset_client() {
227        let limiter = RateLimiter::new(1, Duration::from_secs(1));
228
229        assert!(limiter.check_rate_limit("user1").is_ok());
230        assert!(limiter.check_rate_limit("user1").is_err());
231
232        limiter.reset_client("user1");
233        assert!(limiter.check_rate_limit("user1").is_ok());
234    }
235
236    #[test]
237    fn test_time_based_token_replenishment() {
238        // Create a rate limiter with 10 requests per second (for faster testing)
239        let limiter = RateLimiter::new(10, Duration::from_millis(1000));
240
241        // Exhaust initial tokens
242        for _ in 0..10 {
243            assert!(limiter.check_rate_limit("user1").is_ok());
244        }
245
246        // Should be blocked now
247        assert!(limiter.check_rate_limit("user1").is_err());
248
249        // Wait for tokens to replenish (100ms should give us ~1 token)
250        thread::sleep(Duration::from_millis(150));
251
252        // Should be allowed now due to token replenishment
253        assert!(limiter.check_rate_limit("user1").is_ok());
254
255        // Should be blocked again
256        assert!(limiter.check_rate_limit("user1").is_err());
257    }
258
259    #[test]
260    fn test_burst_size_cap() {
261        // Create a rate limiter with burst size
262        let limiter = RateLimiter::builder()
263            .max_requests(5)
264            .time_window(Duration::from_secs(1))
265            .burst_size(3) // Burst size smaller than max_requests
266            .build();
267
268        // Should be able to make 3 burst requests (initial bucket capacity)
269        assert!(limiter.check_rate_limit("user1").is_ok());
270        assert!(limiter.check_rate_limit("user1").is_ok());
271        assert!(limiter.check_rate_limit("user1").is_ok());
272
273        // Now should be blocked (burst tokens exhausted)
274        assert!(limiter.check_rate_limit("user1").is_err());
275
276        // Wait for tokens to replenish (200ms should give us 1 token at 5/sec rate)
277        thread::sleep(Duration::from_millis(250));
278
279        // Should be allowed now due to replenishment
280        assert!(limiter.check_rate_limit("user1").is_ok());
281    }
282
283    #[test]
284    fn test_token_accumulation_capped() {
285        // Create a rate limiter with burst size
286        let limiter = RateLimiter::builder()
287            .max_requests(10)
288            .time_window(Duration::from_millis(100))
289            .burst_size(5)
290            .build();
291
292        // Wait long enough that tokens would accumulate beyond burst size if uncapped
293        // (200ms would generate 20 tokens at 100 tokens/sec rate, but capped at 5)
294        thread::sleep(Duration::from_millis(200));
295
296        // Should only be able to make burst_size (5) requests rapidly
297        for _ in 0..5 {
298            assert!(limiter.check_rate_limit("user1").is_ok());
299        }
300
301        // Should be blocked now (all 5 tokens consumed)
302        assert!(limiter.check_rate_limit("user1").is_err());
303
304        // Wait for one more token to replenish (10ms for 1 token at 100/sec)
305        thread::sleep(Duration::from_millis(15));
306
307        // Should be allowed one more request
308        assert!(limiter.check_rate_limit("user1").is_ok());
309
310        // Should be blocked again
311        assert!(limiter.check_rate_limit("user1").is_err());
312    }
313
314    #[test]
315    fn test_fractional_token_replenishment() {
316        // Create a rate limiter with 1 request per second
317        let limiter = RateLimiter::new(1, Duration::from_secs(1));
318
319        // Use the token
320        assert!(limiter.check_rate_limit("user1").is_ok());
321        assert!(limiter.check_rate_limit("user1").is_err());
322
323        // Wait for half a second (should get 0.5 tokens)
324        thread::sleep(Duration::from_millis(500));
325
326        // Should still be blocked (need 1.0 token)
327        assert!(limiter.check_rate_limit("user1").is_err());
328
329        // Wait another 600ms (total 1.1 seconds, should have > 1 token)
330        thread::sleep(Duration::from_millis(600));
331
332        // Should be allowed now
333        assert!(limiter.check_rate_limit("user1").is_ok());
334    }
335}