1use std::collections::{HashMap, VecDeque};
14use std::time::{Duration, Instant};
15
16use serde::Serialize;
17
18#[derive(Debug, Clone)]
22pub struct RateLimitConfig {
23 pub max_requests: u32,
25 pub window: Duration,
27 pub enabled: bool,
29}
30
31impl RateLimitConfig {
32 pub fn default_config() -> Self {
34 RateLimitConfig {
35 max_requests: 100,
36 window: Duration::from_secs(60),
37 enabled: true,
38 }
39 }
40
41 pub fn disabled() -> Self {
43 RateLimitConfig {
44 max_requests: 0,
45 window: Duration::from_secs(0),
46 enabled: false,
47 }
48 }
49}
50
51#[derive(Debug, Clone, Serialize)]
55pub struct RateLimitResult {
56 pub allowed: bool,
58 pub remaining: u32,
60 pub limit: u32,
62 pub reset_secs: u64,
64}
65
66#[derive(Debug, Clone)]
70pub struct ClientRateMetric {
71 pub client_key: String,
72 pub total_requests: u64,
73 pub rejected: u64,
74 pub current_window_count: u32,
75}
76
77struct ClientBucket {
79 timestamps: VecDeque<Instant>,
80 total_requests: u64,
81 rejected: u64,
82}
83
84impl ClientBucket {
85 fn new() -> Self {
86 ClientBucket {
87 timestamps: VecDeque::new(),
88 total_requests: 0,
89 rejected: 0,
90 }
91 }
92
93 fn prune_and_count(&mut self, now: Instant, window: Duration) -> u32 {
95 let cutoff = now.checked_sub(window).unwrap_or(now);
96 while let Some(&front) = self.timestamps.front() {
97 if front < cutoff {
98 self.timestamps.pop_front();
99 } else {
100 break;
101 }
102 }
103 self.timestamps.len() as u32
104 }
105
106 fn reset_time(&self, now: Instant, window: Duration) -> u64 {
108 if let Some(&oldest) = self.timestamps.front() {
109 let expires_at = oldest + window;
110 if expires_at > now {
111 return (expires_at - now).as_secs();
112 }
113 }
114 0
115 }
116}
117
118pub struct RateLimiter {
120 config: RateLimitConfig,
121 buckets: HashMap<String, ClientBucket>,
122}
123
124impl RateLimiter {
125 pub fn new(config: RateLimitConfig) -> Self {
127 RateLimiter {
128 config,
129 buckets: HashMap::new(),
130 }
131 }
132
133 pub fn check(&mut self, client_key: &str) -> RateLimitResult {
136 if !self.config.enabled {
137 return RateLimitResult {
138 allowed: true,
139 remaining: u32::MAX,
140 limit: 0,
141 reset_secs: 0,
142 };
143 }
144
145 let now = Instant::now();
146 let bucket = self.buckets
147 .entry(client_key.to_string())
148 .or_insert_with(ClientBucket::new);
149
150 let count = bucket.prune_and_count(now, self.config.window);
151 bucket.total_requests += 1;
152
153 if count >= self.config.max_requests {
154 bucket.rejected += 1;
155 let reset = bucket.reset_time(now, self.config.window);
156 return RateLimitResult {
157 allowed: false,
158 remaining: 0,
159 limit: self.config.max_requests,
160 reset_secs: reset,
161 };
162 }
163
164 bucket.timestamps.push_back(now);
166 let remaining = self.config.max_requests - count - 1;
167 let reset = bucket.reset_time(now, self.config.window);
168
169 RateLimitResult {
170 allowed: true,
171 remaining,
172 limit: self.config.max_requests,
173 reset_secs: reset,
174 }
175 }
176
177 pub fn peek(&mut self, client_key: &str) -> RateLimitResult {
179 if !self.config.enabled {
180 return RateLimitResult {
181 allowed: true,
182 remaining: u32::MAX,
183 limit: 0,
184 reset_secs: 0,
185 };
186 }
187
188 let now = Instant::now();
189 let bucket = self.buckets
190 .entry(client_key.to_string())
191 .or_insert_with(ClientBucket::new);
192
193 let count = bucket.prune_and_count(now, self.config.window);
194 let remaining = self.config.max_requests.saturating_sub(count);
195 let reset = bucket.reset_time(now, self.config.window);
196
197 RateLimitResult {
198 allowed: remaining > 0,
199 remaining,
200 limit: self.config.max_requests,
201 reset_secs: reset,
202 }
203 }
204
205 pub fn client_count(&self) -> usize {
207 self.buckets.len()
208 }
209
210 pub fn cleanup(&mut self) {
212 let now = Instant::now();
213 let window = self.config.window;
214 self.buckets.retain(|_, bucket| {
215 bucket.prune_and_count(now, window);
216 !bucket.timestamps.is_empty()
217 });
218 }
219
220 pub fn config(&self) -> &RateLimitConfig {
222 &self.config
223 }
224
225 pub fn update_config(&mut self, max_requests: Option<u32>, window_secs: Option<u64>, enabled: Option<bool>) {
227 if let Some(max) = max_requests {
228 self.config.max_requests = max;
229 }
230 if let Some(secs) = window_secs {
231 self.config.window = Duration::from_secs(secs);
232 }
233 if let Some(en) = enabled {
234 self.config.enabled = en;
235 }
236 }
237
238 pub fn client_metrics(&mut self) -> Vec<ClientRateMetric> {
240 let now = Instant::now();
241 let window = self.config.window;
242 self.buckets.iter_mut().map(|(key, bucket)| {
243 let current = bucket.prune_and_count(now, window);
244 ClientRateMetric {
245 client_key: key.clone(),
246 total_requests: bucket.total_requests,
247 rejected: bucket.rejected,
248 current_window_count: current,
249 }
250 }).collect()
251 }
252}
253
254use crate::tenant::TenantPlan;
257
258pub struct TenantQuotas {
261 pub requests_per_min: u32,
263 pub tokens_per_day: u64,
265}
266
267impl TenantQuotas {
268 pub fn for_plan(plan: &TenantPlan) -> Self {
269 match plan {
270 TenantPlan::Starter => Self { requests_per_min: 60, tokens_per_day: 100_000 },
271 TenantPlan::Pro => Self { requests_per_min: 300, tokens_per_day: 1_000_000 },
272 TenantPlan::Enterprise => Self { requests_per_min: 2000, tokens_per_day: u64::MAX },
273 }
274 }
275}
276
277struct TokenBucket {
279 used: u64,
280 window_start: std::time::Instant,
281}
282
283impl TokenBucket {
284 fn new() -> Self {
285 Self { used: 0, window_start: std::time::Instant::now() }
286 }
287
288 fn refresh(&mut self) {
290 if self.window_start.elapsed() >= Duration::from_secs(86400) {
291 self.used = 0;
292 self.window_start = std::time::Instant::now();
293 }
294 }
295
296 fn add(&mut self, tokens: u64) {
297 self.refresh();
298 self.used = self.used.saturating_add(tokens);
299 }
300
301 fn can_consume(&mut self, limit: u64) -> bool {
302 self.refresh();
303 self.used < limit
304 }
305}
306
307pub struct TenantRateLimiter {
316 request_limiters: HashMap<String, RateLimiter>,
318 token_buckets: HashMap<String, TokenBucket>,
320}
321
322impl TenantRateLimiter {
323 pub fn new() -> Self {
324 Self {
325 request_limiters: HashMap::new(),
326 token_buckets: HashMap::new(),
327 }
328 }
329
330 pub fn check_request(&mut self, tenant_id: &str, plan: &TenantPlan) -> RateLimitResult {
333 let quotas = TenantQuotas::for_plan(plan);
334 let limiter = self.request_limiters
335 .entry(tenant_id.to_string())
336 .or_insert_with(|| {
337 RateLimiter::new(RateLimitConfig {
338 max_requests: quotas.requests_per_min,
339 window: Duration::from_secs(60),
340 enabled: true,
341 })
342 });
343 limiter.update_config(Some(quotas.requests_per_min), None, None);
345 limiter.check(tenant_id)
346 }
347
348 pub fn record_tokens(&mut self, tenant_id: &str, tokens: u64) {
350 self.token_buckets
351 .entry(tenant_id.to_string())
352 .or_insert_with(TokenBucket::new)
353 .add(tokens);
354 }
355
356 pub fn check_token_quota(&mut self, tenant_id: &str, plan: &TenantPlan) -> bool {
359 let limit = TenantQuotas::for_plan(plan).tokens_per_day;
360 if limit == u64::MAX {
361 return true; }
363 self.token_buckets
364 .entry(tenant_id.to_string())
365 .or_insert_with(TokenBucket::new)
366 .can_consume(limit)
367 }
368
369 pub fn token_usage(&mut self, tenant_id: &str, plan: &TenantPlan) -> (u64, u64) {
371 let limit = TenantQuotas::for_plan(plan).tokens_per_day;
372 let bucket = self.token_buckets
373 .entry(tenant_id.to_string())
374 .or_insert_with(TokenBucket::new);
375 bucket.refresh();
376 (bucket.used, limit)
377 }
378
379 pub fn tenant_count(&self) -> usize {
381 self.request_limiters.len()
382 }
383
384 pub fn cleanup(&mut self) {
386 let now = std::time::Instant::now();
387 self.token_buckets.retain(|_, b| {
389 b.window_start.elapsed() < Duration::from_secs(86400 * 2)
390 });
391 for limiter in self.request_limiters.values_mut() {
393 limiter.cleanup();
394 }
395 let _ = now; }
397}
398
399#[cfg(test)]
402mod tests {
403 use super::*;
404
405 fn fast_config(max: u32, window_ms: u64) -> RateLimitConfig {
406 RateLimitConfig {
407 max_requests: max,
408 window: Duration::from_millis(window_ms),
409 enabled: true,
410 }
411 }
412
413 #[test]
414 fn allows_within_limit() {
415 let mut limiter = RateLimiter::new(fast_config(5, 1000));
416 for i in 0..5 {
417 let result = limiter.check("client_a");
418 assert!(result.allowed, "request {} should be allowed", i);
419 assert_eq!(result.remaining, 4 - i as u32);
420 assert_eq!(result.limit, 5);
421 }
422 }
423
424 #[test]
425 fn denies_over_limit() {
426 let mut limiter = RateLimiter::new(fast_config(3, 60_000));
427 for _ in 0..3 {
428 assert!(limiter.check("client_a").allowed);
429 }
430 let result = limiter.check("client_a");
431 assert!(!result.allowed);
432 assert_eq!(result.remaining, 0);
433 }
434
435 #[test]
436 fn separate_clients_independent() {
437 let mut limiter = RateLimiter::new(fast_config(2, 60_000));
438 assert!(limiter.check("alice").allowed);
439 assert!(limiter.check("alice").allowed);
440 assert!(!limiter.check("alice").allowed);
441
442 assert!(limiter.check("bob").allowed);
444 assert!(limiter.check("bob").allowed);
445 assert!(!limiter.check("bob").allowed);
446 }
447
448 #[test]
449 fn window_expiry_allows_again() {
450 let mut limiter = RateLimiter::new(fast_config(2, 1)); assert!(limiter.check("client").allowed);
452 assert!(limiter.check("client").allowed);
453 assert!(!limiter.check("client").allowed);
454
455 std::thread::sleep(Duration::from_millis(5));
457 assert!(limiter.check("client").allowed);
458 }
459
460 #[test]
461 fn disabled_always_allows() {
462 let mut limiter = RateLimiter::new(RateLimitConfig::disabled());
463 for _ in 0..1000 {
464 let result = limiter.check("anyone");
465 assert!(result.allowed);
466 assert_eq!(result.remaining, u32::MAX);
467 }
468 }
469
470 #[test]
471 fn peek_does_not_consume() {
472 let mut limiter = RateLimiter::new(fast_config(3, 60_000));
473 limiter.check("client"); let peek1 = limiter.peek("client");
476 assert!(peek1.allowed);
477 assert_eq!(peek1.remaining, 2);
478
479 let peek2 = limiter.peek("client");
480 assert_eq!(peek2.remaining, 2); }
482
483 #[test]
484 fn client_count_tracks_unique() {
485 let mut limiter = RateLimiter::new(fast_config(10, 60_000));
486 assert_eq!(limiter.client_count(), 0);
487
488 limiter.check("a");
489 assert_eq!(limiter.client_count(), 1);
490
491 limiter.check("b");
492 assert_eq!(limiter.client_count(), 2);
493
494 limiter.check("a"); assert_eq!(limiter.client_count(), 2);
496 }
497
498 #[test]
499 fn cleanup_removes_expired() {
500 let mut limiter = RateLimiter::new(fast_config(5, 1)); limiter.check("temp");
502 assert_eq!(limiter.client_count(), 1);
503
504 std::thread::sleep(Duration::from_millis(5));
505 limiter.cleanup();
506 assert_eq!(limiter.client_count(), 0);
507 }
508
509 #[test]
510 fn reset_secs_positive_when_active() {
511 let mut limiter = RateLimiter::new(fast_config(5, 60_000)); let result = limiter.check("client");
513 assert!(result.allowed);
514 assert!(result.reset_secs <= 60);
516 }
517
518 #[test]
519 fn result_serializes_to_json() {
520 let result = RateLimitResult {
521 allowed: true,
522 remaining: 42,
523 limit: 100,
524 reset_secs: 30,
525 };
526 let json = serde_json::to_string(&result).unwrap();
527 assert!(json.contains("\"allowed\":true"));
528 assert!(json.contains("\"remaining\":42"));
529 assert!(json.contains("\"limit\":100"));
530 }
531
532 #[test]
533 fn default_config_values() {
534 let cfg = RateLimitConfig::default_config();
535 assert_eq!(cfg.max_requests, 100);
536 assert_eq!(cfg.window, Duration::from_secs(60));
537 assert!(cfg.enabled);
538 }
539
540 #[test]
541 fn single_request_limit() {
542 let mut limiter = RateLimiter::new(fast_config(1, 60_000));
543 assert!(limiter.check("client").allowed);
544 assert!(!limiter.check("client").allowed);
545 }
546
547 #[test]
548 fn remaining_decrements_correctly() {
549 let mut limiter = RateLimiter::new(fast_config(5, 60_000));
550 assert_eq!(limiter.check("c").remaining, 4);
551 assert_eq!(limiter.check("c").remaining, 3);
552 assert_eq!(limiter.check("c").remaining, 2);
553 assert_eq!(limiter.check("c").remaining, 1);
554 assert_eq!(limiter.check("c").remaining, 0);
555 let denied = limiter.check("c");
557 assert!(!denied.allowed);
558 assert_eq!(denied.remaining, 0);
559 }
560
561 #[test]
564 fn tenant_limiter_starter_quota() {
565 let quotas = TenantQuotas::for_plan(&TenantPlan::Starter);
566 assert_eq!(quotas.requests_per_min, 60);
567 assert_eq!(quotas.tokens_per_day, 100_000);
568 }
569
570 #[test]
571 fn tenant_limiter_pro_quota() {
572 let quotas = TenantQuotas::for_plan(&TenantPlan::Pro);
573 assert_eq!(quotas.requests_per_min, 300);
574 assert_eq!(quotas.tokens_per_day, 1_000_000);
575 }
576
577 #[test]
578 fn tenant_limiter_enterprise_unlimited_tokens() {
579 let quotas = TenantQuotas::for_plan(&TenantPlan::Enterprise);
580 assert_eq!(quotas.tokens_per_day, u64::MAX);
581 }
582
583 #[test]
584 fn tenant_limiter_check_request_allowed() {
585 let mut trl = TenantRateLimiter::new();
586 let result = trl.check_request("acme", &TenantPlan::Pro);
587 assert!(result.allowed);
588 }
589
590 #[test]
591 fn tenant_limiter_two_tenants_independent() {
592 let mut trl = TenantRateLimiter::new();
593 let r_a = trl.check_request("tenant-a", &TenantPlan::Starter);
597 let r_b = trl.check_request("tenant-b", &TenantPlan::Starter);
598 assert!(r_a.allowed);
599 assert!(r_b.allowed);
600 assert_eq!(trl.tenant_count(), 2);
601 }
602
603 #[test]
604 fn tenant_limiter_token_tracking() {
605 let mut trl = TenantRateLimiter::new();
606 trl.record_tokens("acme", 50_000);
607 let (used, limit) = trl.token_usage("acme", &TenantPlan::Starter);
608 assert_eq!(used, 50_000);
609 assert_eq!(limit, 100_000);
610 }
611
612 #[test]
613 fn tenant_limiter_token_quota_check() {
614 let mut trl = TenantRateLimiter::new();
615 assert!(trl.check_token_quota("acme", &TenantPlan::Starter));
617 trl.record_tokens("acme", 100_001);
619 assert!(!trl.check_token_quota("acme", &TenantPlan::Starter));
620 }
621
622 #[test]
623 fn tenant_limiter_enterprise_token_quota_always_ok() {
624 let mut trl = TenantRateLimiter::new();
625 trl.record_tokens("big-corp", u64::MAX / 2);
626 assert!(trl.check_token_quota("big-corp", &TenantPlan::Enterprise));
627 }
628}