allsource_core/infrastructure/security/
rate_limit.rs1use chrono::{DateTime, Utc};
11use dashmap::DashMap;
12use std::{sync::Arc, time::Duration};
13
14#[derive(Debug, Clone)]
16pub struct RateLimitConfig {
17 pub requests_per_minute: u32,
19 pub burst_size: u32,
21}
22
23impl RateLimitConfig {
24 pub fn free_tier() -> Self {
26 Self {
27 requests_per_minute: 60,
28 burst_size: 100,
29 }
30 }
31
32 pub fn professional() -> Self {
34 Self {
35 requests_per_minute: 600,
36 burst_size: 1000,
37 }
38 }
39
40 pub fn unlimited() -> Self {
42 Self {
43 requests_per_minute: 10_000,
44 burst_size: 20_000,
45 }
46 }
47
48 pub fn dev_mode() -> Self {
50 Self {
51 requests_per_minute: 100_000,
52 burst_size: 200_000,
53 }
54 }
55}
56
57#[derive(Debug, Clone)]
59struct TokenBucket {
60 tokens: f64,
62 max_tokens: f64,
64 refill_rate: f64,
66 last_refill: DateTime<Utc>,
68}
69
70impl TokenBucket {
71 fn new(config: &RateLimitConfig) -> Self {
72 let max_tokens = config.burst_size as f64;
73 Self {
74 tokens: max_tokens,
75 max_tokens,
76 refill_rate: config.requests_per_minute as f64 / 60.0, last_refill: Utc::now(),
78 }
79 }
80
81 fn try_consume(&mut self, tokens: f64) -> bool {
83 self.refill();
84
85 if self.tokens >= tokens {
86 self.tokens -= tokens;
87 true
88 } else {
89 false
90 }
91 }
92
93 fn refill(&mut self) {
95 let now = Utc::now();
96 let elapsed = (now - self.last_refill).num_milliseconds() as f64 / 1000.0;
97
98 if elapsed > 0.0 {
99 let new_tokens = elapsed * self.refill_rate;
100 self.tokens = (self.tokens + new_tokens).min(self.max_tokens);
101 self.last_refill = now;
102 }
103 }
104
105 fn remaining(&mut self) -> u32 {
107 self.refill();
108 self.tokens.floor() as u32
109 }
110
111 fn retry_after(&mut self) -> Duration {
113 self.refill();
114
115 if self.tokens >= 1.0 {
116 Duration::from_secs(0)
117 } else {
118 let tokens_needed = 1.0 - self.tokens;
119 let seconds = tokens_needed / self.refill_rate;
120 Duration::from_secs_f64(seconds)
121 }
122 }
123}
124
125pub struct RateLimiter {
127 buckets: Arc<DashMap<String, TokenBucket>>,
129 default_config: RateLimitConfig,
131 custom_configs: Arc<DashMap<String, RateLimitConfig>>,
133}
134
135impl RateLimiter {
136 pub fn new(default_config: RateLimitConfig) -> Self {
137 Self {
138 buckets: Arc::new(DashMap::new()),
139 default_config,
140 custom_configs: Arc::new(DashMap::new()),
141 }
142 }
143
144 pub fn set_config(&self, identifier: &str, config: RateLimitConfig) {
146 self.custom_configs.insert(identifier.to_string(), config);
147
148 self.buckets.remove(identifier);
150 }
151
152 pub fn check_rate_limit(&self, identifier: &str) -> RateLimitResult {
154 self.check_rate_limit_with_cost(identifier, 1.0)
155 }
156
157 pub fn check_rate_limit_with_cost(&self, identifier: &str, cost: f64) -> RateLimitResult {
159 let config = self
160 .custom_configs
161 .get(identifier)
162 .map(|c| c.clone())
163 .unwrap_or_else(|| self.default_config.clone());
164
165 let mut entry = self
166 .buckets
167 .entry(identifier.to_string())
168 .or_insert_with(|| TokenBucket::new(&config));
169
170 let allowed = entry.try_consume(cost);
171 let remaining = entry.remaining();
172 let retry_after = if !allowed {
173 Some(entry.retry_after())
174 } else {
175 None
176 };
177
178 RateLimitResult {
179 allowed,
180 remaining,
181 retry_after,
182 limit: config.requests_per_minute,
183 }
184 }
185
186 pub fn get_stats(&self, identifier: &str) -> Option<RateLimitStats> {
188 self.buckets
189 .get_mut(identifier)
190 .map(|mut bucket| RateLimitStats {
191 remaining: bucket.remaining(),
192 retry_after: bucket.retry_after(),
193 })
194 }
195
196 pub fn cleanup(&self) {
198 let now = Utc::now();
199 self.buckets.retain(|_, bucket| {
200 (now - bucket.last_refill).num_hours() < 1
202 });
203 }
204}
205
206impl Default for RateLimiter {
207 fn default() -> Self {
208 Self::new(RateLimitConfig::professional())
209 }
210}
211
212#[derive(Debug, Clone)]
214pub struct RateLimitResult {
215 pub allowed: bool,
217 pub remaining: u32,
219 pub retry_after: Option<Duration>,
221 pub limit: u32,
223}
224
225#[derive(Debug, Clone)]
227pub struct RateLimitStats {
228 pub remaining: u32,
230 pub retry_after: Duration,
232}
233
234#[cfg(test)]
235mod tests {
236 use super::*;
237 use std::{thread, time::Duration as StdDuration};
238
239 #[test]
240 fn test_token_bucket_creation() {
241 let config = RateLimitConfig::free_tier();
242 let bucket = TokenBucket::new(&config);
243
244 assert_eq!(bucket.max_tokens, 100.0);
245 assert_eq!(bucket.tokens, 100.0);
246 }
247
248 #[test]
249 fn test_token_consumption() {
250 let config = RateLimitConfig::free_tier();
251 let mut bucket = TokenBucket::new(&config);
252
253 assert!(bucket.try_consume(1.0));
254 assert_eq!(bucket.remaining(), 99);
255
256 assert!(bucket.try_consume(10.0));
257 assert_eq!(bucket.remaining(), 89);
258 }
259
260 #[test]
261 fn test_rate_limit_enforcement() {
262 let config = RateLimitConfig {
263 requests_per_minute: 60,
264 burst_size: 10,
265 };
266 let mut bucket = TokenBucket::new(&config);
267
268 for _ in 0..10 {
270 assert!(bucket.try_consume(1.0));
271 }
272
273 assert!(!bucket.try_consume(1.0));
275 }
276
277 #[test]
278 fn test_token_refill() {
279 let config = RateLimitConfig {
280 requests_per_minute: 60, burst_size: 10,
282 };
283 let mut bucket = TokenBucket::new(&config);
284
285 for _ in 0..10 {
287 bucket.try_consume(1.0);
288 }
289
290 assert_eq!(bucket.remaining(), 0);
291
292 thread::sleep(StdDuration::from_secs(2));
294
295 let remaining = bucket.remaining();
297 assert!(
298 (1..=3).contains(&remaining),
299 "Expected 1-3 tokens, got {}",
300 remaining
301 );
302 }
303
304 #[test]
305 fn test_rate_limiter_per_identifier() {
306 let limiter = RateLimiter::new(RateLimitConfig {
307 requests_per_minute: 60,
308 burst_size: 5,
309 });
310
311 let result1 = limiter.check_rate_limit("user1");
313 let result2 = limiter.check_rate_limit("user2");
314
315 assert!(result1.allowed);
316 assert!(result2.allowed);
317 assert_eq!(result1.remaining, 4);
318 assert_eq!(result2.remaining, 4);
319 }
320
321 #[test]
322 fn test_custom_config() {
323 let limiter = RateLimiter::new(RateLimitConfig::free_tier());
324
325 limiter.set_config("premium_user", RateLimitConfig::unlimited());
326
327 let free_result = limiter.check_rate_limit("free_user");
328 let premium_result = limiter.check_rate_limit("premium_user");
329
330 assert!(free_result.limit < premium_result.limit);
331 }
332
333 #[test]
334 fn test_rate_limit_with_cost() {
335 let limiter = RateLimiter::new(RateLimitConfig {
336 requests_per_minute: 60,
337 burst_size: 10,
338 });
339
340 let result = limiter.check_rate_limit_with_cost("user1", 5.0);
342 assert!(result.allowed);
343 assert_eq!(result.remaining, 5);
344 }
345}