allsource_core/
rate_limit.rs1use chrono::{DateTime, Utc};
12use dashmap::DashMap;
13use std::sync::Arc;
14use std::time::Duration;
15
16#[derive(Debug, Clone)]
18pub struct RateLimitConfig {
19 pub requests_per_minute: u32,
21 pub burst_size: u32,
23}
24
25impl RateLimitConfig {
26 pub fn free_tier() -> Self {
28 Self {
29 requests_per_minute: 60,
30 burst_size: 100,
31 }
32 }
33
34 pub fn professional() -> Self {
36 Self {
37 requests_per_minute: 600,
38 burst_size: 1000,
39 }
40 }
41
42 pub fn unlimited() -> Self {
44 Self {
45 requests_per_minute: 10_000,
46 burst_size: 20_000,
47 }
48 }
49
50 pub fn dev_mode() -> Self {
52 Self {
53 requests_per_minute: 100_000,
54 burst_size: 200_000,
55 }
56 }
57}
58
59#[derive(Debug, Clone)]
61struct TokenBucket {
62 tokens: f64,
64 max_tokens: f64,
66 refill_rate: f64,
68 last_refill: DateTime<Utc>,
70}
71
72impl TokenBucket {
73 fn new(config: &RateLimitConfig) -> Self {
74 let max_tokens = config.burst_size as f64;
75 Self {
76 tokens: max_tokens,
77 max_tokens,
78 refill_rate: config.requests_per_minute as f64 / 60.0, last_refill: Utc::now(),
80 }
81 }
82
83 fn try_consume(&mut self, tokens: f64) -> bool {
85 self.refill();
86
87 if self.tokens >= tokens {
88 self.tokens -= tokens;
89 true
90 } else {
91 false
92 }
93 }
94
95 fn refill(&mut self) {
97 let now = Utc::now();
98 let elapsed = (now - self.last_refill).num_milliseconds() as f64 / 1000.0;
99
100 if elapsed > 0.0 {
101 let new_tokens = elapsed * self.refill_rate;
102 self.tokens = (self.tokens + new_tokens).min(self.max_tokens);
103 self.last_refill = now;
104 }
105 }
106
107 fn remaining(&mut self) -> u32 {
109 self.refill();
110 self.tokens.floor() as u32
111 }
112
113 fn retry_after(&mut self) -> Duration {
115 self.refill();
116
117 if self.tokens >= 1.0 {
118 Duration::from_secs(0)
119 } else {
120 let tokens_needed = 1.0 - self.tokens;
121 let seconds = tokens_needed / self.refill_rate;
122 Duration::from_secs_f64(seconds)
123 }
124 }
125}
126
127pub struct RateLimiter {
129 buckets: Arc<DashMap<String, TokenBucket>>,
131 default_config: RateLimitConfig,
133 custom_configs: Arc<DashMap<String, RateLimitConfig>>,
135}
136
137impl RateLimiter {
138 pub fn new(default_config: RateLimitConfig) -> Self {
139 Self {
140 buckets: Arc::new(DashMap::new()),
141 default_config,
142 custom_configs: Arc::new(DashMap::new()),
143 }
144 }
145
146 pub fn set_config(&self, identifier: &str, config: RateLimitConfig) {
148 self.custom_configs.insert(identifier.to_string(), config);
149
150 self.buckets.remove(identifier);
152 }
153
154 pub fn check_rate_limit(&self, identifier: &str) -> RateLimitResult {
156 self.check_rate_limit_with_cost(identifier, 1.0)
157 }
158
159 pub fn check_rate_limit_with_cost(&self, identifier: &str, cost: f64) -> RateLimitResult {
161 let config = self.custom_configs
162 .get(identifier)
163 .map(|c| c.clone())
164 .unwrap_or_else(|| self.default_config.clone());
165
166 let mut entry = self.buckets
167 .entry(identifier.to_string())
168 .or_insert_with(|| TokenBucket::new(&config));
169
170 let allowed = entry.try_consume(cost);
171 let remaining = entry.remaining();
172 let retry_after = if !allowed {
173 Some(entry.retry_after())
174 } else {
175 None
176 };
177
178 RateLimitResult {
179 allowed,
180 remaining,
181 retry_after,
182 limit: config.requests_per_minute,
183 }
184 }
185
186 pub fn get_stats(&self, identifier: &str) -> Option<RateLimitStats> {
188 self.buckets.get_mut(identifier).map(|mut bucket| {
189 RateLimitStats {
190 remaining: bucket.remaining(),
191 retry_after: bucket.retry_after(),
192 }
193 })
194 }
195
196 pub fn cleanup(&self) {
198 let now = Utc::now();
199 self.buckets.retain(|_, bucket| {
200 (now - bucket.last_refill).num_hours() < 1
202 });
203 }
204}
205
206impl Default for RateLimiter {
207 fn default() -> Self {
208 Self::new(RateLimitConfig::professional())
209 }
210}
211
212#[derive(Debug, Clone)]
214pub struct RateLimitResult {
215 pub allowed: bool,
217 pub remaining: u32,
219 pub retry_after: Option<Duration>,
221 pub limit: u32,
223}
224
225#[derive(Debug, Clone)]
227pub struct RateLimitStats {
228 pub remaining: u32,
230 pub retry_after: Duration,
232}
233
234#[cfg(test)]
235mod tests {
236 use super::*;
237 use std::thread;
238 use std::time::Duration as StdDuration;
239
240 #[test]
241 fn test_token_bucket_creation() {
242 let config = RateLimitConfig::free_tier();
243 let bucket = TokenBucket::new(&config);
244
245 assert_eq!(bucket.max_tokens, 100.0);
246 assert_eq!(bucket.tokens, 100.0);
247 }
248
249 #[test]
250 fn test_token_consumption() {
251 let config = RateLimitConfig::free_tier();
252 let mut bucket = TokenBucket::new(&config);
253
254 assert!(bucket.try_consume(1.0));
255 assert_eq!(bucket.remaining(), 99);
256
257 assert!(bucket.try_consume(10.0));
258 assert_eq!(bucket.remaining(), 89);
259 }
260
261 #[test]
262 fn test_rate_limit_enforcement() {
263 let config = RateLimitConfig {
264 requests_per_minute: 60,
265 burst_size: 10,
266 };
267 let mut bucket = TokenBucket::new(&config);
268
269 for _ in 0..10 {
271 assert!(bucket.try_consume(1.0));
272 }
273
274 assert!(!bucket.try_consume(1.0));
276 }
277
278 #[test]
279 fn test_token_refill() {
280 let config = RateLimitConfig {
281 requests_per_minute: 60, burst_size: 10,
283 };
284 let mut bucket = TokenBucket::new(&config);
285
286 for _ in 0..10 {
288 bucket.try_consume(1.0);
289 }
290
291 assert_eq!(bucket.remaining(), 0);
292
293 thread::sleep(StdDuration::from_secs(2));
295
296 let remaining = bucket.remaining();
298 assert!(remaining >= 1 && remaining <= 3, "Expected 1-3 tokens, got {}", remaining);
299 }
300
301 #[test]
302 fn test_rate_limiter_per_identifier() {
303 let limiter = RateLimiter::new(RateLimitConfig {
304 requests_per_minute: 60,
305 burst_size: 5,
306 });
307
308 let result1 = limiter.check_rate_limit("user1");
310 let result2 = limiter.check_rate_limit("user2");
311
312 assert!(result1.allowed);
313 assert!(result2.allowed);
314 assert_eq!(result1.remaining, 4);
315 assert_eq!(result2.remaining, 4);
316 }
317
318 #[test]
319 fn test_custom_config() {
320 let limiter = RateLimiter::new(RateLimitConfig::free_tier());
321
322 limiter.set_config("premium_user", RateLimitConfig::unlimited());
323
324 let free_result = limiter.check_rate_limit("free_user");
325 let premium_result = limiter.check_rate_limit("premium_user");
326
327 assert!(free_result.limit < premium_result.limit);
328 }
329
330 #[test]
331 fn test_rate_limit_with_cost() {
332 let limiter = RateLimiter::new(RateLimitConfig {
333 requests_per_minute: 60,
334 burst_size: 10,
335 });
336
337 let result = limiter.check_rate_limit_with_cost("user1", 5.0);
339 assert!(result.allowed);
340 assert_eq!(result.remaining, 5);
341 }
342}