Skip to main content

grapsus_proxy/inference/
rate_limit.rs

1//! Token-based rate limiting for inference endpoints
2//!
3//! Provides dual-bucket rate limiting that tracks both:
4//! - Tokens per minute (primary limit for LLM APIs)
5//! - Requests per minute (secondary limit to prevent abuse)
6
7use dashmap::DashMap;
8use std::sync::atomic::{AtomicU64, Ordering};
9use std::time::{Duration, Instant};
10use tracing::{debug, trace};
11
12use grapsus_config::TokenRateLimit;
13
14/// Result of a rate limit check
15#[derive(Debug, Clone, Copy, PartialEq, Eq)]
16pub enum TokenRateLimitResult {
17    /// Request is allowed
18    Allowed,
19    /// Token limit exceeded
20    TokensExceeded {
21        /// Milliseconds until retry is allowed
22        retry_after_ms: u64,
23    },
24    /// Request limit exceeded
25    RequestsExceeded {
26        /// Milliseconds until retry is allowed
27        retry_after_ms: u64,
28    },
29}
30
31impl TokenRateLimitResult {
32    /// Returns true if the request is allowed
33    pub fn is_allowed(&self) -> bool {
34        matches!(self, Self::Allowed)
35    }
36
37    /// Get retry-after value in milliseconds (0 if allowed)
38    pub fn retry_after_ms(&self) -> u64 {
39        match self {
40            Self::Allowed => 0,
41            Self::TokensExceeded { retry_after_ms } => *retry_after_ms,
42            Self::RequestsExceeded { retry_after_ms } => *retry_after_ms,
43        }
44    }
45}
46
47/// Token bucket for rate limiting
48struct TokenBucket {
49    /// Current token count
50    tokens: AtomicU64,
51    /// Maximum tokens (burst capacity)
52    max_tokens: u64,
53    /// Tokens added per millisecond
54    refill_rate: f64,
55    /// Last refill timestamp
56    last_refill: std::sync::Mutex<Instant>,
57}
58
59impl TokenBucket {
60    fn new(tokens_per_minute: u64, burst_tokens: u64) -> Self {
61        // Calculate refill rate: tokens per millisecond
62        let refill_rate = tokens_per_minute as f64 / 60_000.0;
63
64        Self {
65            tokens: AtomicU64::new(burst_tokens),
66            max_tokens: burst_tokens,
67            refill_rate,
68            last_refill: std::sync::Mutex::new(Instant::now()),
69        }
70    }
71
72    /// Try to consume tokens from the bucket
73    fn try_consume(&self, amount: u64) -> Result<(), u64> {
74        // First, refill based on elapsed time
75        self.refill();
76
77        // Try to consume
78        loop {
79            let current = self.tokens.load(Ordering::Acquire);
80            if current < amount {
81                // Not enough tokens - calculate wait time
82                let needed = amount - current;
83                let wait_ms = (needed as f64 / self.refill_rate).ceil() as u64;
84                return Err(wait_ms);
85            }
86
87            // Try to atomically subtract
88            if self
89                .tokens
90                .compare_exchange(
91                    current,
92                    current - amount,
93                    Ordering::AcqRel,
94                    Ordering::Relaxed,
95                )
96                .is_ok()
97            {
98                return Ok(());
99            }
100            // CAS failed, retry
101        }
102    }
103
104    /// Refill tokens based on elapsed time
105    fn refill(&self) {
106        let mut last = self.last_refill.lock().unwrap();
107        let now = Instant::now();
108        let elapsed = now.duration_since(*last);
109
110        if elapsed.as_millis() > 0 {
111            let refill_amount = (elapsed.as_millis() as f64 * self.refill_rate) as u64;
112            if refill_amount > 0 {
113                let current = self.tokens.load(Ordering::Acquire);
114                let new_tokens = (current + refill_amount).min(self.max_tokens);
115                self.tokens.store(new_tokens, Ordering::Release);
116                *last = now;
117            }
118        }
119    }
120
121    /// Get current token count
122    fn current_tokens(&self) -> u64 {
123        self.refill();
124        self.tokens.load(Ordering::Acquire)
125    }
126}
127
128/// Token-based rate limiter for inference endpoints
129///
130/// Tracks rate limits per key (typically client IP or API key).
131pub struct TokenRateLimiter {
132    /// Token buckets per key
133    token_buckets: DashMap<String, TokenBucket>,
134    /// Request buckets per key (optional)
135    request_buckets: Option<DashMap<String, TokenBucket>>,
136    /// Configuration
137    config: TokenRateLimit,
138}
139
140impl TokenRateLimiter {
141    /// Create a new token rate limiter
142    pub fn new(config: TokenRateLimit) -> Self {
143        let request_buckets = config.requests_per_minute.map(|rpm| DashMap::new());
144
145        Self {
146            token_buckets: DashMap::new(),
147            request_buckets,
148            config,
149        }
150    }
151
152    /// Check if a request is allowed
153    ///
154    /// Both token and request limits must pass for the request to be allowed.
155    pub fn check(&self, key: &str, estimated_tokens: u64) -> TokenRateLimitResult {
156        // Check token limit
157        let token_bucket = self
158            .token_buckets
159            .entry(key.to_string())
160            .or_insert_with(|| {
161                TokenBucket::new(self.config.tokens_per_minute, self.config.burst_tokens)
162            });
163
164        if let Err(retry_ms) = token_bucket.try_consume(estimated_tokens) {
165            trace!(
166                key = key,
167                estimated_tokens = estimated_tokens,
168                retry_after_ms = retry_ms,
169                "Token rate limit exceeded"
170            );
171            return TokenRateLimitResult::TokensExceeded {
172                retry_after_ms: retry_ms,
173            };
174        }
175
176        // Check request limit if configured
177        if let (Some(rpm), Some(ref request_buckets)) =
178            (self.config.requests_per_minute, &self.request_buckets)
179        {
180            let request_bucket = request_buckets.entry(key.to_string()).or_insert_with(|| {
181                // For request limiting, use burst = rpm / 6 (10 second burst)
182                let burst = rpm.max(1) / 6;
183                TokenBucket::new(rpm, burst.max(1))
184            });
185
186            if let Err(retry_ms) = request_bucket.try_consume(1) {
187                trace!(
188                    key = key,
189                    retry_after_ms = retry_ms,
190                    "Request rate limit exceeded"
191                );
192                return TokenRateLimitResult::RequestsExceeded {
193                    retry_after_ms: retry_ms,
194                };
195            }
196        }
197
198        trace!(
199            key = key,
200            estimated_tokens = estimated_tokens,
201            "Rate limit check passed"
202        );
203        TokenRateLimitResult::Allowed
204    }
205
206    /// Record actual token usage after response
207    ///
208    /// This allows adjusting the bucket based on actual vs estimated usage.
209    /// If actual < estimated, refund the difference.
210    /// If actual > estimated, consume the extra (best effort).
211    pub fn record_actual(&self, key: &str, actual_tokens: u64, estimated_tokens: u64) {
212        if let Some(bucket) = self.token_buckets.get(key) {
213            if actual_tokens < estimated_tokens {
214                // Refund over-estimation
215                let refund = estimated_tokens - actual_tokens;
216                let current = bucket.tokens.load(Ordering::Acquire);
217                let new_tokens = (current + refund).min(bucket.max_tokens);
218                bucket.tokens.store(new_tokens, Ordering::Release);
219
220                debug!(
221                    key = key,
222                    actual = actual_tokens,
223                    estimated = estimated_tokens,
224                    refund = refund,
225                    "Refunded over-estimated tokens"
226                );
227            } else if actual_tokens > estimated_tokens {
228                // Under-estimation - try to consume extra (don't block)
229                let extra = actual_tokens - estimated_tokens;
230                let current = bucket.tokens.load(Ordering::Acquire);
231                let to_consume = extra.min(current);
232                if to_consume > 0 {
233                    bucket.tokens.fetch_sub(to_consume, Ordering::AcqRel);
234                }
235
236                debug!(
237                    key = key,
238                    actual = actual_tokens,
239                    estimated = estimated_tokens,
240                    consumed_extra = to_consume,
241                    "Consumed under-estimated tokens"
242                );
243            }
244        }
245    }
246
247    /// Get current token count for a key
248    pub fn current_tokens(&self, key: &str) -> Option<u64> {
249        self.token_buckets.get(key).map(|b| b.current_tokens())
250    }
251
252    /// Get stats for metrics
253    pub fn stats(&self) -> TokenRateLimiterStats {
254        TokenRateLimiterStats {
255            active_keys: self.token_buckets.len(),
256            tokens_per_minute: self.config.tokens_per_minute,
257            requests_per_minute: self.config.requests_per_minute,
258        }
259    }
260}
261
262/// Stats for the token rate limiter
263#[derive(Debug, Clone)]
264pub struct TokenRateLimiterStats {
265    /// Number of active rate limit keys
266    pub active_keys: usize,
267    /// Configured tokens per minute
268    pub tokens_per_minute: u64,
269    /// Configured requests per minute (if any)
270    pub requests_per_minute: Option<u64>,
271}
272
273#[cfg(test)]
274mod tests {
275    use super::*;
276    use grapsus_config::TokenEstimation;
277
278    fn test_config() -> TokenRateLimit {
279        TokenRateLimit {
280            tokens_per_minute: 1000,
281            requests_per_minute: Some(10),
282            burst_tokens: 200,
283            estimation_method: TokenEstimation::Chars,
284        }
285    }
286
287    #[test]
288    fn test_basic_rate_limiting() {
289        let limiter = TokenRateLimiter::new(test_config());
290
291        // First request should succeed
292        let result = limiter.check("test-key", 50);
293        assert!(result.is_allowed());
294
295        // Should still have tokens
296        let current = limiter.current_tokens("test-key").unwrap();
297        assert!(current > 0);
298    }
299
300    #[test]
301    fn test_token_exhaustion() {
302        let limiter = TokenRateLimiter::new(test_config());
303
304        // Exhaust tokens
305        for _ in 0..4 {
306            let _ = limiter.check("test-key", 50);
307        }
308
309        // This should exceed the 200 burst tokens
310        let result = limiter.check("test-key", 50);
311        assert!(!result.is_allowed());
312        assert!(matches!(
313            result,
314            TokenRateLimitResult::TokensExceeded { .. }
315        ));
316    }
317
318    #[test]
319    fn test_actual_token_refund() {
320        let limiter = TokenRateLimiter::new(test_config());
321
322        // Consume with high estimate
323        let _ = limiter.check("test-key", 100);
324        let before = limiter.current_tokens("test-key").unwrap();
325
326        // Record actual as lower
327        limiter.record_actual("test-key", 50, 100);
328        let after = limiter.current_tokens("test-key").unwrap();
329
330        // Should have refunded 50 tokens
331        assert!(after > before);
332    }
333}