use governor::{
clock::DefaultClock, state::direct::NotKeyed, state::InMemoryState, Quota, RateLimiter,
};
use std::collections::HashMap;
use std::num::NonZeroU32;
use std::sync::Arc;
use tokio::sync::RwLock;
use tracing::debug;
#[derive(Debug, thiserror::Error)]
pub enum RateLimitError {
#[error("Rate limit exceeded for API key")]
Exceeded,
}
type ApiRateLimiter = RateLimiter<NotKeyed, InMemoryState, DefaultClock>;
pub struct ApiKeyRateLimiter {
limiters: Arc<RwLock<HashMap<String, Arc<ApiRateLimiter>>>>,
default_limit: NonZeroU32,
api_key_limits: HashMap<String, NonZeroU32>,
}
impl ApiKeyRateLimiter {
pub fn new(default_per_minute: u32, api_key_limits: HashMap<String, u32>) -> Self {
let default_limit = NonZeroU32::new(default_per_minute)
.unwrap_or_else(|| NonZeroU32::new(60).unwrap());
let mut limits_map = HashMap::new();
for (key, limit) in api_key_limits {
if let Some(non_zero) = NonZeroU32::new(limit) {
limits_map.insert(key, non_zero);
}
}
Self {
limiters: Arc::new(RwLock::new(HashMap::new())),
default_limit,
api_key_limits: limits_map,
}
}
pub async fn check_limit(&self, api_key: &str) -> Result<(), RateLimitError> {
let limiter = self.get_or_create_limiter(api_key).await;
match limiter.check() {
Ok(_) => {
debug!("Rate limit check passed for API key: {}", Self::redact_key(api_key));
Ok(())
}
Err(_) => {
debug!("Rate limit exceeded for API key: {}", Self::redact_key(api_key));
Err(RateLimitError::Exceeded)
}
}
}
async fn get_or_create_limiter(&self, api_key: &str) -> Arc<ApiRateLimiter> {
{
let limiters = self.limiters.read().await;
if let Some(limiter) = limiters.get(api_key) {
return Arc::clone(limiter);
}
}
let limit = self.get_limit_for_key(api_key);
let quota = Quota::per_minute(limit);
let limiter = Arc::new(RateLimiter::direct(quota));
{
let mut limiters = self.limiters.write().await;
limiters.insert(api_key.to_string(), Arc::clone(&limiter));
}
debug!(
"Created rate limiter for API key {} with limit: {}/minute",
Self::redact_key(api_key),
limit
);
limiter
}
fn get_limit_for_key(&self, api_key: &str) -> NonZeroU32 {
self.api_key_limits
.get(api_key)
.copied()
.unwrap_or(self.default_limit)
}
fn redact_key(key: &str) -> String {
if key.len() > 8 {
format!("{}...", &key[..8])
} else {
"***".to_string()
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_rate_limiter_default_limit() {
let limiter = ApiKeyRateLimiter::new(2, HashMap::new());
assert!(limiter.check_limit("test_key").await.is_ok());
assert!(limiter.check_limit("test_key").await.is_ok());
assert!(limiter.check_limit("test_key").await.is_err());
}
#[tokio::test]
async fn test_rate_limiter_per_user_limit() {
let mut limits = HashMap::new();
limits.insert("user1".to_string(), 5);
limits.insert("user2".to_string(), 2);
let limiter = ApiKeyRateLimiter::new(10, limits);
for _ in 0..5 {
assert!(limiter.check_limit("user1").await.is_ok());
}
assert!(limiter.check_limit("user1").await.is_err());
for _ in 0..2 {
assert!(limiter.check_limit("user2").await.is_ok());
}
assert!(limiter.check_limit("user2").await.is_err());
for _ in 0..10 {
assert!(limiter.check_limit("user3").await.is_ok());
}
assert!(limiter.check_limit("user3").await.is_err());
}
#[tokio::test]
async fn test_rate_limiter_separate_keys() {
let limiter = ApiKeyRateLimiter::new(2, HashMap::new());
assert!(limiter.check_limit("key1").await.is_ok());
assert!(limiter.check_limit("key2").await.is_ok());
assert!(limiter.check_limit("key1").await.is_ok());
assert!(limiter.check_limit("key2").await.is_ok());
assert!(limiter.check_limit("key1").await.is_err());
assert!(limiter.check_limit("key2").await.is_err());
}
#[test]
fn test_redact_key() {
assert_eq!(ApiKeyRateLimiter::redact_key("short"), "***");
assert_eq!(
ApiKeyRateLimiter::redact_key("sk-ant-api03-longkeyhere"),
"sk-ant-a..."
);
}
}