Skip to main content

allsource_core/infrastructure/security/
rate_limit.rs

1/// Rate limiting implementation using token bucket algorithm
2///
3/// Features:
4/// - Per-tenant rate limiting
5/// - Per-user rate limiting
6/// - Per-API key rate limiting
7/// - Configurable limits
8/// - Efficient in-memory storage with DashMap
9/// - Automatic token replenishment
10use chrono::{DateTime, Utc};
11use dashmap::DashMap;
12use std::{sync::Arc, time::Duration};
13
14/// Rate limit configuration for different resource types
15#[derive(Debug, Clone)]
16pub struct RateLimitConfig {
17    /// Maximum requests per minute
18    pub requests_per_minute: u32,
19    /// Maximum burst size
20    pub burst_size: u32,
21}
22
23impl RateLimitConfig {
24    /// Free tier: 60 requests/min
25    pub fn free_tier() -> Self {
26        Self {
27            requests_per_minute: 60,
28            burst_size: 100,
29        }
30    }
31
32    /// Professional tier: 600 requests/min
33    pub fn professional() -> Self {
34        Self {
35            requests_per_minute: 600,
36            burst_size: 1000,
37        }
38    }
39
40    /// Unlimited tier: 10,000 requests/min
41    pub fn unlimited() -> Self {
42        Self {
43            requests_per_minute: 10_000,
44            burst_size: 20_000,
45        }
46    }
47
48    /// Development mode: Very high limits
49    pub fn dev_mode() -> Self {
50        Self {
51            requests_per_minute: 100_000,
52            burst_size: 200_000,
53        }
54    }
55}
56
57/// Token bucket for rate limiting
58#[derive(Debug, Clone)]
59struct TokenBucket {
60    /// Current number of tokens
61    tokens: f64,
62    /// Maximum tokens (burst size)
63    max_tokens: f64,
64    /// Tokens added per second
65    refill_rate: f64,
66    /// Last refill timestamp
67    last_refill: DateTime<Utc>,
68}
69
70impl TokenBucket {
71    fn new(config: &RateLimitConfig) -> Self {
72        let max_tokens = config.burst_size as f64;
73        Self {
74            tokens: max_tokens,
75            max_tokens,
76            refill_rate: config.requests_per_minute as f64 / 60.0, // tokens per second
77            last_refill: Utc::now(),
78        }
79    }
80
81    /// Try to consume a token. Returns true if successful.
82    fn try_consume(&mut self, tokens: f64) -> bool {
83        self.refill();
84
85        if self.tokens >= tokens {
86            self.tokens -= tokens;
87            true
88        } else {
89            false
90        }
91    }
92
93    /// Refill tokens based on time elapsed
94    fn refill(&mut self) {
95        let now = Utc::now();
96        let elapsed = (now - self.last_refill).num_milliseconds() as f64 / 1000.0;
97
98        if elapsed > 0.0 {
99            let new_tokens = elapsed * self.refill_rate;
100            self.tokens = (self.tokens + new_tokens).min(self.max_tokens);
101            self.last_refill = now;
102        }
103    }
104
105    /// Get remaining tokens
106    fn remaining(&mut self) -> u32 {
107        self.refill();
108        self.tokens.floor() as u32
109    }
110
111    /// Get time until next token is available
112    fn retry_after(&mut self) -> Duration {
113        self.refill();
114
115        if self.tokens >= 1.0 {
116            Duration::from_secs(0)
117        } else {
118            let tokens_needed = 1.0 - self.tokens;
119            let seconds = tokens_needed / self.refill_rate;
120            Duration::from_secs_f64(seconds)
121        }
122    }
123}
124
125/// Rate limiter using token bucket algorithm
126pub struct RateLimiter {
127    /// Buckets keyed by identifier (tenant_id, user_id, or api_key_id)
128    buckets: Arc<DashMap<String, TokenBucket>>,
129    /// Default configuration
130    default_config: RateLimitConfig,
131    /// Custom configs for specific identifiers
132    custom_configs: Arc<DashMap<String, RateLimitConfig>>,
133}
134
135impl RateLimiter {
136    pub fn new(default_config: RateLimitConfig) -> Self {
137        Self {
138            buckets: Arc::new(DashMap::new()),
139            default_config,
140            custom_configs: Arc::new(DashMap::new()),
141        }
142    }
143
144    /// Set custom config for a specific identifier
145    pub fn set_config(&self, identifier: &str, config: RateLimitConfig) {
146        self.custom_configs.insert(identifier.to_string(), config);
147
148        // Reset bucket with new config
149        self.buckets.remove(identifier);
150    }
151
152    /// Check if request is allowed
153    pub fn check_rate_limit(&self, identifier: &str) -> RateLimitResult {
154        self.check_rate_limit_with_cost(identifier, 1.0)
155    }
156
157    /// Check rate limit with custom cost (for expensive operations)
158    pub fn check_rate_limit_with_cost(&self, identifier: &str, cost: f64) -> RateLimitResult {
159        let config = self
160            .custom_configs
161            .get(identifier)
162            .map(|c| c.clone())
163            .unwrap_or_else(|| self.default_config.clone());
164
165        let mut entry = self
166            .buckets
167            .entry(identifier.to_string())
168            .or_insert_with(|| TokenBucket::new(&config));
169
170        let allowed = entry.try_consume(cost);
171        let remaining = entry.remaining();
172        let retry_after = if !allowed {
173            Some(entry.retry_after())
174        } else {
175            None
176        };
177
178        RateLimitResult {
179            allowed,
180            remaining,
181            retry_after,
182            limit: config.requests_per_minute,
183        }
184    }
185
186    /// Get current stats for an identifier
187    pub fn get_stats(&self, identifier: &str) -> Option<RateLimitStats> {
188        self.buckets
189            .get_mut(identifier)
190            .map(|mut bucket| RateLimitStats {
191                remaining: bucket.remaining(),
192                retry_after: bucket.retry_after(),
193            })
194    }
195
196    /// Cleanup old buckets (call periodically)
197    pub fn cleanup(&self) {
198        let now = Utc::now();
199        self.buckets.retain(|_, bucket| {
200            // Remove buckets that haven't been used in the last hour
201            (now - bucket.last_refill).num_hours() < 1
202        });
203    }
204}
205
206impl Default for RateLimiter {
207    fn default() -> Self {
208        Self::new(RateLimitConfig::professional())
209    }
210}
211
212/// Result of a rate limit check
213#[derive(Debug, Clone)]
214pub struct RateLimitResult {
215    /// Whether the request is allowed
216    pub allowed: bool,
217    /// Remaining requests
218    pub remaining: u32,
219    /// Time to wait before retrying (if not allowed)
220    pub retry_after: Option<Duration>,
221    /// Total limit per minute
222    pub limit: u32,
223}
224
225/// Current rate limit statistics
226#[derive(Debug, Clone)]
227pub struct RateLimitStats {
228    /// Remaining requests
229    pub remaining: u32,
230    /// Time until next token
231    pub retry_after: Duration,
232}
233
234#[cfg(test)]
235mod tests {
236    use super::*;
237    use std::{thread, time::Duration as StdDuration};
238
239    #[test]
240    fn test_token_bucket_creation() {
241        let config = RateLimitConfig::free_tier();
242        let bucket = TokenBucket::new(&config);
243
244        assert_eq!(bucket.max_tokens, 100.0);
245        assert_eq!(bucket.tokens, 100.0);
246    }
247
248    #[test]
249    fn test_token_consumption() {
250        let config = RateLimitConfig::free_tier();
251        let mut bucket = TokenBucket::new(&config);
252
253        assert!(bucket.try_consume(1.0));
254        assert_eq!(bucket.remaining(), 99);
255
256        assert!(bucket.try_consume(10.0));
257        assert_eq!(bucket.remaining(), 89);
258    }
259
260    #[test]
261    fn test_rate_limit_enforcement() {
262        let config = RateLimitConfig {
263            requests_per_minute: 60,
264            burst_size: 10,
265        };
266        let mut bucket = TokenBucket::new(&config);
267
268        // Should allow up to burst size
269        for _ in 0..10 {
270            assert!(bucket.try_consume(1.0));
271        }
272
273        // Should deny after burst exhausted
274        assert!(!bucket.try_consume(1.0));
275    }
276
277    #[test]
278    fn test_token_refill() {
279        let config = RateLimitConfig {
280            requests_per_minute: 60, // 1 per second
281            burst_size: 10,
282        };
283        let mut bucket = TokenBucket::new(&config);
284
285        // Consume all tokens
286        for _ in 0..10 {
287            bucket.try_consume(1.0);
288        }
289
290        assert_eq!(bucket.remaining(), 0);
291
292        // Wait for refill (simulate)
293        thread::sleep(StdDuration::from_secs(2));
294
295        // Should have ~2 tokens refilled
296        let remaining = bucket.remaining();
297        assert!(
298            (1..=3).contains(&remaining),
299            "Expected 1-3 tokens, got {}",
300            remaining
301        );
302    }
303
304    #[test]
305    fn test_rate_limiter_per_identifier() {
306        let limiter = RateLimiter::new(RateLimitConfig {
307            requests_per_minute: 60,
308            burst_size: 5,
309        });
310
311        // Different identifiers have separate buckets
312        let result1 = limiter.check_rate_limit("user1");
313        let result2 = limiter.check_rate_limit("user2");
314
315        assert!(result1.allowed);
316        assert!(result2.allowed);
317        assert_eq!(result1.remaining, 4);
318        assert_eq!(result2.remaining, 4);
319    }
320
321    #[test]
322    fn test_custom_config() {
323        let limiter = RateLimiter::new(RateLimitConfig::free_tier());
324
325        limiter.set_config("premium_user", RateLimitConfig::unlimited());
326
327        let free_result = limiter.check_rate_limit("free_user");
328        let premium_result = limiter.check_rate_limit("premium_user");
329
330        assert!(free_result.limit < premium_result.limit);
331    }
332
333    #[test]
334    fn test_rate_limit_with_cost() {
335        let limiter = RateLimiter::new(RateLimitConfig {
336            requests_per_minute: 60,
337            burst_size: 10,
338        });
339
340        // Expensive operation costs 5 tokens
341        let result = limiter.check_rate_limit_with_cost("user1", 5.0);
342        assert!(result.allowed);
343        assert_eq!(result.remaining, 5);
344    }
345}