allsource_core/infrastructure/security/
rate_limit.rs1use chrono::{DateTime, Utc};
21use dashmap::DashMap;
22use std::{sync::Arc, time::Duration};
23
24#[derive(Debug, Clone)]
26pub struct RateLimitConfig {
27 pub requests_per_minute: u32,
29 pub burst_size: u32,
31}
32
33impl RateLimitConfig {
34 pub fn free_tier() -> Self {
36 Self {
37 requests_per_minute: 60,
38 burst_size: 100,
39 }
40 }
41
42 pub fn professional() -> Self {
44 Self {
45 requests_per_minute: 600,
46 burst_size: 1000,
47 }
48 }
49
50 pub fn unlimited() -> Self {
52 Self {
53 requests_per_minute: 10_000,
54 burst_size: 20_000,
55 }
56 }
57
58 pub fn dev_mode() -> Self {
60 Self {
61 requests_per_minute: 100_000,
62 burst_size: 200_000,
63 }
64 }
65}
66
67#[derive(Debug, Clone)]
69struct TokenBucket {
70 tokens: f64,
72 max_tokens: f64,
74 refill_rate: f64,
76 last_refill: DateTime<Utc>,
78}
79
80impl TokenBucket {
81 fn new(config: &RateLimitConfig) -> Self {
82 let max_tokens = f64::from(config.burst_size);
83 Self {
84 tokens: max_tokens,
85 max_tokens,
86 refill_rate: f64::from(config.requests_per_minute) / 60.0, last_refill: Utc::now(),
88 }
89 }
90
91 fn try_consume(&mut self, tokens: f64) -> bool {
93 self.refill();
94
95 if self.tokens >= tokens {
96 self.tokens -= tokens;
97 true
98 } else {
99 false
100 }
101 }
102
103 fn refill(&mut self) {
105 let now = Utc::now();
106 let elapsed = (now - self.last_refill).num_milliseconds() as f64 / 1000.0;
107
108 if elapsed > 0.0 {
109 let new_tokens = elapsed * self.refill_rate;
110 self.tokens = (self.tokens + new_tokens).min(self.max_tokens);
111 self.last_refill = now;
112 }
113 }
114
115 fn remaining(&mut self) -> u32 {
117 self.refill();
118 self.tokens.floor() as u32
119 }
120
121 fn retry_after(&mut self) -> Duration {
123 self.refill();
124
125 if self.tokens >= 1.0 {
126 Duration::from_secs(0)
127 } else {
128 let tokens_needed = 1.0 - self.tokens;
129 let seconds = tokens_needed / self.refill_rate;
130 Duration::from_secs_f64(seconds)
131 }
132 }
133}
134
135pub struct RateLimiter {
137 buckets: Arc<DashMap<String, TokenBucket>>,
139 default_config: RateLimitConfig,
141 custom_configs: Arc<DashMap<String, RateLimitConfig>>,
143}
144
145impl RateLimiter {
146 pub fn new(default_config: RateLimitConfig) -> Self {
147 Self {
148 buckets: Arc::new(DashMap::new()),
149 default_config,
150 custom_configs: Arc::new(DashMap::new()),
151 }
152 }
153
154 pub fn set_config(&self, identifier: &str, config: RateLimitConfig) {
156 self.custom_configs.insert(identifier.to_string(), config);
157
158 self.buckets.remove(identifier);
160 }
161
162 pub fn check_rate_limit(&self, identifier: &str) -> RateLimitResult {
164 self.check_rate_limit_with_cost(identifier, 1.0)
165 }
166
167 pub fn check_rate_limit_with_cost(&self, identifier: &str, cost: f64) -> RateLimitResult {
169 let config = self
170 .custom_configs
171 .get(identifier)
172 .map_or_else(|| self.default_config.clone(), |c| c.clone());
173
174 let mut entry = self
175 .buckets
176 .entry(identifier.to_string())
177 .or_insert_with(|| TokenBucket::new(&config));
178
179 let allowed = entry.try_consume(cost);
180 let remaining = entry.remaining();
181 let retry_after = if allowed {
182 None
183 } else {
184 Some(entry.retry_after())
185 };
186
187 RateLimitResult {
188 allowed,
189 remaining,
190 retry_after,
191 limit: config.requests_per_minute,
192 }
193 }
194
195 pub fn get_stats(&self, identifier: &str) -> Option<RateLimitStats> {
197 self.buckets
198 .get_mut(identifier)
199 .map(|mut bucket| RateLimitStats {
200 remaining: bucket.remaining(),
201 retry_after: bucket.retry_after(),
202 })
203 }
204
205 pub fn cleanup(&self) {
207 let now = Utc::now();
208 self.buckets.retain(|_, bucket| {
209 (now - bucket.last_refill).num_hours() < 1
211 });
212 }
213}
214
215impl Default for RateLimiter {
216 fn default() -> Self {
217 Self::new(RateLimitConfig::professional())
218 }
219}
220
221#[derive(Debug, Clone)]
223pub struct RateLimitResult {
224 pub allowed: bool,
226 pub remaining: u32,
228 pub retry_after: Option<Duration>,
230 pub limit: u32,
232}
233
234#[derive(Debug, Clone)]
236pub struct RateLimitStats {
237 pub remaining: u32,
239 pub retry_after: Duration,
241}
242
243#[cfg(test)]
244mod tests {
245 use super::*;
246 use std::{thread, time::Duration as StdDuration};
247
248 #[test]
249 fn test_token_bucket_creation() {
250 let config = RateLimitConfig::free_tier();
251 let bucket = TokenBucket::new(&config);
252
253 assert_eq!(bucket.max_tokens, 100.0);
254 assert_eq!(bucket.tokens, 100.0);
255 }
256
257 #[test]
258 fn test_token_consumption() {
259 let config = RateLimitConfig::free_tier();
260 let mut bucket = TokenBucket::new(&config);
261
262 assert!(bucket.try_consume(1.0));
263 assert_eq!(bucket.remaining(), 99);
264
265 assert!(bucket.try_consume(10.0));
266 assert_eq!(bucket.remaining(), 89);
267 }
268
269 #[test]
270 fn test_rate_limit_enforcement() {
271 let config = RateLimitConfig {
272 requests_per_minute: 60,
273 burst_size: 10,
274 };
275 let mut bucket = TokenBucket::new(&config);
276
277 for _ in 0..10 {
279 assert!(bucket.try_consume(1.0));
280 }
281
282 assert!(!bucket.try_consume(1.0));
284 }
285
286 #[test]
287 fn test_token_refill() {
288 let config = RateLimitConfig {
289 requests_per_minute: 60, burst_size: 10,
291 };
292 let mut bucket = TokenBucket::new(&config);
293
294 for _ in 0..10 {
296 bucket.try_consume(1.0);
297 }
298
299 assert_eq!(bucket.remaining(), 0);
300
301 thread::sleep(StdDuration::from_secs(2));
303
304 let remaining = bucket.remaining();
306 assert!(
307 (1..=3).contains(&remaining),
308 "Expected 1-3 tokens, got {remaining}"
309 );
310 }
311
312 #[test]
313 fn test_rate_limiter_per_identifier() {
314 let limiter = RateLimiter::new(RateLimitConfig {
315 requests_per_minute: 60,
316 burst_size: 5,
317 });
318
319 let result1 = limiter.check_rate_limit("user1");
321 let result2 = limiter.check_rate_limit("user2");
322
323 assert!(result1.allowed);
324 assert!(result2.allowed);
325 assert_eq!(result1.remaining, 4);
326 assert_eq!(result2.remaining, 4);
327 }
328
329 #[test]
330 fn test_custom_config() {
331 let limiter = RateLimiter::new(RateLimitConfig::free_tier());
332
333 limiter.set_config("premium_user", RateLimitConfig::unlimited());
334
335 let free_result = limiter.check_rate_limit("free_user");
336 let premium_result = limiter.check_rate_limit("premium_user");
337
338 assert!(free_result.limit < premium_result.limit);
339 }
340
341 #[test]
342 fn test_rate_limit_with_cost() {
343 let limiter = RateLimiter::new(RateLimitConfig {
344 requests_per_minute: 60,
345 burst_size: 10,
346 });
347
348 let result = limiter.check_rate_limit_with_cost("user1", 5.0);
350 assert!(result.allowed);
351 assert_eq!(result.remaining, 5);
352 }
353}