essence/rate_limit/
mod.rs1use governor::{
2 clock::DefaultClock, state::direct::NotKeyed, state::InMemoryState, Quota, RateLimiter,
3};
4use std::collections::HashMap;
5use std::num::NonZeroU32;
6use std::sync::Arc;
7use tokio::sync::RwLock;
8use tracing::debug;
9
10#[derive(Debug, thiserror::Error)]
12pub enum RateLimitError {
13 #[error("Rate limit exceeded for API key")]
14 Exceeded,
15}
16
17type ApiRateLimiter = RateLimiter<NotKeyed, InMemoryState, DefaultClock>;
19
20pub struct ApiKeyRateLimiter {
22 limiters: Arc<RwLock<HashMap<String, Arc<ApiRateLimiter>>>>,
24
25 default_limit: NonZeroU32,
27
28 api_key_limits: HashMap<String, NonZeroU32>,
30}
31
32impl ApiKeyRateLimiter {
33 pub fn new(default_per_minute: u32, api_key_limits: HashMap<String, u32>) -> Self {
35 let default_limit = NonZeroU32::new(default_per_minute)
36 .unwrap_or_else(|| NonZeroU32::new(60).unwrap());
37
38 let mut limits_map = HashMap::new();
39 for (key, limit) in api_key_limits {
40 if let Some(non_zero) = NonZeroU32::new(limit) {
41 limits_map.insert(key, non_zero);
42 }
43 }
44
45 Self {
46 limiters: Arc::new(RwLock::new(HashMap::new())),
47 default_limit,
48 api_key_limits: limits_map,
49 }
50 }
51
52 pub async fn check_limit(&self, api_key: &str) -> Result<(), RateLimitError> {
54 let limiter = self.get_or_create_limiter(api_key).await;
55
56 match limiter.check() {
57 Ok(_) => {
58 debug!("Rate limit check passed for API key: {}", Self::redact_key(api_key));
59 Ok(())
60 }
61 Err(_) => {
62 debug!("Rate limit exceeded for API key: {}", Self::redact_key(api_key));
63 Err(RateLimitError::Exceeded)
64 }
65 }
66 }
67
68 async fn get_or_create_limiter(&self, api_key: &str) -> Arc<ApiRateLimiter> {
70 {
72 let limiters = self.limiters.read().await;
73 if let Some(limiter) = limiters.get(api_key) {
74 return Arc::clone(limiter);
75 }
76 }
77
78 let limit = self.get_limit_for_key(api_key);
80 let quota = Quota::per_minute(limit);
81 let limiter = Arc::new(RateLimiter::direct(quota));
82
83 {
85 let mut limiters = self.limiters.write().await;
86 limiters.insert(api_key.to_string(), Arc::clone(&limiter));
87 }
88
89 debug!(
90 "Created rate limiter for API key {} with limit: {}/minute",
91 Self::redact_key(api_key),
92 limit
93 );
94
95 limiter
96 }
97
98 fn get_limit_for_key(&self, api_key: &str) -> NonZeroU32 {
100 self.api_key_limits
103 .get(api_key)
104 .copied()
105 .unwrap_or(self.default_limit)
106 }
107
108 fn redact_key(key: &str) -> String {
110 if key.len() > 8 {
111 format!("{}...", &key[..8])
112 } else {
113 "***".to_string()
114 }
115 }
116}
117
118#[cfg(test)]
119mod tests {
120 use super::*;
121
122 #[tokio::test]
123 async fn test_rate_limiter_default_limit() {
124 let limiter = ApiKeyRateLimiter::new(2, HashMap::new());
125
126 assert!(limiter.check_limit("test_key").await.is_ok());
128 assert!(limiter.check_limit("test_key").await.is_ok());
129
130 assert!(limiter.check_limit("test_key").await.is_err());
132 }
133
134 #[tokio::test]
135 async fn test_rate_limiter_per_user_limit() {
136 let mut limits = HashMap::new();
137 limits.insert("user1".to_string(), 5);
138 limits.insert("user2".to_string(), 2);
139
140 let limiter = ApiKeyRateLimiter::new(10, limits);
141
142 for _ in 0..5 {
144 assert!(limiter.check_limit("user1").await.is_ok());
145 }
146 assert!(limiter.check_limit("user1").await.is_err());
147
148 for _ in 0..2 {
150 assert!(limiter.check_limit("user2").await.is_ok());
151 }
152 assert!(limiter.check_limit("user2").await.is_err());
153
154 for _ in 0..10 {
156 assert!(limiter.check_limit("user3").await.is_ok());
157 }
158 assert!(limiter.check_limit("user3").await.is_err());
159 }
160
161 #[tokio::test]
162 async fn test_rate_limiter_separate_keys() {
163 let limiter = ApiKeyRateLimiter::new(2, HashMap::new());
164
165 assert!(limiter.check_limit("key1").await.is_ok());
167 assert!(limiter.check_limit("key2").await.is_ok());
168 assert!(limiter.check_limit("key1").await.is_ok());
169 assert!(limiter.check_limit("key2").await.is_ok());
170
171 assert!(limiter.check_limit("key1").await.is_err());
173 assert!(limiter.check_limit("key2").await.is_err());
174 }
175
176 #[test]
177 fn test_redact_key() {
178 assert_eq!(ApiKeyRateLimiter::redact_key("short"), "***");
179 assert_eq!(
180 ApiKeyRateLimiter::redact_key("sk-ant-api03-longkeyhere"),
181 "sk-ant-a..."
182 );
183 }
184}