Skip to main content

hyperinfer_core/
rate_limiting.rs

1//! Rate limiting utilities for HyperInfer
2//!
3//! Provides distributed quota enforcement using Redis and GCRA algorithm.
4
5use redis::aio::ConnectionManager;
6use redis::Client;
7use serde::{Deserialize, Serialize};
8use std::time::Instant;
9
10pub const USAGE_TOKENS_KEY_PREFIX: &str = "hyperinfer:usage:tokens:";
11pub const USAGE_REQUESTS_KEY_PREFIX: &str = "hyperinfer:usage:requests:";
12
13const GCRA_SCRIPT: &str = r#"
14local key = KEYS[1]
15local rate = tonumber(ARGV[1])
16local capacity = tonumber(ARGV[2])
17local now = tonumber(ARGV[3])
18local cost = tonumber(ARGV[4])
19
20local emission_interval = capacity / rate
21local tat = redis.call('GET', key)
22
23if not tat then
24    tat = now
25else
26    tat = tonumber(tat)
27end
28
29local new_tat = math.max(tat, now) + cost * emission_interval
30local allow_at = new_tat - capacity
31
32if allow_at <= now then
33    redis.call('SET', key, new_tat, 'EX', math.ceil(capacity * 2))
34    return {1, 0}
35else
36    return {0, math.ceil(allow_at - now)}
37end
38"#;
39
40const RPM_SCRIPT: &str = r#"
41local key = KEYS[1]
42local limit = tonumber(ARGV[1])
43local window = tonumber(ARGV[2])
44
45local current = redis.call('INCR', key)
46if current == 1 then
47    redis.call('EXPIRE', key, window)
48end
49
50if current > limit then
51    local ttl = redis.call('TTL', key)
52    return {0, 0, ttl}
53end
54return {1, limit - current, 0}
55"#;
56
57#[derive(Debug, Clone)]
58pub struct TokenBucket {
59    pub capacity: u64,
60    pub tokens: u64,
61    pub refill_rate: u64,
62    pub last_refill: Instant,
63}
64
65#[derive(Debug, Clone, Serialize, Deserialize)]
66pub struct Quota {
67    pub max_requests_per_minute: Option<u64>,
68    pub max_tokens_per_minute: Option<u64>,
69    pub budget_cents: Option<u64>,
70}
71
72#[derive(Clone)]
73pub struct RateLimiter {
74    redis_manager: Option<ConnectionManager>,
75    default_rpm: u64,
76    default_tpm: u64,
77}
78
79impl RateLimiter {
80    pub async fn new(
81        redis_url: Option<&str>,
82    ) -> Result<Self, Box<dyn std::error::Error + Send + Sync>> {
83        let redis_manager = match redis_url {
84            Some(url) => {
85                let client = Client::open(url)?;
86                Some(ConnectionManager::new(client).await?)
87            }
88            None => None,
89        };
90        Ok(Self {
91            redis_manager,
92            default_rpm: 60,
93            default_tpm: 100000,
94        })
95    }
96
97    pub async fn is_allowed(
98        &self,
99        key: &str,
100        amount: u64,
101    ) -> Result<bool, Box<dyn std::error::Error + Send + Sync>> {
102        if let Some(ref manager) = self.redis_manager {
103            let mut conn = manager.clone();
104
105            let result: Vec<u64> = redis::cmd("EVAL")
106                .arg(RPM_SCRIPT)
107                .arg(1)
108                .arg(format!("hyperinfer:ratelimit:rpm:{}", key))
109                .arg(self.default_rpm)
110                .arg(60)
111                .query_async(&mut conn)
112                .await?;
113
114            let allowed = result.first().copied().unwrap_or(0);
115            if allowed == 0 {
116                return Ok(false);
117            }
118
119            let tpm_key = format!("hyperinfer:ratelimit:tpm:{}", key);
120            let now = std::time::SystemTime::now()
121                .duration_since(std::time::UNIX_EPOCH)
122                .map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)?
123                .as_millis() as u64;
124            let rate = self.default_tpm / 60;
125            let tpm_result: Vec<u64> = redis::cmd("EVAL")
126                .arg(GCRA_SCRIPT)
127                .arg(1)
128                .arg(&tpm_key)
129                .arg(rate)
130                .arg(self.default_tpm)
131                .arg(now)
132                .arg(amount)
133                .query_async(&mut conn)
134                .await?;
135
136            Ok(tpm_result.first().copied().unwrap_or(0) == 1)
137        } else {
138            Ok(true)
139        }
140    }
141
142    pub async fn check_rpm(
143        &self,
144        key: &str,
145        limit: u64,
146    ) -> Result<(bool, u64), Box<dyn std::error::Error + Send + Sync>> {
147        if let Some(ref manager) = self.redis_manager {
148            let mut conn = manager.clone();
149
150            let result: Vec<u64> = redis::cmd("EVAL")
151                .arg(RPM_SCRIPT)
152                .arg(1)
153                .arg(format!("hyperinfer:ratelimit:rpm:{}", key))
154                .arg(limit)
155                .arg(60)
156                .query_async(&mut conn)
157                .await?;
158
159            let allowed = result.first().copied().unwrap_or(0) == 1;
160            let remaining = result.get(1).copied().unwrap_or(0);
161            Ok((allowed, remaining))
162        } else {
163            Ok((true, limit))
164        }
165    }
166
167    pub async fn check_tpm(
168        &self,
169        key: &str,
170        limit: u64,
171        tokens: u64,
172    ) -> Result<bool, Box<dyn std::error::Error + Send + Sync>> {
173        if let Some(ref manager) = self.redis_manager {
174            let mut conn = manager.clone();
175
176            let now = std::time::SystemTime::now()
177                .duration_since(std::time::UNIX_EPOCH)
178                .map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)?
179                .as_millis() as u64;
180            let rate = limit / 60;
181
182            let result: Vec<u64> = redis::cmd("EVAL")
183                .arg(GCRA_SCRIPT)
184                .arg(1)
185                .arg(format!("hyperinfer:ratelimit:tpm:{}", key))
186                .arg(rate)
187                .arg(limit)
188                .arg(now)
189                .arg(tokens)
190                .query_async(&mut conn)
191                .await?;
192
193            Ok(result.first().copied().unwrap_or(0) == 1)
194        } else {
195            Ok(true)
196        }
197    }
198
199    pub async fn record_usage(
200        &self,
201        key: &str,
202        tokens_used: u64,
203    ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
204        if let Some(ref manager) = self.redis_manager {
205            let mut conn = manager.clone();
206
207            redis::pipe()
208                .atomic()
209                .cmd("INCRBY")
210                .arg(format!("{}{}", USAGE_TOKENS_KEY_PREFIX, key))
211                .arg(tokens_used)
212                .cmd("INCR")
213                .arg(format!("{}{}", USAGE_REQUESTS_KEY_PREFIX, key))
214                .query_async::<()>(&mut conn)
215                .await?;
216        }
217        Ok(())
218    }
219}
220
221#[cfg(test)]
222mod tests {
223    use super::*;
224
225    #[tokio::test]
226    async fn test_rate_limiter_new_without_redis() {
227        let result = RateLimiter::new(None).await;
228        assert!(result.is_ok());
229        let limiter = result.unwrap();
230        assert_eq!(limiter.default_rpm, 60);
231        assert_eq!(limiter.default_tpm, 100000);
232    }
233
234    #[tokio::test]
235    async fn test_rate_limiter_is_allowed_without_redis() {
236        let limiter = RateLimiter::new(None).await.unwrap();
237        let result = limiter.is_allowed("test-key", 1).await;
238        assert!(result.is_ok());
239        // Without Redis, should always allow
240        assert!(result.unwrap());
241    }
242
243    #[tokio::test]
244    async fn test_rate_limiter_check_rpm_without_redis() {
245        let limiter = RateLimiter::new(None).await.unwrap();
246        let result = limiter.check_rpm("test-key", 100).await;
247        assert!(result.is_ok());
248        let (allowed, remaining) = result.unwrap();
249        // Without Redis, should always allow
250        assert!(allowed);
251        assert_eq!(remaining, 100);
252    }
253
254    #[tokio::test]
255    async fn test_rate_limiter_check_tpm_without_redis() {
256        let limiter = RateLimiter::new(None).await.unwrap();
257        let result = limiter.check_tpm("test-key", 1000, 100).await;
258        assert!(result.is_ok());
259        // Without Redis, should always allow
260        assert!(result.unwrap());
261    }
262
263    #[tokio::test]
264    async fn test_rate_limiter_record_usage_without_redis() {
265        let limiter = RateLimiter::new(None).await.unwrap();
266        let result = limiter.record_usage("test-key", 50).await;
267        assert!(result.is_ok());
268    }
269
270    #[test]
271    fn test_token_bucket_creation() {
272        let bucket = TokenBucket {
273            capacity: 100,
274            tokens: 100,
275            refill_rate: 10,
276            last_refill: Instant::now(),
277        };
278
279        assert_eq!(bucket.capacity, 100);
280        assert_eq!(bucket.tokens, 100);
281        assert_eq!(bucket.refill_rate, 10);
282    }
283
284    #[test]
285    fn test_token_bucket_clone() {
286        let bucket = TokenBucket {
287            capacity: 50,
288            tokens: 25,
289            refill_rate: 5,
290            last_refill: Instant::now(),
291        };
292
293        let cloned = bucket.clone();
294        assert_eq!(bucket.capacity, cloned.capacity);
295        assert_eq!(bucket.tokens, cloned.tokens);
296        assert_eq!(bucket.refill_rate, cloned.refill_rate);
297    }
298
299    #[test]
300    fn test_quota_creation() {
301        let quota = Quota {
302            max_requests_per_minute: Some(60),
303            max_tokens_per_minute: Some(100000),
304            budget_cents: Some(1000),
305        };
306
307        assert_eq!(quota.max_requests_per_minute, Some(60));
308        assert_eq!(quota.max_tokens_per_minute, Some(100000));
309        assert_eq!(quota.budget_cents, Some(1000));
310    }
311
312    #[test]
313    fn test_quota_with_none_values() {
314        let quota = Quota {
315            max_requests_per_minute: None,
316            max_tokens_per_minute: None,
317            budget_cents: None,
318        };
319
320        assert_eq!(quota.max_requests_per_minute, None);
321        assert_eq!(quota.max_tokens_per_minute, None);
322        assert_eq!(quota.budget_cents, None);
323    }
324
325    #[test]
326    fn test_quota_clone() {
327        let quota = Quota {
328            max_requests_per_minute: Some(100),
329            max_tokens_per_minute: Some(50000),
330            budget_cents: Some(2000),
331        };
332
333        let cloned = quota.clone();
334        assert_eq!(
335            quota.max_requests_per_minute,
336            cloned.max_requests_per_minute
337        );
338        assert_eq!(quota.max_tokens_per_minute, cloned.max_tokens_per_minute);
339        assert_eq!(quota.budget_cents, cloned.budget_cents);
340    }
341
342    #[tokio::test]
343    async fn test_rate_limiter_is_allowed_with_zero_amount() {
344        let limiter = RateLimiter::new(None).await.unwrap();
345        let result = limiter.is_allowed("test-key", 0).await;
346        assert!(result.is_ok());
347        assert!(result.unwrap());
348    }
349
350    #[tokio::test]
351    async fn test_rate_limiter_is_allowed_with_large_amount() {
352        let limiter = RateLimiter::new(None).await.unwrap();
353        let result = limiter.is_allowed("test-key", 999999).await;
354        assert!(result.is_ok());
355        assert!(result.unwrap());
356    }
357
358    #[tokio::test]
359    async fn test_rate_limiter_check_rpm_with_different_limits() {
360        let limiter = RateLimiter::new(None).await.unwrap();
361
362        let result1 = limiter.check_rpm("key1", 10).await;
363        assert!(result1.is_ok());
364        assert_eq!(result1.unwrap().1, 10);
365
366        let result2 = limiter.check_rpm("key2", 1000).await;
367        assert!(result2.is_ok());
368        assert_eq!(result2.unwrap().1, 1000);
369    }
370
371    #[tokio::test]
372    async fn test_rate_limiter_record_usage_multiple_times() {
373        let limiter = RateLimiter::new(None).await.unwrap();
374
375        assert!(limiter.record_usage("key", 100).await.is_ok());
376        assert!(limiter.record_usage("key", 200).await.is_ok());
377        assert!(limiter.record_usage("key", 300).await.is_ok());
378    }
379}