Skip to main content

allsource_core/infrastructure/security/
rate_limit.rs

1// Copyright 2024-2025 AllSource Team
2// Licensed under the Business Source License 1.1 (the "License");
3// you may not use this file except in compliance with the License.
4// You may obtain a copy of the License at
5//
6//     See LICENSE-BSL in the repository root
7//
8// Change Date: 2029-03-01
9// Change License: Apache License, Version 2.0
10
11/// Rate limiting implementation using token bucket algorithm
12///
13/// Features:
14/// - Per-tenant rate limiting
15/// - Per-user rate limiting
16/// - Per-API key rate limiting
17/// - Configurable limits
18/// - Efficient in-memory storage with DashMap
19/// - Automatic token replenishment
20use chrono::{DateTime, Utc};
21use dashmap::DashMap;
22use std::{sync::Arc, time::Duration};
23
24/// Rate limit configuration for different resource types
25#[derive(Debug, Clone)]
26pub struct RateLimitConfig {
27    /// Maximum requests per minute
28    pub requests_per_minute: u32,
29    /// Maximum burst size
30    pub burst_size: u32,
31}
32
33impl RateLimitConfig {
34    /// Free tier: 60 requests/min
35    pub fn free_tier() -> Self {
36        Self {
37            requests_per_minute: 60,
38            burst_size: 100,
39        }
40    }
41
42    /// Professional tier: 600 requests/min
43    pub fn professional() -> Self {
44        Self {
45            requests_per_minute: 600,
46            burst_size: 1000,
47        }
48    }
49
50    /// Unlimited tier: 10,000 requests/min
51    pub fn unlimited() -> Self {
52        Self {
53            requests_per_minute: 10_000,
54            burst_size: 20_000,
55        }
56    }
57
58    /// Development mode: Very high limits
59    pub fn dev_mode() -> Self {
60        Self {
61            requests_per_minute: 100_000,
62            burst_size: 200_000,
63        }
64    }
65}
66
67/// Token bucket for rate limiting
68#[derive(Debug, Clone)]
69struct TokenBucket {
70    /// Current number of tokens
71    tokens: f64,
72    /// Maximum tokens (burst size)
73    max_tokens: f64,
74    /// Tokens added per second
75    refill_rate: f64,
76    /// Last refill timestamp
77    last_refill: DateTime<Utc>,
78}
79
80impl TokenBucket {
81    fn new(config: &RateLimitConfig) -> Self {
82        let max_tokens = f64::from(config.burst_size);
83        Self {
84            tokens: max_tokens,
85            max_tokens,
86            refill_rate: f64::from(config.requests_per_minute) / 60.0, // tokens per second
87            last_refill: Utc::now(),
88        }
89    }
90
91    /// Try to consume a token. Returns true if successful.
92    fn try_consume(&mut self, tokens: f64) -> bool {
93        self.refill();
94
95        if self.tokens >= tokens {
96            self.tokens -= tokens;
97            true
98        } else {
99            false
100        }
101    }
102
103    /// Refill tokens based on time elapsed
104    fn refill(&mut self) {
105        let now = Utc::now();
106        let elapsed = (now - self.last_refill).num_milliseconds() as f64 / 1000.0;
107
108        if elapsed > 0.0 {
109            let new_tokens = elapsed * self.refill_rate;
110            self.tokens = (self.tokens + new_tokens).min(self.max_tokens);
111            self.last_refill = now;
112        }
113    }
114
115    /// Get remaining tokens
116    fn remaining(&mut self) -> u32 {
117        self.refill();
118        self.tokens.floor() as u32
119    }
120
121    /// Get time until next token is available
122    fn retry_after(&mut self) -> Duration {
123        self.refill();
124
125        if self.tokens >= 1.0 {
126            Duration::from_secs(0)
127        } else {
128            let tokens_needed = 1.0 - self.tokens;
129            let seconds = tokens_needed / self.refill_rate;
130            Duration::from_secs_f64(seconds)
131        }
132    }
133}
134
135/// Rate limiter using token bucket algorithm
136pub struct RateLimiter {
137    /// Buckets keyed by identifier (tenant_id, user_id, or api_key_id)
138    buckets: Arc<DashMap<String, TokenBucket>>,
139    /// Default configuration
140    default_config: RateLimitConfig,
141    /// Custom configs for specific identifiers
142    custom_configs: Arc<DashMap<String, RateLimitConfig>>,
143}
144
145impl RateLimiter {
146    pub fn new(default_config: RateLimitConfig) -> Self {
147        Self {
148            buckets: Arc::new(DashMap::new()),
149            default_config,
150            custom_configs: Arc::new(DashMap::new()),
151        }
152    }
153
154    /// Set custom config for a specific identifier
155    pub fn set_config(&self, identifier: &str, config: RateLimitConfig) {
156        self.custom_configs.insert(identifier.to_string(), config);
157
158        // Reset bucket with new config
159        self.buckets.remove(identifier);
160    }
161
162    /// Check if request is allowed
163    pub fn check_rate_limit(&self, identifier: &str) -> RateLimitResult {
164        self.check_rate_limit_with_cost(identifier, 1.0)
165    }
166
167    /// Check rate limit with custom cost (for expensive operations)
168    pub fn check_rate_limit_with_cost(&self, identifier: &str, cost: f64) -> RateLimitResult {
169        let config = self
170            .custom_configs
171            .get(identifier)
172            .map_or_else(|| self.default_config.clone(), |c| c.clone());
173
174        let mut entry = self
175            .buckets
176            .entry(identifier.to_string())
177            .or_insert_with(|| TokenBucket::new(&config));
178
179        let allowed = entry.try_consume(cost);
180        let remaining = entry.remaining();
181        let retry_after = if allowed {
182            None
183        } else {
184            Some(entry.retry_after())
185        };
186
187        RateLimitResult {
188            allowed,
189            remaining,
190            retry_after,
191            limit: config.requests_per_minute,
192        }
193    }
194
195    /// Get current stats for an identifier
196    pub fn get_stats(&self, identifier: &str) -> Option<RateLimitStats> {
197        self.buckets
198            .get_mut(identifier)
199            .map(|mut bucket| RateLimitStats {
200                remaining: bucket.remaining(),
201                retry_after: bucket.retry_after(),
202            })
203    }
204
205    /// Cleanup old buckets (call periodically)
206    pub fn cleanup(&self) {
207        let now = Utc::now();
208        self.buckets.retain(|_, bucket| {
209            // Remove buckets that haven't been used in the last hour
210            (now - bucket.last_refill).num_hours() < 1
211        });
212    }
213}
214
215impl Default for RateLimiter {
216    fn default() -> Self {
217        Self::new(RateLimitConfig::professional())
218    }
219}
220
221/// Result of a rate limit check
222#[derive(Debug, Clone)]
223pub struct RateLimitResult {
224    /// Whether the request is allowed
225    pub allowed: bool,
226    /// Remaining requests
227    pub remaining: u32,
228    /// Time to wait before retrying (if not allowed)
229    pub retry_after: Option<Duration>,
230    /// Total limit per minute
231    pub limit: u32,
232}
233
234/// Current rate limit statistics
235#[derive(Debug, Clone)]
236pub struct RateLimitStats {
237    /// Remaining requests
238    pub remaining: u32,
239    /// Time until next token
240    pub retry_after: Duration,
241}
242
243#[cfg(test)]
244mod tests {
245    use super::*;
246    use std::{thread, time::Duration as StdDuration};
247
248    #[test]
249    fn test_token_bucket_creation() {
250        let config = RateLimitConfig::free_tier();
251        let bucket = TokenBucket::new(&config);
252
253        assert_eq!(bucket.max_tokens, 100.0);
254        assert_eq!(bucket.tokens, 100.0);
255    }
256
257    #[test]
258    fn test_token_consumption() {
259        let config = RateLimitConfig::free_tier();
260        let mut bucket = TokenBucket::new(&config);
261
262        assert!(bucket.try_consume(1.0));
263        assert_eq!(bucket.remaining(), 99);
264
265        assert!(bucket.try_consume(10.0));
266        assert_eq!(bucket.remaining(), 89);
267    }
268
269    #[test]
270    fn test_rate_limit_enforcement() {
271        let config = RateLimitConfig {
272            requests_per_minute: 60,
273            burst_size: 10,
274        };
275        let mut bucket = TokenBucket::new(&config);
276
277        // Should allow up to burst size
278        for _ in 0..10 {
279            assert!(bucket.try_consume(1.0));
280        }
281
282        // Should deny after burst exhausted
283        assert!(!bucket.try_consume(1.0));
284    }
285
286    #[test]
287    fn test_token_refill() {
288        let config = RateLimitConfig {
289            requests_per_minute: 60, // 1 per second
290            burst_size: 10,
291        };
292        let mut bucket = TokenBucket::new(&config);
293
294        // Consume all tokens
295        for _ in 0..10 {
296            bucket.try_consume(1.0);
297        }
298
299        assert_eq!(bucket.remaining(), 0);
300
301        // Wait for refill (simulate)
302        thread::sleep(StdDuration::from_secs(2));
303
304        // Should have ~2 tokens refilled
305        let remaining = bucket.remaining();
306        assert!(
307            (1..=3).contains(&remaining),
308            "Expected 1-3 tokens, got {remaining}"
309        );
310    }
311
312    #[test]
313    fn test_rate_limiter_per_identifier() {
314        let limiter = RateLimiter::new(RateLimitConfig {
315            requests_per_minute: 60,
316            burst_size: 5,
317        });
318
319        // Different identifiers have separate buckets
320        let result1 = limiter.check_rate_limit("user1");
321        let result2 = limiter.check_rate_limit("user2");
322
323        assert!(result1.allowed);
324        assert!(result2.allowed);
325        assert_eq!(result1.remaining, 4);
326        assert_eq!(result2.remaining, 4);
327    }
328
329    #[test]
330    fn test_custom_config() {
331        let limiter = RateLimiter::new(RateLimitConfig::free_tier());
332
333        limiter.set_config("premium_user", RateLimitConfig::unlimited());
334
335        let free_result = limiter.check_rate_limit("free_user");
336        let premium_result = limiter.check_rate_limit("premium_user");
337
338        assert!(free_result.limit < premium_result.limit);
339    }
340
341    #[test]
342    fn test_rate_limit_with_cost() {
343        let limiter = RateLimiter::new(RateLimitConfig {
344            requests_per_minute: 60,
345            burst_size: 10,
346        });
347
348        // Expensive operation costs 5 tokens
349        let result = limiter.check_rate_limit_with_cost("user1", 5.0);
350        assert!(result.allowed);
351        assert_eq!(result.remaining, 5);
352    }
353}