Skip to main content

essence/rate_limit/
mod.rs

1use governor::{
2    clock::DefaultClock, state::direct::NotKeyed, state::InMemoryState, Quota, RateLimiter,
3};
4use std::collections::HashMap;
5use std::num::NonZeroU32;
6use std::sync::Arc;
7use tokio::sync::RwLock;
8use tracing::debug;
9
10/// Rate limit error
11#[derive(Debug, thiserror::Error)]
12pub enum RateLimitError {
13    #[error("Rate limit exceeded for API key")]
14    Exceeded,
15}
16
17/// Type alias for the rate limiter we use
18type ApiRateLimiter = RateLimiter<NotKeyed, InMemoryState, DefaultClock>;
19
20/// Per-API-key rate limiter
21pub struct ApiKeyRateLimiter {
22    /// Map of API key to rate limiter
23    limiters: Arc<RwLock<HashMap<String, Arc<ApiRateLimiter>>>>,
24
25    /// Default rate limit (requests per minute)
26    default_limit: NonZeroU32,
27
28    /// Per-user/API key limits
29    api_key_limits: HashMap<String, NonZeroU32>,
30}
31
32impl ApiKeyRateLimiter {
33    /// Create a new API key rate limiter
34    pub fn new(default_per_minute: u32, api_key_limits: HashMap<String, u32>) -> Self {
35        let default_limit = NonZeroU32::new(default_per_minute)
36            .unwrap_or_else(|| NonZeroU32::new(60).unwrap());
37
38        let mut limits_map = HashMap::new();
39        for (key, limit) in api_key_limits {
40            if let Some(non_zero) = NonZeroU32::new(limit) {
41                limits_map.insert(key, non_zero);
42            }
43        }
44
45        Self {
46            limiters: Arc::new(RwLock::new(HashMap::new())),
47            default_limit,
48            api_key_limits: limits_map,
49        }
50    }
51
52    /// Check rate limit for a specific API key
53    pub async fn check_limit(&self, api_key: &str) -> Result<(), RateLimitError> {
54        let limiter = self.get_or_create_limiter(api_key).await;
55        
56        match limiter.check() {
57            Ok(_) => {
58                debug!("Rate limit check passed for API key: {}", Self::redact_key(api_key));
59                Ok(())
60            }
61            Err(_) => {
62                debug!("Rate limit exceeded for API key: {}", Self::redact_key(api_key));
63                Err(RateLimitError::Exceeded)
64            }
65        }
66    }
67
68    /// Get or create a rate limiter for the given API key
69    async fn get_or_create_limiter(&self, api_key: &str) -> Arc<ApiRateLimiter> {
70        // Check if limiter already exists
71        {
72            let limiters = self.limiters.read().await;
73            if let Some(limiter) = limiters.get(api_key) {
74                return Arc::clone(limiter);
75            }
76        }
77
78        // Create new limiter
79        let limit = self.get_limit_for_key(api_key);
80        let quota = Quota::per_minute(limit);
81        let limiter = Arc::new(RateLimiter::direct(quota));
82
83        // Store it
84        {
85            let mut limiters = self.limiters.write().await;
86            limiters.insert(api_key.to_string(), Arc::clone(&limiter));
87        }
88
89        debug!(
90            "Created rate limiter for API key {} with limit: {}/minute",
91            Self::redact_key(api_key),
92            limit
93        );
94
95        limiter
96    }
97
98    /// Get the rate limit for a specific API key
99    fn get_limit_for_key(&self, api_key: &str) -> NonZeroU32 {
100        // Try to find user ID from API key mapping
101        // For now, we use the API key directly as the user identifier
102        self.api_key_limits
103            .get(api_key)
104            .copied()
105            .unwrap_or(self.default_limit)
106    }
107
108    /// Redact API key for logging (show only first 8 chars)
109    fn redact_key(key: &str) -> String {
110        if key.len() > 8 {
111            format!("{}...", &key[..8])
112        } else {
113            "***".to_string()
114        }
115    }
116}
117
118#[cfg(test)]
119mod tests {
120    use super::*;
121
122    #[tokio::test]
123    async fn test_rate_limiter_default_limit() {
124        let limiter = ApiKeyRateLimiter::new(2, HashMap::new());
125
126        // First two requests should pass
127        assert!(limiter.check_limit("test_key").await.is_ok());
128        assert!(limiter.check_limit("test_key").await.is_ok());
129
130        // Third request should fail (limit is 2 per minute)
131        assert!(limiter.check_limit("test_key").await.is_err());
132    }
133
134    #[tokio::test]
135    async fn test_rate_limiter_per_user_limit() {
136        let mut limits = HashMap::new();
137        limits.insert("user1".to_string(), 5);
138        limits.insert("user2".to_string(), 2);
139
140        let limiter = ApiKeyRateLimiter::new(10, limits);
141
142        // user1 should be able to make 5 requests
143        for _ in 0..5 {
144            assert!(limiter.check_limit("user1").await.is_ok());
145        }
146        assert!(limiter.check_limit("user1").await.is_err());
147
148        // user2 should be able to make 2 requests
149        for _ in 0..2 {
150            assert!(limiter.check_limit("user2").await.is_ok());
151        }
152        assert!(limiter.check_limit("user2").await.is_err());
153
154        // user3 (not in limits map) should get default limit of 10
155        for _ in 0..10 {
156            assert!(limiter.check_limit("user3").await.is_ok());
157        }
158        assert!(limiter.check_limit("user3").await.is_err());
159    }
160
161    #[tokio::test]
162    async fn test_rate_limiter_separate_keys() {
163        let limiter = ApiKeyRateLimiter::new(2, HashMap::new());
164
165        // Each key should have its own limit
166        assert!(limiter.check_limit("key1").await.is_ok());
167        assert!(limiter.check_limit("key2").await.is_ok());
168        assert!(limiter.check_limit("key1").await.is_ok());
169        assert!(limiter.check_limit("key2").await.is_ok());
170
171        // Both should be at limit now
172        assert!(limiter.check_limit("key1").await.is_err());
173        assert!(limiter.check_limit("key2").await.is_err());
174    }
175
176    #[test]
177    fn test_redact_key() {
178        assert_eq!(ApiKeyRateLimiter::redact_key("short"), "***");
179        assert_eq!(
180            ApiKeyRateLimiter::redact_key("sk-ant-api03-longkeyhere"),
181            "sk-ant-a..."
182        );
183    }
184}