Skip to main content

axon/
rate_limiter.rs

1//! Rate Limiter — sliding window rate limiting for AxonServer.
2//!
3//! Implements a sliding window counter algorithm per client key (IP or token).
4//! Each window tracks request timestamps; expired entries are pruned on access.
5//!
6//! Configuration:
7//!   - `max_requests` — maximum requests per window (default: 100)
8//!   - `window_secs` — window duration in seconds (default: 60)
9//!
10//! Integration: called from AxonServer handlers before processing requests.
11//! Returns `RateLimitResult` with allowed/denied status and remaining quota.
12
13use std::collections::{HashMap, VecDeque};
14use std::time::{Duration, Instant};
15
16use serde::Serialize;
17
18// ── Configuration ────────────────────────────────────────────────────────
19
20/// Rate limiter configuration.
21#[derive(Debug, Clone)]
22pub struct RateLimitConfig {
23    /// Maximum requests allowed per window.
24    pub max_requests: u32,
25    /// Window duration.
26    pub window: Duration,
27    /// Whether rate limiting is enabled.
28    pub enabled: bool,
29}
30
31impl RateLimitConfig {
32    /// Default: 100 requests per 60 seconds.
33    pub fn default_config() -> Self {
34        RateLimitConfig {
35            max_requests: 100,
36            window: Duration::from_secs(60),
37            enabled: true,
38        }
39    }
40
41    /// Disabled rate limiter.
42    pub fn disabled() -> Self {
43        RateLimitConfig {
44            max_requests: 0,
45            window: Duration::from_secs(0),
46            enabled: false,
47        }
48    }
49}
50
51// ── Result ───────────────────────────────────────────────────────────────
52
53/// Result of a rate limit check.
54#[derive(Debug, Clone, Serialize)]
55pub struct RateLimitResult {
56    /// Whether the request is allowed.
57    pub allowed: bool,
58    /// Remaining requests in the current window.
59    pub remaining: u32,
60    /// Total limit per window.
61    pub limit: u32,
62    /// Seconds until the window resets (oldest entry expires).
63    pub reset_secs: u64,
64}
65
66// ── Limiter ──────────────────────────────────────────────────────────────
67
68/// Per-client rate limiter metrics (for Prometheus exposition).
69#[derive(Debug, Clone)]
70pub struct ClientRateMetric {
71    pub client_key: String,
72    pub total_requests: u64,
73    pub rejected: u64,
74    pub current_window_count: u32,
75}
76
77/// Per-client request timestamps.
78struct ClientBucket {
79    timestamps: VecDeque<Instant>,
80    total_requests: u64,
81    rejected: u64,
82}
83
84impl ClientBucket {
85    fn new() -> Self {
86        ClientBucket {
87            timestamps: VecDeque::new(),
88            total_requests: 0,
89            rejected: 0,
90        }
91    }
92
93    /// Prune expired timestamps and return current count.
94    fn prune_and_count(&mut self, now: Instant, window: Duration) -> u32 {
95        let cutoff = now.checked_sub(window).unwrap_or(now);
96        while let Some(&front) = self.timestamps.front() {
97            if front < cutoff {
98                self.timestamps.pop_front();
99            } else {
100                break;
101            }
102        }
103        self.timestamps.len() as u32
104    }
105
106    /// Time until the oldest entry expires (window reset).
107    fn reset_time(&self, now: Instant, window: Duration) -> u64 {
108        if let Some(&oldest) = self.timestamps.front() {
109            let expires_at = oldest + window;
110            if expires_at > now {
111                return (expires_at - now).as_secs();
112            }
113        }
114        0
115    }
116}
117
118/// Sliding window rate limiter.
119pub struct RateLimiter {
120    config: RateLimitConfig,
121    buckets: HashMap<String, ClientBucket>,
122}
123
124impl RateLimiter {
125    /// Create a new rate limiter with the given configuration.
126    pub fn new(config: RateLimitConfig) -> Self {
127        RateLimiter {
128            config,
129            buckets: HashMap::new(),
130        }
131    }
132
133    /// Check if a request from the given client key is allowed.
134    /// If allowed, records the request timestamp.
135    pub fn check(&mut self, client_key: &str) -> RateLimitResult {
136        if !self.config.enabled {
137            return RateLimitResult {
138                allowed: true,
139                remaining: u32::MAX,
140                limit: 0,
141                reset_secs: 0,
142            };
143        }
144
145        let now = Instant::now();
146        let bucket = self.buckets
147            .entry(client_key.to_string())
148            .or_insert_with(ClientBucket::new);
149
150        let count = bucket.prune_and_count(now, self.config.window);
151        bucket.total_requests += 1;
152
153        if count >= self.config.max_requests {
154            bucket.rejected += 1;
155            let reset = bucket.reset_time(now, self.config.window);
156            return RateLimitResult {
157                allowed: false,
158                remaining: 0,
159                limit: self.config.max_requests,
160                reset_secs: reset,
161            };
162        }
163
164        // Allow and record
165        bucket.timestamps.push_back(now);
166        let remaining = self.config.max_requests - count - 1;
167        let reset = bucket.reset_time(now, self.config.window);
168
169        RateLimitResult {
170            allowed: true,
171            remaining,
172            limit: self.config.max_requests,
173            reset_secs: reset,
174        }
175    }
176
177    /// Get current stats for a client without recording a request.
178    pub fn peek(&mut self, client_key: &str) -> RateLimitResult {
179        if !self.config.enabled {
180            return RateLimitResult {
181                allowed: true,
182                remaining: u32::MAX,
183                limit: 0,
184                reset_secs: 0,
185            };
186        }
187
188        let now = Instant::now();
189        let bucket = self.buckets
190            .entry(client_key.to_string())
191            .or_insert_with(ClientBucket::new);
192
193        let count = bucket.prune_and_count(now, self.config.window);
194        let remaining = self.config.max_requests.saturating_sub(count);
195        let reset = bucket.reset_time(now, self.config.window);
196
197        RateLimitResult {
198            allowed: remaining > 0,
199            remaining,
200            limit: self.config.max_requests,
201            reset_secs: reset,
202        }
203    }
204
205    /// Number of tracked client keys.
206    pub fn client_count(&self) -> usize {
207        self.buckets.len()
208    }
209
210    /// Prune all empty buckets (cleanup).
211    pub fn cleanup(&mut self) {
212        let now = Instant::now();
213        let window = self.config.window;
214        self.buckets.retain(|_, bucket| {
215            bucket.prune_and_count(now, window);
216            !bucket.timestamps.is_empty()
217        });
218    }
219
220    /// Get the configuration.
221    pub fn config(&self) -> &RateLimitConfig {
222        &self.config
223    }
224
225    /// Update the configuration at runtime.
226    pub fn update_config(&mut self, max_requests: Option<u32>, window_secs: Option<u64>, enabled: Option<bool>) {
227        if let Some(max) = max_requests {
228            self.config.max_requests = max;
229        }
230        if let Some(secs) = window_secs {
231            self.config.window = Duration::from_secs(secs);
232        }
233        if let Some(en) = enabled {
234            self.config.enabled = en;
235        }
236    }
237
238    /// Get per-client rate limiter metrics (for Prometheus).
239    pub fn client_metrics(&mut self) -> Vec<ClientRateMetric> {
240        let now = Instant::now();
241        let window = self.config.window;
242        self.buckets.iter_mut().map(|(key, bucket)| {
243            let current = bucket.prune_and_count(now, window);
244            ClientRateMetric {
245                client_key: key.clone(),
246                total_requests: bucket.total_requests,
247                rejected: bucket.rejected,
248                current_window_count: current,
249            }
250        }).collect()
251    }
252}
253
254// ── Per-tenant rate limiter (M4) ─────────────────────────────────────────
255
256use crate::tenant::TenantPlan;
257
258/// Per-plan default quotas for tenant rate limiting.
259/// All values are conservative; enterprise is effectively unlimited.
260pub struct TenantQuotas {
261    /// Maximum requests per minute.
262    pub requests_per_min: u32,
263    /// Maximum tokens per 24-hour rolling window. u64::MAX = unlimited.
264    pub tokens_per_day: u64,
265}
266
267impl TenantQuotas {
268    pub fn for_plan(plan: &TenantPlan) -> Self {
269        match plan {
270            TenantPlan::Starter    => Self { requests_per_min: 60,   tokens_per_day: 100_000 },
271            TenantPlan::Pro        => Self { requests_per_min: 300,  tokens_per_day: 1_000_000 },
272            TenantPlan::Enterprise => Self { requests_per_min: 2000, tokens_per_day: u64::MAX },
273        }
274    }
275}
276
277/// Daily token counter for one tenant.
278struct TokenBucket {
279    used: u64,
280    window_start: std::time::Instant,
281}
282
283impl TokenBucket {
284    fn new() -> Self {
285        Self { used: 0, window_start: std::time::Instant::now() }
286    }
287
288    /// Reset counter if the 24-hour window has rolled over.
289    fn refresh(&mut self) {
290        if self.window_start.elapsed() >= Duration::from_secs(86400) {
291            self.used = 0;
292            self.window_start = std::time::Instant::now();
293        }
294    }
295
296    fn add(&mut self, tokens: u64) {
297        self.refresh();
298        self.used = self.used.saturating_add(tokens);
299    }
300
301    fn can_consume(&mut self, limit: u64) -> bool {
302        self.refresh();
303        self.used < limit
304    }
305}
306
307/// Per-tenant combined rate limiter: request rate + daily token quota.
308///
309/// Both dimensions are enforced independently:
310///   - Request rate: sliding window per minute (same algorithm as `RateLimiter`)
311///   - Token quota: rolling 24-hour counter reset at midnight of first request
312///
313/// Quotas are derived from `TenantPlan` at check time — no DB round-trip needed.
314/// An unknown tenant defaults to `Starter` quotas (safest for open-source compat).
315pub struct TenantRateLimiter {
316    /// Per-tenant request-rate limiters (created lazily).
317    request_limiters: HashMap<String, RateLimiter>,
318    /// Per-tenant daily token counters (created lazily).
319    token_buckets: HashMap<String, TokenBucket>,
320}
321
322impl TenantRateLimiter {
323    pub fn new() -> Self {
324        Self {
325            request_limiters: HashMap::new(),
326            token_buckets: HashMap::new(),
327        }
328    }
329
330    /// Check and record one request for a tenant.
331    /// Returns `RateLimitResult` — caller should reject if `!result.allowed`.
332    pub fn check_request(&mut self, tenant_id: &str, plan: &TenantPlan) -> RateLimitResult {
333        let quotas = TenantQuotas::for_plan(plan);
334        let limiter = self.request_limiters
335            .entry(tenant_id.to_string())
336            .or_insert_with(|| {
337                RateLimiter::new(RateLimitConfig {
338                    max_requests: quotas.requests_per_min,
339                    window: Duration::from_secs(60),
340                    enabled: true,
341                })
342            });
343        // Update quota if plan changed (e.g. tenant upgraded)
344        limiter.update_config(Some(quotas.requests_per_min), None, None);
345        limiter.check(tenant_id)
346    }
347
348    /// Record tokens consumed by a tenant (called after a successful LLM response).
349    pub fn record_tokens(&mut self, tenant_id: &str, tokens: u64) {
350        self.token_buckets
351            .entry(tenant_id.to_string())
352            .or_insert_with(TokenBucket::new)
353            .add(tokens);
354    }
355
356    /// Check whether a tenant is within their daily token quota.
357    /// Does NOT consume tokens — call `record_tokens` after the LLM call.
358    pub fn check_token_quota(&mut self, tenant_id: &str, plan: &TenantPlan) -> bool {
359        let limit = TenantQuotas::for_plan(plan).tokens_per_day;
360        if limit == u64::MAX {
361            return true; // Enterprise = unlimited
362        }
363        self.token_buckets
364            .entry(tenant_id.to_string())
365            .or_insert_with(TokenBucket::new)
366            .can_consume(limit)
367    }
368
369    /// Current token usage for a tenant (used, daily_limit).
370    pub fn token_usage(&mut self, tenant_id: &str, plan: &TenantPlan) -> (u64, u64) {
371        let limit = TenantQuotas::for_plan(plan).tokens_per_day;
372        let bucket = self.token_buckets
373            .entry(tenant_id.to_string())
374            .or_insert_with(TokenBucket::new);
375        bucket.refresh();
376        (bucket.used, limit)
377    }
378
379    /// Number of tracked tenants.
380    pub fn tenant_count(&self) -> usize {
381        self.request_limiters.len()
382    }
383
384    /// Remove stale tenant entries to prevent unbounded growth.
385    pub fn cleanup(&mut self) {
386        let now = std::time::Instant::now();
387        // Remove tenants whose request windows are all expired (24h+)
388        self.token_buckets.retain(|_, b| {
389            b.window_start.elapsed() < Duration::from_secs(86400 * 2)
390        });
391        // Also prune the request limiters
392        for limiter in self.request_limiters.values_mut() {
393            limiter.cleanup();
394        }
395        let _ = now; // suppress unused warning
396    }
397}
398
399// ── Tests ────────────────────────────────────────────────────────────────
400
401#[cfg(test)]
402mod tests {
403    use super::*;
404
405    fn fast_config(max: u32, window_ms: u64) -> RateLimitConfig {
406        RateLimitConfig {
407            max_requests: max,
408            window: Duration::from_millis(window_ms),
409            enabled: true,
410        }
411    }
412
413    #[test]
414    fn allows_within_limit() {
415        let mut limiter = RateLimiter::new(fast_config(5, 1000));
416        for i in 0..5 {
417            let result = limiter.check("client_a");
418            assert!(result.allowed, "request {} should be allowed", i);
419            assert_eq!(result.remaining, 4 - i as u32);
420            assert_eq!(result.limit, 5);
421        }
422    }
423
424    #[test]
425    fn denies_over_limit() {
426        let mut limiter = RateLimiter::new(fast_config(3, 60_000));
427        for _ in 0..3 {
428            assert!(limiter.check("client_a").allowed);
429        }
430        let result = limiter.check("client_a");
431        assert!(!result.allowed);
432        assert_eq!(result.remaining, 0);
433    }
434
435    #[test]
436    fn separate_clients_independent() {
437        let mut limiter = RateLimiter::new(fast_config(2, 60_000));
438        assert!(limiter.check("alice").allowed);
439        assert!(limiter.check("alice").allowed);
440        assert!(!limiter.check("alice").allowed);
441
442        // Bob should still be able to make requests
443        assert!(limiter.check("bob").allowed);
444        assert!(limiter.check("bob").allowed);
445        assert!(!limiter.check("bob").allowed);
446    }
447
448    #[test]
449    fn window_expiry_allows_again() {
450        let mut limiter = RateLimiter::new(fast_config(2, 1)); // 1ms window
451        assert!(limiter.check("client").allowed);
452        assert!(limiter.check("client").allowed);
453        assert!(!limiter.check("client").allowed);
454
455        // Wait for window to expire
456        std::thread::sleep(Duration::from_millis(5));
457        assert!(limiter.check("client").allowed);
458    }
459
460    #[test]
461    fn disabled_always_allows() {
462        let mut limiter = RateLimiter::new(RateLimitConfig::disabled());
463        for _ in 0..1000 {
464            let result = limiter.check("anyone");
465            assert!(result.allowed);
466            assert_eq!(result.remaining, u32::MAX);
467        }
468    }
469
470    #[test]
471    fn peek_does_not_consume() {
472        let mut limiter = RateLimiter::new(fast_config(3, 60_000));
473        limiter.check("client"); // consume 1
474
475        let peek1 = limiter.peek("client");
476        assert!(peek1.allowed);
477        assert_eq!(peek1.remaining, 2);
478
479        let peek2 = limiter.peek("client");
480        assert_eq!(peek2.remaining, 2); // unchanged
481    }
482
483    #[test]
484    fn client_count_tracks_unique() {
485        let mut limiter = RateLimiter::new(fast_config(10, 60_000));
486        assert_eq!(limiter.client_count(), 0);
487
488        limiter.check("a");
489        assert_eq!(limiter.client_count(), 1);
490
491        limiter.check("b");
492        assert_eq!(limiter.client_count(), 2);
493
494        limiter.check("a"); // same client
495        assert_eq!(limiter.client_count(), 2);
496    }
497
498    #[test]
499    fn cleanup_removes_expired() {
500        let mut limiter = RateLimiter::new(fast_config(5, 1)); // 1ms window
501        limiter.check("temp");
502        assert_eq!(limiter.client_count(), 1);
503
504        std::thread::sleep(Duration::from_millis(5));
505        limiter.cleanup();
506        assert_eq!(limiter.client_count(), 0);
507    }
508
509    #[test]
510    fn reset_secs_positive_when_active() {
511        let mut limiter = RateLimiter::new(fast_config(5, 60_000)); // 60s window
512        let result = limiter.check("client");
513        assert!(result.allowed);
514        // Reset should be close to 60 (but timing is imprecise)
515        assert!(result.reset_secs <= 60);
516    }
517
518    #[test]
519    fn result_serializes_to_json() {
520        let result = RateLimitResult {
521            allowed: true,
522            remaining: 42,
523            limit: 100,
524            reset_secs: 30,
525        };
526        let json = serde_json::to_string(&result).unwrap();
527        assert!(json.contains("\"allowed\":true"));
528        assert!(json.contains("\"remaining\":42"));
529        assert!(json.contains("\"limit\":100"));
530    }
531
532    #[test]
533    fn default_config_values() {
534        let cfg = RateLimitConfig::default_config();
535        assert_eq!(cfg.max_requests, 100);
536        assert_eq!(cfg.window, Duration::from_secs(60));
537        assert!(cfg.enabled);
538    }
539
540    #[test]
541    fn single_request_limit() {
542        let mut limiter = RateLimiter::new(fast_config(1, 60_000));
543        assert!(limiter.check("client").allowed);
544        assert!(!limiter.check("client").allowed);
545    }
546
547    #[test]
548    fn remaining_decrements_correctly() {
549        let mut limiter = RateLimiter::new(fast_config(5, 60_000));
550        assert_eq!(limiter.check("c").remaining, 4);
551        assert_eq!(limiter.check("c").remaining, 3);
552        assert_eq!(limiter.check("c").remaining, 2);
553        assert_eq!(limiter.check("c").remaining, 1);
554        assert_eq!(limiter.check("c").remaining, 0);
555        // Next should be denied
556        let denied = limiter.check("c");
557        assert!(!denied.allowed);
558        assert_eq!(denied.remaining, 0);
559    }
560
561    // ── TenantRateLimiter tests ───────────────────────────────────────────
562
563    #[test]
564    fn tenant_limiter_starter_quota() {
565        let quotas = TenantQuotas::for_plan(&TenantPlan::Starter);
566        assert_eq!(quotas.requests_per_min, 60);
567        assert_eq!(quotas.tokens_per_day, 100_000);
568    }
569
570    #[test]
571    fn tenant_limiter_pro_quota() {
572        let quotas = TenantQuotas::for_plan(&TenantPlan::Pro);
573        assert_eq!(quotas.requests_per_min, 300);
574        assert_eq!(quotas.tokens_per_day, 1_000_000);
575    }
576
577    #[test]
578    fn tenant_limiter_enterprise_unlimited_tokens() {
579        let quotas = TenantQuotas::for_plan(&TenantPlan::Enterprise);
580        assert_eq!(quotas.tokens_per_day, u64::MAX);
581    }
582
583    #[test]
584    fn tenant_limiter_check_request_allowed() {
585        let mut trl = TenantRateLimiter::new();
586        let result = trl.check_request("acme", &TenantPlan::Pro);
587        assert!(result.allowed);
588    }
589
590    #[test]
591    fn tenant_limiter_two_tenants_independent() {
592        let mut trl = TenantRateLimiter::new();
593        // Exhaust starter quota for tenant-a (60 req/min = 60 requests)
594        // We use fast_config equivalent via Starter but override to small limit
595        // Instead, just verify they start independent
596        let r_a = trl.check_request("tenant-a", &TenantPlan::Starter);
597        let r_b = trl.check_request("tenant-b", &TenantPlan::Starter);
598        assert!(r_a.allowed);
599        assert!(r_b.allowed);
600        assert_eq!(trl.tenant_count(), 2);
601    }
602
603    #[test]
604    fn tenant_limiter_token_tracking() {
605        let mut trl = TenantRateLimiter::new();
606        trl.record_tokens("acme", 50_000);
607        let (used, limit) = trl.token_usage("acme", &TenantPlan::Starter);
608        assert_eq!(used, 50_000);
609        assert_eq!(limit, 100_000);
610    }
611
612    #[test]
613    fn tenant_limiter_token_quota_check() {
614        let mut trl = TenantRateLimiter::new();
615        // Under quota
616        assert!(trl.check_token_quota("acme", &TenantPlan::Starter));
617        // Exhaust quota
618        trl.record_tokens("acme", 100_001);
619        assert!(!trl.check_token_quota("acme", &TenantPlan::Starter));
620    }
621
622    #[test]
623    fn tenant_limiter_enterprise_token_quota_always_ok() {
624        let mut trl = TenantRateLimiter::new();
625        trl.record_tokens("big-corp", u64::MAX / 2);
626        assert!(trl.check_token_quota("big-corp", &TenantPlan::Enterprise));
627    }
628}