1use redis::aio::ConnectionManager;
6use redis::Client;
7use serde::{Deserialize, Serialize};
8use std::time::Instant;
9
10pub const USAGE_TOKENS_KEY_PREFIX: &str = "hyperinfer:usage:tokens:";
11pub const USAGE_REQUESTS_KEY_PREFIX: &str = "hyperinfer:usage:requests:";
12
13const GCRA_SCRIPT: &str = r#"
14local key = KEYS[1]
15local rate = tonumber(ARGV[1])
16local capacity = tonumber(ARGV[2])
17local now = tonumber(ARGV[3])
18local cost = tonumber(ARGV[4])
19
20local emission_interval = capacity / rate
21local tat = redis.call('GET', key)
22
23if not tat then
24 tat = now
25else
26 tat = tonumber(tat)
27end
28
29local new_tat = math.max(tat, now) + cost * emission_interval
30local allow_at = new_tat - capacity
31
32if allow_at <= now then
33 redis.call('SET', key, new_tat, 'EX', math.ceil(capacity * 2))
34 return {1, 0}
35else
36 return {0, math.ceil(allow_at - now)}
37end
38"#;
39
40const RPM_SCRIPT: &str = r#"
41local key = KEYS[1]
42local limit = tonumber(ARGV[1])
43local window = tonumber(ARGV[2])
44
45local current = redis.call('INCR', key)
46if current == 1 then
47 redis.call('EXPIRE', key, window)
48end
49
50if current > limit then
51 local ttl = redis.call('TTL', key)
52 return {0, 0, ttl}
53end
54return {1, limit - current, 0}
55"#;
56
57#[derive(Debug, Clone)]
58pub struct TokenBucket {
59 pub capacity: u64,
60 pub tokens: u64,
61 pub refill_rate: u64,
62 pub last_refill: Instant,
63}
64
65#[derive(Debug, Clone, Serialize, Deserialize)]
66pub struct Quota {
67 pub max_requests_per_minute: Option<u64>,
68 pub max_tokens_per_minute: Option<u64>,
69 pub budget_cents: Option<u64>,
70}
71
72#[derive(Clone)]
73pub struct RateLimiter {
74 redis_manager: Option<ConnectionManager>,
75 default_rpm: u64,
76 default_tpm: u64,
77}
78
79impl RateLimiter {
80 pub async fn new(
81 redis_url: Option<&str>,
82 ) -> Result<Self, Box<dyn std::error::Error + Send + Sync>> {
83 let redis_manager = match redis_url {
84 Some(url) => {
85 let client = Client::open(url)?;
86 Some(ConnectionManager::new(client).await?)
87 }
88 None => None,
89 };
90 Ok(Self {
91 redis_manager,
92 default_rpm: 60,
93 default_tpm: 100000,
94 })
95 }
96
97 pub async fn is_allowed(
98 &self,
99 key: &str,
100 amount: u64,
101 ) -> Result<bool, Box<dyn std::error::Error + Send + Sync>> {
102 if let Some(ref manager) = self.redis_manager {
103 let mut conn = manager.clone();
104
105 let result: Vec<u64> = redis::cmd("EVAL")
106 .arg(RPM_SCRIPT)
107 .arg(1)
108 .arg(format!("hyperinfer:ratelimit:rpm:{}", key))
109 .arg(self.default_rpm)
110 .arg(60)
111 .query_async(&mut conn)
112 .await?;
113
114 let allowed = result.first().copied().unwrap_or(0);
115 if allowed == 0 {
116 return Ok(false);
117 }
118
119 let tpm_key = format!("hyperinfer:ratelimit:tpm:{}", key);
120 let now = std::time::SystemTime::now()
121 .duration_since(std::time::UNIX_EPOCH)
122 .map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)?
123 .as_millis() as u64;
124 let rate = self.default_tpm / 60;
125 let tpm_result: Vec<u64> = redis::cmd("EVAL")
126 .arg(GCRA_SCRIPT)
127 .arg(1)
128 .arg(&tpm_key)
129 .arg(rate)
130 .arg(self.default_tpm)
131 .arg(now)
132 .arg(amount)
133 .query_async(&mut conn)
134 .await?;
135
136 Ok(tpm_result.first().copied().unwrap_or(0) == 1)
137 } else {
138 Ok(true)
139 }
140 }
141
142 pub async fn check_rpm(
143 &self,
144 key: &str,
145 limit: u64,
146 ) -> Result<(bool, u64), Box<dyn std::error::Error + Send + Sync>> {
147 if let Some(ref manager) = self.redis_manager {
148 let mut conn = manager.clone();
149
150 let result: Vec<u64> = redis::cmd("EVAL")
151 .arg(RPM_SCRIPT)
152 .arg(1)
153 .arg(format!("hyperinfer:ratelimit:rpm:{}", key))
154 .arg(limit)
155 .arg(60)
156 .query_async(&mut conn)
157 .await?;
158
159 let allowed = result.first().copied().unwrap_or(0) == 1;
160 let remaining = result.get(1).copied().unwrap_or(0);
161 Ok((allowed, remaining))
162 } else {
163 Ok((true, limit))
164 }
165 }
166
167 pub async fn check_tpm(
168 &self,
169 key: &str,
170 limit: u64,
171 tokens: u64,
172 ) -> Result<bool, Box<dyn std::error::Error + Send + Sync>> {
173 if let Some(ref manager) = self.redis_manager {
174 let mut conn = manager.clone();
175
176 let now = std::time::SystemTime::now()
177 .duration_since(std::time::UNIX_EPOCH)
178 .map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)?
179 .as_millis() as u64;
180 let rate = limit / 60;
181
182 let result: Vec<u64> = redis::cmd("EVAL")
183 .arg(GCRA_SCRIPT)
184 .arg(1)
185 .arg(format!("hyperinfer:ratelimit:tpm:{}", key))
186 .arg(rate)
187 .arg(limit)
188 .arg(now)
189 .arg(tokens)
190 .query_async(&mut conn)
191 .await?;
192
193 Ok(result.first().copied().unwrap_or(0) == 1)
194 } else {
195 Ok(true)
196 }
197 }
198
199 pub async fn record_usage(
200 &self,
201 key: &str,
202 tokens_used: u64,
203 ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
204 if let Some(ref manager) = self.redis_manager {
205 let mut conn = manager.clone();
206
207 redis::pipe()
208 .atomic()
209 .cmd("INCRBY")
210 .arg(format!("{}{}", USAGE_TOKENS_KEY_PREFIX, key))
211 .arg(tokens_used)
212 .cmd("INCR")
213 .arg(format!("{}{}", USAGE_REQUESTS_KEY_PREFIX, key))
214 .query_async::<()>(&mut conn)
215 .await?;
216 }
217 Ok(())
218 }
219}
220
221#[cfg(test)]
222mod tests {
223 use super::*;
224
225 #[tokio::test]
226 async fn test_rate_limiter_new_without_redis() {
227 let result = RateLimiter::new(None).await;
228 assert!(result.is_ok());
229 let limiter = result.unwrap();
230 assert_eq!(limiter.default_rpm, 60);
231 assert_eq!(limiter.default_tpm, 100000);
232 }
233
234 #[tokio::test]
235 async fn test_rate_limiter_is_allowed_without_redis() {
236 let limiter = RateLimiter::new(None).await.unwrap();
237 let result = limiter.is_allowed("test-key", 1).await;
238 assert!(result.is_ok());
239 assert!(result.unwrap());
241 }
242
243 #[tokio::test]
244 async fn test_rate_limiter_check_rpm_without_redis() {
245 let limiter = RateLimiter::new(None).await.unwrap();
246 let result = limiter.check_rpm("test-key", 100).await;
247 assert!(result.is_ok());
248 let (allowed, remaining) = result.unwrap();
249 assert!(allowed);
251 assert_eq!(remaining, 100);
252 }
253
254 #[tokio::test]
255 async fn test_rate_limiter_check_tpm_without_redis() {
256 let limiter = RateLimiter::new(None).await.unwrap();
257 let result = limiter.check_tpm("test-key", 1000, 100).await;
258 assert!(result.is_ok());
259 assert!(result.unwrap());
261 }
262
263 #[tokio::test]
264 async fn test_rate_limiter_record_usage_without_redis() {
265 let limiter = RateLimiter::new(None).await.unwrap();
266 let result = limiter.record_usage("test-key", 50).await;
267 assert!(result.is_ok());
268 }
269
270 #[test]
271 fn test_token_bucket_creation() {
272 let bucket = TokenBucket {
273 capacity: 100,
274 tokens: 100,
275 refill_rate: 10,
276 last_refill: Instant::now(),
277 };
278
279 assert_eq!(bucket.capacity, 100);
280 assert_eq!(bucket.tokens, 100);
281 assert_eq!(bucket.refill_rate, 10);
282 }
283
284 #[test]
285 fn test_token_bucket_clone() {
286 let bucket = TokenBucket {
287 capacity: 50,
288 tokens: 25,
289 refill_rate: 5,
290 last_refill: Instant::now(),
291 };
292
293 let cloned = bucket.clone();
294 assert_eq!(bucket.capacity, cloned.capacity);
295 assert_eq!(bucket.tokens, cloned.tokens);
296 assert_eq!(bucket.refill_rate, cloned.refill_rate);
297 }
298
299 #[test]
300 fn test_quota_creation() {
301 let quota = Quota {
302 max_requests_per_minute: Some(60),
303 max_tokens_per_minute: Some(100000),
304 budget_cents: Some(1000),
305 };
306
307 assert_eq!(quota.max_requests_per_minute, Some(60));
308 assert_eq!(quota.max_tokens_per_minute, Some(100000));
309 assert_eq!(quota.budget_cents, Some(1000));
310 }
311
312 #[test]
313 fn test_quota_with_none_values() {
314 let quota = Quota {
315 max_requests_per_minute: None,
316 max_tokens_per_minute: None,
317 budget_cents: None,
318 };
319
320 assert_eq!(quota.max_requests_per_minute, None);
321 assert_eq!(quota.max_tokens_per_minute, None);
322 assert_eq!(quota.budget_cents, None);
323 }
324
325 #[test]
326 fn test_quota_clone() {
327 let quota = Quota {
328 max_requests_per_minute: Some(100),
329 max_tokens_per_minute: Some(50000),
330 budget_cents: Some(2000),
331 };
332
333 let cloned = quota.clone();
334 assert_eq!(
335 quota.max_requests_per_minute,
336 cloned.max_requests_per_minute
337 );
338 assert_eq!(quota.max_tokens_per_minute, cloned.max_tokens_per_minute);
339 assert_eq!(quota.budget_cents, cloned.budget_cents);
340 }
341
342 #[tokio::test]
343 async fn test_rate_limiter_is_allowed_with_zero_amount() {
344 let limiter = RateLimiter::new(None).await.unwrap();
345 let result = limiter.is_allowed("test-key", 0).await;
346 assert!(result.is_ok());
347 assert!(result.unwrap());
348 }
349
350 #[tokio::test]
351 async fn test_rate_limiter_is_allowed_with_large_amount() {
352 let limiter = RateLimiter::new(None).await.unwrap();
353 let result = limiter.is_allowed("test-key", 999999).await;
354 assert!(result.is_ok());
355 assert!(result.unwrap());
356 }
357
358 #[tokio::test]
359 async fn test_rate_limiter_check_rpm_with_different_limits() {
360 let limiter = RateLimiter::new(None).await.unwrap();
361
362 let result1 = limiter.check_rpm("key1", 10).await;
363 assert!(result1.is_ok());
364 assert_eq!(result1.unwrap().1, 10);
365
366 let result2 = limiter.check_rpm("key2", 1000).await;
367 assert!(result2.is_ok());
368 assert_eq!(result2.unwrap().1, 1000);
369 }
370
371 #[tokio::test]
372 async fn test_rate_limiter_record_usage_multiple_times() {
373 let limiter = RateLimiter::new(None).await.unwrap();
374
375 assert!(limiter.record_usage("key", 100).await.is_ok());
376 assert!(limiter.record_usage("key", 200).await.is_ok());
377 assert!(limiter.record_usage("key", 300).await.is_ok());
378 }
379}