allsource_core/
rate_limit.rs

1/// Rate limiting implementation using token bucket algorithm
2///
3/// Features:
4/// - Per-tenant rate limiting
5/// - Per-user rate limiting
6/// - Per-API key rate limiting
7/// - Configurable limits
8/// - Efficient in-memory storage with DashMap
9/// - Automatic token replenishment
10use chrono::{DateTime, Utc};
11use dashmap::DashMap;
12use std::sync::Arc;
13use std::time::Duration;
14
15/// Rate limit configuration for different resource types
16#[derive(Debug, Clone)]
17pub struct RateLimitConfig {
18    /// Maximum requests per minute
19    pub requests_per_minute: u32,
20    /// Maximum burst size
21    pub burst_size: u32,
22}
23
24impl RateLimitConfig {
25    /// Free tier: 60 requests/min
26    pub fn free_tier() -> Self {
27        Self {
28            requests_per_minute: 60,
29            burst_size: 100,
30        }
31    }
32
33    /// Professional tier: 600 requests/min
34    pub fn professional() -> Self {
35        Self {
36            requests_per_minute: 600,
37            burst_size: 1000,
38        }
39    }
40
41    /// Unlimited tier: 10,000 requests/min
42    pub fn unlimited() -> Self {
43        Self {
44            requests_per_minute: 10_000,
45            burst_size: 20_000,
46        }
47    }
48
49    /// Development mode: Very high limits
50    pub fn dev_mode() -> Self {
51        Self {
52            requests_per_minute: 100_000,
53            burst_size: 200_000,
54        }
55    }
56}
57
58/// Token bucket for rate limiting
59#[derive(Debug, Clone)]
60struct TokenBucket {
61    /// Current number of tokens
62    tokens: f64,
63    /// Maximum tokens (burst size)
64    max_tokens: f64,
65    /// Tokens added per second
66    refill_rate: f64,
67    /// Last refill timestamp
68    last_refill: DateTime<Utc>,
69}
70
71impl TokenBucket {
72    fn new(config: &RateLimitConfig) -> Self {
73        let max_tokens = config.burst_size as f64;
74        Self {
75            tokens: max_tokens,
76            max_tokens,
77            refill_rate: config.requests_per_minute as f64 / 60.0, // tokens per second
78            last_refill: Utc::now(),
79        }
80    }
81
82    /// Try to consume a token. Returns true if successful.
83    fn try_consume(&mut self, tokens: f64) -> bool {
84        self.refill();
85
86        if self.tokens >= tokens {
87            self.tokens -= tokens;
88            true
89        } else {
90            false
91        }
92    }
93
94    /// Refill tokens based on time elapsed
95    fn refill(&mut self) {
96        let now = Utc::now();
97        let elapsed = (now - self.last_refill).num_milliseconds() as f64 / 1000.0;
98
99        if elapsed > 0.0 {
100            let new_tokens = elapsed * self.refill_rate;
101            self.tokens = (self.tokens + new_tokens).min(self.max_tokens);
102            self.last_refill = now;
103        }
104    }
105
106    /// Get remaining tokens
107    fn remaining(&mut self) -> u32 {
108        self.refill();
109        self.tokens.floor() as u32
110    }
111
112    /// Get time until next token is available
113    fn retry_after(&mut self) -> Duration {
114        self.refill();
115
116        if self.tokens >= 1.0 {
117            Duration::from_secs(0)
118        } else {
119            let tokens_needed = 1.0 - self.tokens;
120            let seconds = tokens_needed / self.refill_rate;
121            Duration::from_secs_f64(seconds)
122        }
123    }
124}
125
126/// Rate limiter using token bucket algorithm
127pub struct RateLimiter {
128    /// Buckets keyed by identifier (tenant_id, user_id, or api_key_id)
129    buckets: Arc<DashMap<String, TokenBucket>>,
130    /// Default configuration
131    default_config: RateLimitConfig,
132    /// Custom configs for specific identifiers
133    custom_configs: Arc<DashMap<String, RateLimitConfig>>,
134}
135
136impl RateLimiter {
137    pub fn new(default_config: RateLimitConfig) -> Self {
138        Self {
139            buckets: Arc::new(DashMap::new()),
140            default_config,
141            custom_configs: Arc::new(DashMap::new()),
142        }
143    }
144
145    /// Set custom config for a specific identifier
146    pub fn set_config(&self, identifier: &str, config: RateLimitConfig) {
147        self.custom_configs.insert(identifier.to_string(), config);
148
149        // Reset bucket with new config
150        self.buckets.remove(identifier);
151    }
152
153    /// Check if request is allowed
154    pub fn check_rate_limit(&self, identifier: &str) -> RateLimitResult {
155        self.check_rate_limit_with_cost(identifier, 1.0)
156    }
157
158    /// Check rate limit with custom cost (for expensive operations)
159    pub fn check_rate_limit_with_cost(&self, identifier: &str, cost: f64) -> RateLimitResult {
160        let config = self
161            .custom_configs
162            .get(identifier)
163            .map(|c| c.clone())
164            .unwrap_or_else(|| self.default_config.clone());
165
166        let mut entry = self
167            .buckets
168            .entry(identifier.to_string())
169            .or_insert_with(|| TokenBucket::new(&config));
170
171        let allowed = entry.try_consume(cost);
172        let remaining = entry.remaining();
173        let retry_after = if !allowed {
174            Some(entry.retry_after())
175        } else {
176            None
177        };
178
179        RateLimitResult {
180            allowed,
181            remaining,
182            retry_after,
183            limit: config.requests_per_minute,
184        }
185    }
186
187    /// Get current stats for an identifier
188    pub fn get_stats(&self, identifier: &str) -> Option<RateLimitStats> {
189        self.buckets
190            .get_mut(identifier)
191            .map(|mut bucket| RateLimitStats {
192                remaining: bucket.remaining(),
193                retry_after: bucket.retry_after(),
194            })
195    }
196
197    /// Cleanup old buckets (call periodically)
198    pub fn cleanup(&self) {
199        let now = Utc::now();
200        self.buckets.retain(|_, bucket| {
201            // Remove buckets that haven't been used in the last hour
202            (now - bucket.last_refill).num_hours() < 1
203        });
204    }
205}
206
207impl Default for RateLimiter {
208    fn default() -> Self {
209        Self::new(RateLimitConfig::professional())
210    }
211}
212
213/// Result of a rate limit check
214#[derive(Debug, Clone)]
215pub struct RateLimitResult {
216    /// Whether the request is allowed
217    pub allowed: bool,
218    /// Remaining requests
219    pub remaining: u32,
220    /// Time to wait before retrying (if not allowed)
221    pub retry_after: Option<Duration>,
222    /// Total limit per minute
223    pub limit: u32,
224}
225
226/// Current rate limit statistics
227#[derive(Debug, Clone)]
228pub struct RateLimitStats {
229    /// Remaining requests
230    pub remaining: u32,
231    /// Time until next token
232    pub retry_after: Duration,
233}
234
235#[cfg(test)]
236mod tests {
237    use super::*;
238    use std::thread;
239    use std::time::Duration as StdDuration;
240
241    #[test]
242    fn test_token_bucket_creation() {
243        let config = RateLimitConfig::free_tier();
244        let bucket = TokenBucket::new(&config);
245
246        assert_eq!(bucket.max_tokens, 100.0);
247        assert_eq!(bucket.tokens, 100.0);
248    }
249
250    #[test]
251    fn test_token_consumption() {
252        let config = RateLimitConfig::free_tier();
253        let mut bucket = TokenBucket::new(&config);
254
255        assert!(bucket.try_consume(1.0));
256        assert_eq!(bucket.remaining(), 99);
257
258        assert!(bucket.try_consume(10.0));
259        assert_eq!(bucket.remaining(), 89);
260    }
261
262    #[test]
263    fn test_rate_limit_enforcement() {
264        let config = RateLimitConfig {
265            requests_per_minute: 60,
266            burst_size: 10,
267        };
268        let mut bucket = TokenBucket::new(&config);
269
270        // Should allow up to burst size
271        for _ in 0..10 {
272            assert!(bucket.try_consume(1.0));
273        }
274
275        // Should deny after burst exhausted
276        assert!(!bucket.try_consume(1.0));
277    }
278
279    #[test]
280    fn test_token_refill() {
281        let config = RateLimitConfig {
282            requests_per_minute: 60, // 1 per second
283            burst_size: 10,
284        };
285        let mut bucket = TokenBucket::new(&config);
286
287        // Consume all tokens
288        for _ in 0..10 {
289            bucket.try_consume(1.0);
290        }
291
292        assert_eq!(bucket.remaining(), 0);
293
294        // Wait for refill (simulate)
295        thread::sleep(StdDuration::from_secs(2));
296
297        // Should have ~2 tokens refilled
298        let remaining = bucket.remaining();
299        assert!(
300            remaining >= 1 && remaining <= 3,
301            "Expected 1-3 tokens, got {}",
302            remaining
303        );
304    }
305
306    #[test]
307    fn test_rate_limiter_per_identifier() {
308        let limiter = RateLimiter::new(RateLimitConfig {
309            requests_per_minute: 60,
310            burst_size: 5,
311        });
312
313        // Different identifiers have separate buckets
314        let result1 = limiter.check_rate_limit("user1");
315        let result2 = limiter.check_rate_limit("user2");
316
317        assert!(result1.allowed);
318        assert!(result2.allowed);
319        assert_eq!(result1.remaining, 4);
320        assert_eq!(result2.remaining, 4);
321    }
322
323    #[test]
324    fn test_custom_config() {
325        let limiter = RateLimiter::new(RateLimitConfig::free_tier());
326
327        limiter.set_config("premium_user", RateLimitConfig::unlimited());
328
329        let free_result = limiter.check_rate_limit("free_user");
330        let premium_result = limiter.check_rate_limit("premium_user");
331
332        assert!(free_result.limit < premium_result.limit);
333    }
334
335    #[test]
336    fn test_rate_limit_with_cost() {
337        let limiter = RateLimiter::new(RateLimitConfig {
338            requests_per_minute: 60,
339            burst_size: 10,
340        });
341
342        // Expensive operation costs 5 tokens
343        let result = limiter.check_rate_limit_with_cost("user1", 5.0);
344        assert!(result.allowed);
345        assert_eq!(result.remaining, 5);
346    }
347}