allsource_core/infrastructure/security/
rate_limit.rs1use chrono::{DateTime, Utc};
11use dashmap::DashMap;
12use std::{sync::Arc, time::Duration};
13
14#[derive(Debug, Clone)]
16pub struct RateLimitConfig {
17 pub requests_per_minute: u32,
19 pub burst_size: u32,
21}
22
23impl RateLimitConfig {
24 pub fn free_tier() -> Self {
26 Self {
27 requests_per_minute: 60,
28 burst_size: 100,
29 }
30 }
31
32 pub fn professional() -> Self {
34 Self {
35 requests_per_minute: 600,
36 burst_size: 1000,
37 }
38 }
39
40 pub fn unlimited() -> Self {
42 Self {
43 requests_per_minute: 10_000,
44 burst_size: 20_000,
45 }
46 }
47
48 pub fn dev_mode() -> Self {
50 Self {
51 requests_per_minute: 100_000,
52 burst_size: 200_000,
53 }
54 }
55}
56
57#[derive(Debug, Clone)]
59struct TokenBucket {
60 tokens: f64,
62 max_tokens: f64,
64 refill_rate: f64,
66 last_refill: DateTime<Utc>,
68}
69
70impl TokenBucket {
71 fn new(config: &RateLimitConfig) -> Self {
72 let max_tokens = f64::from(config.burst_size);
73 Self {
74 tokens: max_tokens,
75 max_tokens,
76 refill_rate: f64::from(config.requests_per_minute) / 60.0, last_refill: Utc::now(),
78 }
79 }
80
81 fn try_consume(&mut self, tokens: f64) -> bool {
83 self.refill();
84
85 if self.tokens >= tokens {
86 self.tokens -= tokens;
87 true
88 } else {
89 false
90 }
91 }
92
93 fn refill(&mut self) {
95 let now = Utc::now();
96 let elapsed = (now - self.last_refill).num_milliseconds() as f64 / 1000.0;
97
98 if elapsed > 0.0 {
99 let new_tokens = elapsed * self.refill_rate;
100 self.tokens = (self.tokens + new_tokens).min(self.max_tokens);
101 self.last_refill = now;
102 }
103 }
104
105 fn remaining(&mut self) -> u32 {
107 self.refill();
108 self.tokens.floor() as u32
109 }
110
111 fn retry_after(&mut self) -> Duration {
113 self.refill();
114
115 if self.tokens >= 1.0 {
116 Duration::from_secs(0)
117 } else {
118 let tokens_needed = 1.0 - self.tokens;
119 let seconds = tokens_needed / self.refill_rate;
120 Duration::from_secs_f64(seconds)
121 }
122 }
123}
124
125pub struct RateLimiter {
127 buckets: Arc<DashMap<String, TokenBucket>>,
129 default_config: RateLimitConfig,
131 custom_configs: Arc<DashMap<String, RateLimitConfig>>,
133}
134
135impl RateLimiter {
136 pub fn new(default_config: RateLimitConfig) -> Self {
137 Self {
138 buckets: Arc::new(DashMap::new()),
139 default_config,
140 custom_configs: Arc::new(DashMap::new()),
141 }
142 }
143
144 pub fn set_config(&self, identifier: &str, config: RateLimitConfig) {
146 self.custom_configs.insert(identifier.to_string(), config);
147
148 self.buckets.remove(identifier);
150 }
151
152 pub fn check_rate_limit(&self, identifier: &str) -> RateLimitResult {
154 self.check_rate_limit_with_cost(identifier, 1.0)
155 }
156
157 pub fn check_rate_limit_with_cost(&self, identifier: &str, cost: f64) -> RateLimitResult {
159 let config = self
160 .custom_configs
161 .get(identifier)
162 .map_or_else(|| self.default_config.clone(), |c| c.clone());
163
164 let mut entry = self
165 .buckets
166 .entry(identifier.to_string())
167 .or_insert_with(|| TokenBucket::new(&config));
168
169 let allowed = entry.try_consume(cost);
170 let remaining = entry.remaining();
171 let retry_after = if allowed {
172 None
173 } else {
174 Some(entry.retry_after())
175 };
176
177 RateLimitResult {
178 allowed,
179 remaining,
180 retry_after,
181 limit: config.requests_per_minute,
182 }
183 }
184
185 pub fn get_stats(&self, identifier: &str) -> Option<RateLimitStats> {
187 self.buckets
188 .get_mut(identifier)
189 .map(|mut bucket| RateLimitStats {
190 remaining: bucket.remaining(),
191 retry_after: bucket.retry_after(),
192 })
193 }
194
195 pub fn cleanup(&self) {
197 let now = Utc::now();
198 self.buckets.retain(|_, bucket| {
199 (now - bucket.last_refill).num_hours() < 1
201 });
202 }
203}
204
205impl Default for RateLimiter {
206 fn default() -> Self {
207 Self::new(RateLimitConfig::professional())
208 }
209}
210
211#[derive(Debug, Clone)]
213pub struct RateLimitResult {
214 pub allowed: bool,
216 pub remaining: u32,
218 pub retry_after: Option<Duration>,
220 pub limit: u32,
222}
223
224#[derive(Debug, Clone)]
226pub struct RateLimitStats {
227 pub remaining: u32,
229 pub retry_after: Duration,
231}
232
233#[cfg(test)]
234mod tests {
235 use super::*;
236 use std::{thread, time::Duration as StdDuration};
237
238 #[test]
239 fn test_token_bucket_creation() {
240 let config = RateLimitConfig::free_tier();
241 let bucket = TokenBucket::new(&config);
242
243 assert_eq!(bucket.max_tokens, 100.0);
244 assert_eq!(bucket.tokens, 100.0);
245 }
246
247 #[test]
248 fn test_token_consumption() {
249 let config = RateLimitConfig::free_tier();
250 let mut bucket = TokenBucket::new(&config);
251
252 assert!(bucket.try_consume(1.0));
253 assert_eq!(bucket.remaining(), 99);
254
255 assert!(bucket.try_consume(10.0));
256 assert_eq!(bucket.remaining(), 89);
257 }
258
259 #[test]
260 fn test_rate_limit_enforcement() {
261 let config = RateLimitConfig {
262 requests_per_minute: 60,
263 burst_size: 10,
264 };
265 let mut bucket = TokenBucket::new(&config);
266
267 for _ in 0..10 {
269 assert!(bucket.try_consume(1.0));
270 }
271
272 assert!(!bucket.try_consume(1.0));
274 }
275
276 #[test]
277 fn test_token_refill() {
278 let config = RateLimitConfig {
279 requests_per_minute: 60, burst_size: 10,
281 };
282 let mut bucket = TokenBucket::new(&config);
283
284 for _ in 0..10 {
286 bucket.try_consume(1.0);
287 }
288
289 assert_eq!(bucket.remaining(), 0);
290
291 thread::sleep(StdDuration::from_secs(2));
293
294 let remaining = bucket.remaining();
296 assert!(
297 (1..=3).contains(&remaining),
298 "Expected 1-3 tokens, got {remaining}"
299 );
300 }
301
302 #[test]
303 fn test_rate_limiter_per_identifier() {
304 let limiter = RateLimiter::new(RateLimitConfig {
305 requests_per_minute: 60,
306 burst_size: 5,
307 });
308
309 let result1 = limiter.check_rate_limit("user1");
311 let result2 = limiter.check_rate_limit("user2");
312
313 assert!(result1.allowed);
314 assert!(result2.allowed);
315 assert_eq!(result1.remaining, 4);
316 assert_eq!(result2.remaining, 4);
317 }
318
319 #[test]
320 fn test_custom_config() {
321 let limiter = RateLimiter::new(RateLimitConfig::free_tier());
322
323 limiter.set_config("premium_user", RateLimitConfig::unlimited());
324
325 let free_result = limiter.check_rate_limit("free_user");
326 let premium_result = limiter.check_rate_limit("premium_user");
327
328 assert!(free_result.limit < premium_result.limit);
329 }
330
331 #[test]
332 fn test_rate_limit_with_cost() {
333 let limiter = RateLimiter::new(RateLimitConfig {
334 requests_per_minute: 60,
335 burst_size: 10,
336 });
337
338 let result = limiter.check_rate_limit_with_cost("user1", 5.0);
340 assert!(result.allowed);
341 assert_eq!(result.remaining, 5);
342 }
343}