allsource_core/
rate_limit.rs

1/// Rate limiting implementation using token bucket algorithm
2///
3/// Features:
4/// - Per-tenant rate limiting
5/// - Per-user rate limiting
6/// - Per-API key rate limiting
7/// - Configurable limits
8/// - Efficient in-memory storage with DashMap
9/// - Automatic token replenishment
10
11use chrono::{DateTime, Utc};
12use dashmap::DashMap;
13use std::sync::Arc;
14use std::time::Duration;
15
16/// Rate limit configuration for different resource types
17#[derive(Debug, Clone)]
18pub struct RateLimitConfig {
19    /// Maximum requests per minute
20    pub requests_per_minute: u32,
21    /// Maximum burst size
22    pub burst_size: u32,
23}
24
25impl RateLimitConfig {
26    /// Free tier: 60 requests/min
27    pub fn free_tier() -> Self {
28        Self {
29            requests_per_minute: 60,
30            burst_size: 100,
31        }
32    }
33
34    /// Professional tier: 600 requests/min
35    pub fn professional() -> Self {
36        Self {
37            requests_per_minute: 600,
38            burst_size: 1000,
39        }
40    }
41
42    /// Unlimited tier: 10,000 requests/min
43    pub fn unlimited() -> Self {
44        Self {
45            requests_per_minute: 10_000,
46            burst_size: 20_000,
47        }
48    }
49
50    /// Development mode: Very high limits
51    pub fn dev_mode() -> Self {
52        Self {
53            requests_per_minute: 100_000,
54            burst_size: 200_000,
55        }
56    }
57}
58
59/// Token bucket for rate limiting
60#[derive(Debug, Clone)]
61struct TokenBucket {
62    /// Current number of tokens
63    tokens: f64,
64    /// Maximum tokens (burst size)
65    max_tokens: f64,
66    /// Tokens added per second
67    refill_rate: f64,
68    /// Last refill timestamp
69    last_refill: DateTime<Utc>,
70}
71
72impl TokenBucket {
73    fn new(config: &RateLimitConfig) -> Self {
74        let max_tokens = config.burst_size as f64;
75        Self {
76            tokens: max_tokens,
77            max_tokens,
78            refill_rate: config.requests_per_minute as f64 / 60.0, // tokens per second
79            last_refill: Utc::now(),
80        }
81    }
82
83    /// Try to consume a token. Returns true if successful.
84    fn try_consume(&mut self, tokens: f64) -> bool {
85        self.refill();
86
87        if self.tokens >= tokens {
88            self.tokens -= tokens;
89            true
90        } else {
91            false
92        }
93    }
94
95    /// Refill tokens based on time elapsed
96    fn refill(&mut self) {
97        let now = Utc::now();
98        let elapsed = (now - self.last_refill).num_milliseconds() as f64 / 1000.0;
99
100        if elapsed > 0.0 {
101            let new_tokens = elapsed * self.refill_rate;
102            self.tokens = (self.tokens + new_tokens).min(self.max_tokens);
103            self.last_refill = now;
104        }
105    }
106
107    /// Get remaining tokens
108    fn remaining(&mut self) -> u32 {
109        self.refill();
110        self.tokens.floor() as u32
111    }
112
113    /// Get time until next token is available
114    fn retry_after(&mut self) -> Duration {
115        self.refill();
116
117        if self.tokens >= 1.0 {
118            Duration::from_secs(0)
119        } else {
120            let tokens_needed = 1.0 - self.tokens;
121            let seconds = tokens_needed / self.refill_rate;
122            Duration::from_secs_f64(seconds)
123        }
124    }
125}
126
127/// Rate limiter using token bucket algorithm
128pub struct RateLimiter {
129    /// Buckets keyed by identifier (tenant_id, user_id, or api_key_id)
130    buckets: Arc<DashMap<String, TokenBucket>>,
131    /// Default configuration
132    default_config: RateLimitConfig,
133    /// Custom configs for specific identifiers
134    custom_configs: Arc<DashMap<String, RateLimitConfig>>,
135}
136
137impl RateLimiter {
138    pub fn new(default_config: RateLimitConfig) -> Self {
139        Self {
140            buckets: Arc::new(DashMap::new()),
141            default_config,
142            custom_configs: Arc::new(DashMap::new()),
143        }
144    }
145
146    /// Set custom config for a specific identifier
147    pub fn set_config(&self, identifier: &str, config: RateLimitConfig) {
148        self.custom_configs.insert(identifier.to_string(), config);
149
150        // Reset bucket with new config
151        self.buckets.remove(identifier);
152    }
153
154    /// Check if request is allowed
155    pub fn check_rate_limit(&self, identifier: &str) -> RateLimitResult {
156        self.check_rate_limit_with_cost(identifier, 1.0)
157    }
158
159    /// Check rate limit with custom cost (for expensive operations)
160    pub fn check_rate_limit_with_cost(&self, identifier: &str, cost: f64) -> RateLimitResult {
161        let config = self.custom_configs
162            .get(identifier)
163            .map(|c| c.clone())
164            .unwrap_or_else(|| self.default_config.clone());
165
166        let mut entry = self.buckets
167            .entry(identifier.to_string())
168            .or_insert_with(|| TokenBucket::new(&config));
169
170        let allowed = entry.try_consume(cost);
171        let remaining = entry.remaining();
172        let retry_after = if !allowed {
173            Some(entry.retry_after())
174        } else {
175            None
176        };
177
178        RateLimitResult {
179            allowed,
180            remaining,
181            retry_after,
182            limit: config.requests_per_minute,
183        }
184    }
185
186    /// Get current stats for an identifier
187    pub fn get_stats(&self, identifier: &str) -> Option<RateLimitStats> {
188        self.buckets.get_mut(identifier).map(|mut bucket| {
189            RateLimitStats {
190                remaining: bucket.remaining(),
191                retry_after: bucket.retry_after(),
192            }
193        })
194    }
195
196    /// Cleanup old buckets (call periodically)
197    pub fn cleanup(&self) {
198        let now = Utc::now();
199        self.buckets.retain(|_, bucket| {
200            // Remove buckets that haven't been used in the last hour
201            (now - bucket.last_refill).num_hours() < 1
202        });
203    }
204}
205
206impl Default for RateLimiter {
207    fn default() -> Self {
208        Self::new(RateLimitConfig::professional())
209    }
210}
211
212/// Result of a rate limit check
213#[derive(Debug, Clone)]
214pub struct RateLimitResult {
215    /// Whether the request is allowed
216    pub allowed: bool,
217    /// Remaining requests
218    pub remaining: u32,
219    /// Time to wait before retrying (if not allowed)
220    pub retry_after: Option<Duration>,
221    /// Total limit per minute
222    pub limit: u32,
223}
224
225/// Current rate limit statistics
226#[derive(Debug, Clone)]
227pub struct RateLimitStats {
228    /// Remaining requests
229    pub remaining: u32,
230    /// Time until next token
231    pub retry_after: Duration,
232}
233
234#[cfg(test)]
235mod tests {
236    use super::*;
237    use std::thread;
238    use std::time::Duration as StdDuration;
239
240    #[test]
241    fn test_token_bucket_creation() {
242        let config = RateLimitConfig::free_tier();
243        let bucket = TokenBucket::new(&config);
244
245        assert_eq!(bucket.max_tokens, 100.0);
246        assert_eq!(bucket.tokens, 100.0);
247    }
248
249    #[test]
250    fn test_token_consumption() {
251        let config = RateLimitConfig::free_tier();
252        let mut bucket = TokenBucket::new(&config);
253
254        assert!(bucket.try_consume(1.0));
255        assert_eq!(bucket.remaining(), 99);
256
257        assert!(bucket.try_consume(10.0));
258        assert_eq!(bucket.remaining(), 89);
259    }
260
261    #[test]
262    fn test_rate_limit_enforcement() {
263        let config = RateLimitConfig {
264            requests_per_minute: 60,
265            burst_size: 10,
266        };
267        let mut bucket = TokenBucket::new(&config);
268
269        // Should allow up to burst size
270        for _ in 0..10 {
271            assert!(bucket.try_consume(1.0));
272        }
273
274        // Should deny after burst exhausted
275        assert!(!bucket.try_consume(1.0));
276    }
277
278    #[test]
279    fn test_token_refill() {
280        let config = RateLimitConfig {
281            requests_per_minute: 60, // 1 per second
282            burst_size: 10,
283        };
284        let mut bucket = TokenBucket::new(&config);
285
286        // Consume all tokens
287        for _ in 0..10 {
288            bucket.try_consume(1.0);
289        }
290
291        assert_eq!(bucket.remaining(), 0);
292
293        // Wait for refill (simulate)
294        thread::sleep(StdDuration::from_secs(2));
295
296        // Should have ~2 tokens refilled
297        let remaining = bucket.remaining();
298        assert!(remaining >= 1 && remaining <= 3, "Expected 1-3 tokens, got {}", remaining);
299    }
300
301    #[test]
302    fn test_rate_limiter_per_identifier() {
303        let limiter = RateLimiter::new(RateLimitConfig {
304            requests_per_minute: 60,
305            burst_size: 5,
306        });
307
308        // Different identifiers have separate buckets
309        let result1 = limiter.check_rate_limit("user1");
310        let result2 = limiter.check_rate_limit("user2");
311
312        assert!(result1.allowed);
313        assert!(result2.allowed);
314        assert_eq!(result1.remaining, 4);
315        assert_eq!(result2.remaining, 4);
316    }
317
318    #[test]
319    fn test_custom_config() {
320        let limiter = RateLimiter::new(RateLimitConfig::free_tier());
321
322        limiter.set_config("premium_user", RateLimitConfig::unlimited());
323
324        let free_result = limiter.check_rate_limit("free_user");
325        let premium_result = limiter.check_rate_limit("premium_user");
326
327        assert!(free_result.limit < premium_result.limit);
328    }
329
330    #[test]
331    fn test_rate_limit_with_cost() {
332        let limiter = RateLimiter::new(RateLimitConfig {
333            requests_per_minute: 60,
334            burst_size: 10,
335        });
336
337        // Expensive operation costs 5 tokens
338        let result = limiter.check_rate_limit_with_cost("user1", 5.0);
339        assert!(result.allowed);
340        assert_eq!(result.remaining, 5);
341    }
342}