allsource_core/
rate_limit.rs1use chrono::{DateTime, Utc};
11use dashmap::DashMap;
12use std::sync::Arc;
13use std::time::Duration;
14
15#[derive(Debug, Clone)]
17pub struct RateLimitConfig {
18 pub requests_per_minute: u32,
20 pub burst_size: u32,
22}
23
24impl RateLimitConfig {
25 pub fn free_tier() -> Self {
27 Self {
28 requests_per_minute: 60,
29 burst_size: 100,
30 }
31 }
32
33 pub fn professional() -> Self {
35 Self {
36 requests_per_minute: 600,
37 burst_size: 1000,
38 }
39 }
40
41 pub fn unlimited() -> Self {
43 Self {
44 requests_per_minute: 10_000,
45 burst_size: 20_000,
46 }
47 }
48
49 pub fn dev_mode() -> Self {
51 Self {
52 requests_per_minute: 100_000,
53 burst_size: 200_000,
54 }
55 }
56}
57
58#[derive(Debug, Clone)]
60struct TokenBucket {
61 tokens: f64,
63 max_tokens: f64,
65 refill_rate: f64,
67 last_refill: DateTime<Utc>,
69}
70
71impl TokenBucket {
72 fn new(config: &RateLimitConfig) -> Self {
73 let max_tokens = config.burst_size as f64;
74 Self {
75 tokens: max_tokens,
76 max_tokens,
77 refill_rate: config.requests_per_minute as f64 / 60.0, last_refill: Utc::now(),
79 }
80 }
81
82 fn try_consume(&mut self, tokens: f64) -> bool {
84 self.refill();
85
86 if self.tokens >= tokens {
87 self.tokens -= tokens;
88 true
89 } else {
90 false
91 }
92 }
93
94 fn refill(&mut self) {
96 let now = Utc::now();
97 let elapsed = (now - self.last_refill).num_milliseconds() as f64 / 1000.0;
98
99 if elapsed > 0.0 {
100 let new_tokens = elapsed * self.refill_rate;
101 self.tokens = (self.tokens + new_tokens).min(self.max_tokens);
102 self.last_refill = now;
103 }
104 }
105
106 fn remaining(&mut self) -> u32 {
108 self.refill();
109 self.tokens.floor() as u32
110 }
111
112 fn retry_after(&mut self) -> Duration {
114 self.refill();
115
116 if self.tokens >= 1.0 {
117 Duration::from_secs(0)
118 } else {
119 let tokens_needed = 1.0 - self.tokens;
120 let seconds = tokens_needed / self.refill_rate;
121 Duration::from_secs_f64(seconds)
122 }
123 }
124}
125
126pub struct RateLimiter {
128 buckets: Arc<DashMap<String, TokenBucket>>,
130 default_config: RateLimitConfig,
132 custom_configs: Arc<DashMap<String, RateLimitConfig>>,
134}
135
136impl RateLimiter {
137 pub fn new(default_config: RateLimitConfig) -> Self {
138 Self {
139 buckets: Arc::new(DashMap::new()),
140 default_config,
141 custom_configs: Arc::new(DashMap::new()),
142 }
143 }
144
145 pub fn set_config(&self, identifier: &str, config: RateLimitConfig) {
147 self.custom_configs.insert(identifier.to_string(), config);
148
149 self.buckets.remove(identifier);
151 }
152
153 pub fn check_rate_limit(&self, identifier: &str) -> RateLimitResult {
155 self.check_rate_limit_with_cost(identifier, 1.0)
156 }
157
158 pub fn check_rate_limit_with_cost(&self, identifier: &str, cost: f64) -> RateLimitResult {
160 let config = self
161 .custom_configs
162 .get(identifier)
163 .map(|c| c.clone())
164 .unwrap_or_else(|| self.default_config.clone());
165
166 let mut entry = self
167 .buckets
168 .entry(identifier.to_string())
169 .or_insert_with(|| TokenBucket::new(&config));
170
171 let allowed = entry.try_consume(cost);
172 let remaining = entry.remaining();
173 let retry_after = if !allowed {
174 Some(entry.retry_after())
175 } else {
176 None
177 };
178
179 RateLimitResult {
180 allowed,
181 remaining,
182 retry_after,
183 limit: config.requests_per_minute,
184 }
185 }
186
187 pub fn get_stats(&self, identifier: &str) -> Option<RateLimitStats> {
189 self.buckets
190 .get_mut(identifier)
191 .map(|mut bucket| RateLimitStats {
192 remaining: bucket.remaining(),
193 retry_after: bucket.retry_after(),
194 })
195 }
196
197 pub fn cleanup(&self) {
199 let now = Utc::now();
200 self.buckets.retain(|_, bucket| {
201 (now - bucket.last_refill).num_hours() < 1
203 });
204 }
205}
206
207impl Default for RateLimiter {
208 fn default() -> Self {
209 Self::new(RateLimitConfig::professional())
210 }
211}
212
213#[derive(Debug, Clone)]
215pub struct RateLimitResult {
216 pub allowed: bool,
218 pub remaining: u32,
220 pub retry_after: Option<Duration>,
222 pub limit: u32,
224}
225
226#[derive(Debug, Clone)]
228pub struct RateLimitStats {
229 pub remaining: u32,
231 pub retry_after: Duration,
233}
234
235#[cfg(test)]
236mod tests {
237 use super::*;
238 use std::thread;
239 use std::time::Duration as StdDuration;
240
241 #[test]
242 fn test_token_bucket_creation() {
243 let config = RateLimitConfig::free_tier();
244 let bucket = TokenBucket::new(&config);
245
246 assert_eq!(bucket.max_tokens, 100.0);
247 assert_eq!(bucket.tokens, 100.0);
248 }
249
250 #[test]
251 fn test_token_consumption() {
252 let config = RateLimitConfig::free_tier();
253 let mut bucket = TokenBucket::new(&config);
254
255 assert!(bucket.try_consume(1.0));
256 assert_eq!(bucket.remaining(), 99);
257
258 assert!(bucket.try_consume(10.0));
259 assert_eq!(bucket.remaining(), 89);
260 }
261
262 #[test]
263 fn test_rate_limit_enforcement() {
264 let config = RateLimitConfig {
265 requests_per_minute: 60,
266 burst_size: 10,
267 };
268 let mut bucket = TokenBucket::new(&config);
269
270 for _ in 0..10 {
272 assert!(bucket.try_consume(1.0));
273 }
274
275 assert!(!bucket.try_consume(1.0));
277 }
278
279 #[test]
280 fn test_token_refill() {
281 let config = RateLimitConfig {
282 requests_per_minute: 60, burst_size: 10,
284 };
285 let mut bucket = TokenBucket::new(&config);
286
287 for _ in 0..10 {
289 bucket.try_consume(1.0);
290 }
291
292 assert_eq!(bucket.remaining(), 0);
293
294 thread::sleep(StdDuration::from_secs(2));
296
297 let remaining = bucket.remaining();
299 assert!(
300 remaining >= 1 && remaining <= 3,
301 "Expected 1-3 tokens, got {}",
302 remaining
303 );
304 }
305
306 #[test]
307 fn test_rate_limiter_per_identifier() {
308 let limiter = RateLimiter::new(RateLimitConfig {
309 requests_per_minute: 60,
310 burst_size: 5,
311 });
312
313 let result1 = limiter.check_rate_limit("user1");
315 let result2 = limiter.check_rate_limit("user2");
316
317 assert!(result1.allowed);
318 assert!(result2.allowed);
319 assert_eq!(result1.remaining, 4);
320 assert_eq!(result2.remaining, 4);
321 }
322
323 #[test]
324 fn test_custom_config() {
325 let limiter = RateLimiter::new(RateLimitConfig::free_tier());
326
327 limiter.set_config("premium_user", RateLimitConfig::unlimited());
328
329 let free_result = limiter.check_rate_limit("free_user");
330 let premium_result = limiter.check_rate_limit("premium_user");
331
332 assert!(free_result.limit < premium_result.limit);
333 }
334
335 #[test]
336 fn test_rate_limit_with_cost() {
337 let limiter = RateLimiter::new(RateLimitConfig {
338 requests_per_minute: 60,
339 burst_size: 10,
340 });
341
342 let result = limiter.check_rate_limit_with_cost("user1", 5.0);
344 assert!(result.allowed);
345 assert_eq!(result.remaining, 5);
346 }
347}