Skip to main content

allsource_core/infrastructure/security/
rate_limit.rs

1/// Rate limiting implementation using token bucket algorithm
2///
3/// Features:
4/// - Per-tenant rate limiting
5/// - Per-user rate limiting
6/// - Per-API key rate limiting
7/// - Configurable limits
8/// - Efficient in-memory storage with DashMap
9/// - Automatic token replenishment
10use chrono::{DateTime, Utc};
11use dashmap::DashMap;
12use std::{sync::Arc, time::Duration};
13
14/// Rate limit configuration for different resource types
15#[derive(Debug, Clone)]
16pub struct RateLimitConfig {
17    /// Maximum requests per minute
18    pub requests_per_minute: u32,
19    /// Maximum burst size
20    pub burst_size: u32,
21}
22
23impl RateLimitConfig {
24    /// Free tier: 60 requests/min
25    pub fn free_tier() -> Self {
26        Self {
27            requests_per_minute: 60,
28            burst_size: 100,
29        }
30    }
31
32    /// Professional tier: 600 requests/min
33    pub fn professional() -> Self {
34        Self {
35            requests_per_minute: 600,
36            burst_size: 1000,
37        }
38    }
39
40    /// Unlimited tier: 10,000 requests/min
41    pub fn unlimited() -> Self {
42        Self {
43            requests_per_minute: 10_000,
44            burst_size: 20_000,
45        }
46    }
47
48    /// Development mode: Very high limits
49    pub fn dev_mode() -> Self {
50        Self {
51            requests_per_minute: 100_000,
52            burst_size: 200_000,
53        }
54    }
55}
56
57/// Token bucket for rate limiting
58#[derive(Debug, Clone)]
59struct TokenBucket {
60    /// Current number of tokens
61    tokens: f64,
62    /// Maximum tokens (burst size)
63    max_tokens: f64,
64    /// Tokens added per second
65    refill_rate: f64,
66    /// Last refill timestamp
67    last_refill: DateTime<Utc>,
68}
69
70impl TokenBucket {
71    fn new(config: &RateLimitConfig) -> Self {
72        let max_tokens = f64::from(config.burst_size);
73        Self {
74            tokens: max_tokens,
75            max_tokens,
76            refill_rate: f64::from(config.requests_per_minute) / 60.0, // tokens per second
77            last_refill: Utc::now(),
78        }
79    }
80
81    /// Try to consume a token. Returns true if successful.
82    fn try_consume(&mut self, tokens: f64) -> bool {
83        self.refill();
84
85        if self.tokens >= tokens {
86            self.tokens -= tokens;
87            true
88        } else {
89            false
90        }
91    }
92
93    /// Refill tokens based on time elapsed
94    fn refill(&mut self) {
95        let now = Utc::now();
96        let elapsed = (now - self.last_refill).num_milliseconds() as f64 / 1000.0;
97
98        if elapsed > 0.0 {
99            let new_tokens = elapsed * self.refill_rate;
100            self.tokens = (self.tokens + new_tokens).min(self.max_tokens);
101            self.last_refill = now;
102        }
103    }
104
105    /// Get remaining tokens
106    fn remaining(&mut self) -> u32 {
107        self.refill();
108        self.tokens.floor() as u32
109    }
110
111    /// Get time until next token is available
112    fn retry_after(&mut self) -> Duration {
113        self.refill();
114
115        if self.tokens >= 1.0 {
116            Duration::from_secs(0)
117        } else {
118            let tokens_needed = 1.0 - self.tokens;
119            let seconds = tokens_needed / self.refill_rate;
120            Duration::from_secs_f64(seconds)
121        }
122    }
123}
124
125/// Rate limiter using token bucket algorithm
126pub struct RateLimiter {
127    /// Buckets keyed by identifier (tenant_id, user_id, or api_key_id)
128    buckets: Arc<DashMap<String, TokenBucket>>,
129    /// Default configuration
130    default_config: RateLimitConfig,
131    /// Custom configs for specific identifiers
132    custom_configs: Arc<DashMap<String, RateLimitConfig>>,
133}
134
135impl RateLimiter {
136    pub fn new(default_config: RateLimitConfig) -> Self {
137        Self {
138            buckets: Arc::new(DashMap::new()),
139            default_config,
140            custom_configs: Arc::new(DashMap::new()),
141        }
142    }
143
144    /// Set custom config for a specific identifier
145    pub fn set_config(&self, identifier: &str, config: RateLimitConfig) {
146        self.custom_configs.insert(identifier.to_string(), config);
147
148        // Reset bucket with new config
149        self.buckets.remove(identifier);
150    }
151
152    /// Check if request is allowed
153    pub fn check_rate_limit(&self, identifier: &str) -> RateLimitResult {
154        self.check_rate_limit_with_cost(identifier, 1.0)
155    }
156
157    /// Check rate limit with custom cost (for expensive operations)
158    pub fn check_rate_limit_with_cost(&self, identifier: &str, cost: f64) -> RateLimitResult {
159        let config = self
160            .custom_configs
161            .get(identifier)
162            .map_or_else(|| self.default_config.clone(), |c| c.clone());
163
164        let mut entry = self
165            .buckets
166            .entry(identifier.to_string())
167            .or_insert_with(|| TokenBucket::new(&config));
168
169        let allowed = entry.try_consume(cost);
170        let remaining = entry.remaining();
171        let retry_after = if allowed {
172            None
173        } else {
174            Some(entry.retry_after())
175        };
176
177        RateLimitResult {
178            allowed,
179            remaining,
180            retry_after,
181            limit: config.requests_per_minute,
182        }
183    }
184
185    /// Get current stats for an identifier
186    pub fn get_stats(&self, identifier: &str) -> Option<RateLimitStats> {
187        self.buckets
188            .get_mut(identifier)
189            .map(|mut bucket| RateLimitStats {
190                remaining: bucket.remaining(),
191                retry_after: bucket.retry_after(),
192            })
193    }
194
195    /// Cleanup old buckets (call periodically)
196    pub fn cleanup(&self) {
197        let now = Utc::now();
198        self.buckets.retain(|_, bucket| {
199            // Remove buckets that haven't been used in the last hour
200            (now - bucket.last_refill).num_hours() < 1
201        });
202    }
203}
204
205impl Default for RateLimiter {
206    fn default() -> Self {
207        Self::new(RateLimitConfig::professional())
208    }
209}
210
211/// Result of a rate limit check
212#[derive(Debug, Clone)]
213pub struct RateLimitResult {
214    /// Whether the request is allowed
215    pub allowed: bool,
216    /// Remaining requests
217    pub remaining: u32,
218    /// Time to wait before retrying (if not allowed)
219    pub retry_after: Option<Duration>,
220    /// Total limit per minute
221    pub limit: u32,
222}
223
224/// Current rate limit statistics
225#[derive(Debug, Clone)]
226pub struct RateLimitStats {
227    /// Remaining requests
228    pub remaining: u32,
229    /// Time until next token
230    pub retry_after: Duration,
231}
232
233#[cfg(test)]
234mod tests {
235    use super::*;
236    use std::{thread, time::Duration as StdDuration};
237
238    #[test]
239    fn test_token_bucket_creation() {
240        let config = RateLimitConfig::free_tier();
241        let bucket = TokenBucket::new(&config);
242
243        assert_eq!(bucket.max_tokens, 100.0);
244        assert_eq!(bucket.tokens, 100.0);
245    }
246
247    #[test]
248    fn test_token_consumption() {
249        let config = RateLimitConfig::free_tier();
250        let mut bucket = TokenBucket::new(&config);
251
252        assert!(bucket.try_consume(1.0));
253        assert_eq!(bucket.remaining(), 99);
254
255        assert!(bucket.try_consume(10.0));
256        assert_eq!(bucket.remaining(), 89);
257    }
258
259    #[test]
260    fn test_rate_limit_enforcement() {
261        let config = RateLimitConfig {
262            requests_per_minute: 60,
263            burst_size: 10,
264        };
265        let mut bucket = TokenBucket::new(&config);
266
267        // Should allow up to burst size
268        for _ in 0..10 {
269            assert!(bucket.try_consume(1.0));
270        }
271
272        // Should deny after burst exhausted
273        assert!(!bucket.try_consume(1.0));
274    }
275
276    #[test]
277    fn test_token_refill() {
278        let config = RateLimitConfig {
279            requests_per_minute: 60, // 1 per second
280            burst_size: 10,
281        };
282        let mut bucket = TokenBucket::new(&config);
283
284        // Consume all tokens
285        for _ in 0..10 {
286            bucket.try_consume(1.0);
287        }
288
289        assert_eq!(bucket.remaining(), 0);
290
291        // Wait for refill (simulate)
292        thread::sleep(StdDuration::from_secs(2));
293
294        // Should have ~2 tokens refilled
295        let remaining = bucket.remaining();
296        assert!(
297            (1..=3).contains(&remaining),
298            "Expected 1-3 tokens, got {remaining}"
299        );
300    }
301
302    #[test]
303    fn test_rate_limiter_per_identifier() {
304        let limiter = RateLimiter::new(RateLimitConfig {
305            requests_per_minute: 60,
306            burst_size: 5,
307        });
308
309        // Different identifiers have separate buckets
310        let result1 = limiter.check_rate_limit("user1");
311        let result2 = limiter.check_rate_limit("user2");
312
313        assert!(result1.allowed);
314        assert!(result2.allowed);
315        assert_eq!(result1.remaining, 4);
316        assert_eq!(result2.remaining, 4);
317    }
318
319    #[test]
320    fn test_custom_config() {
321        let limiter = RateLimiter::new(RateLimitConfig::free_tier());
322
323        limiter.set_config("premium_user", RateLimitConfig::unlimited());
324
325        let free_result = limiter.check_rate_limit("free_user");
326        let premium_result = limiter.check_rate_limit("premium_user");
327
328        assert!(free_result.limit < premium_result.limit);
329    }
330
331    #[test]
332    fn test_rate_limit_with_cost() {
333        let limiter = RateLimiter::new(RateLimitConfig {
334            requests_per_minute: 60,
335            burst_size: 10,
336        });
337
338        // Expensive operation costs 5 tokens
339        let result = limiter.check_rate_limit_with_cost("user1", 5.0);
340        assert!(result.allowed);
341        assert_eq!(result.remaining, 5);
342    }
343}