use std::{
sync::{
atomic::{AtomicU64, Ordering},
Arc,
},
time::{Duration, SystemTime, UNIX_EPOCH},
};
use dashmap::DashMap;
use super::{ApiKeyError, ApiKeyId, ApiKeyRole};
#[derive(Clone, Debug)]
struct RateLimiter {
current_subscriptions: Arc<AtomicU64>,
requests_this_minute: Arc<AtomicU64>,
minute_start_time: Arc<AtomicU64>,
}
impl RateLimiter {
fn new() -> Self {
let current_time = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or(Duration::from_secs(0))
.as_secs();
RateLimiter {
current_subscriptions: Arc::new(AtomicU64::new(0)),
requests_this_minute: Arc::new(AtomicU64::new(0)),
minute_start_time: Arc::new(AtomicU64::new(current_time)),
}
}
pub fn current_subscriptions(&self) -> u64 {
self.current_subscriptions.load(Ordering::Relaxed)
}
pub fn minute_start_time(&self) -> u64 {
self.minute_start_time.load(Ordering::Relaxed)
}
pub fn record_request(&self) -> u64 {
let current_time = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or(Duration::from_secs(0))
.as_secs();
let minute_start = self.minute_start_time();
if current_time >= minute_start + 60 {
self.requests_this_minute.store(1, Ordering::Relaxed);
self.minute_start_time
.store(current_time, Ordering::Relaxed);
1
} else {
let current =
self.requests_this_minute.fetch_add(1, Ordering::Relaxed);
current + 1
}
}
}
#[derive(Debug, Clone, Default)]
pub struct RateLimitsController {
map: DashMap<ApiKeyId, RateLimiter>,
}
impl RateLimitsController {
pub fn arc(self) -> Arc<Self> {
Arc::new(self)
}
fn get_or_create(
&self,
id: &ApiKeyId,
) -> dashmap::mapref::one::Ref<'_, ApiKeyId, RateLimiter> {
if !self.map.contains_key(id) {
let new_limiter = RateLimiter::new();
self.map.insert(id.to_owned(), new_limiter);
}
self.map.get(id).unwrap() }
pub fn add_active_key_sub(&self, id: &ApiKeyId) {
if let Some(user_rate_limiter) = self.map.get_mut(id) {
user_rate_limiter
.current_subscriptions
.fetch_add(1, Ordering::Relaxed);
} else {
let rate_limiter = RateLimiter::new();
rate_limiter
.current_subscriptions
.fetch_add(1, Ordering::Relaxed);
self.map.insert(id.to_owned(), rate_limiter);
}
}
pub fn remove_active_key_sub(&self, id: &ApiKeyId) {
if let Some(user_rate_limiter) = self.map.get_mut(id) {
user_rate_limiter
.current_subscriptions
.fetch_sub(1, Ordering::Relaxed);
}
}
pub fn check_subscriptions(
&self,
api_key_id: &ApiKeyId,
role: &ApiKeyRole,
) -> Result<(bool, u64), ApiKeyError> {
let rate_limiter = self.get_or_create(api_key_id);
let current_subscriptions = rate_limiter.current_subscriptions();
let current_rate_limit =
role.validate_subscription_limit(current_subscriptions.into())?;
Ok((true, current_rate_limit.into()))
}
pub fn check_rate_limit(
&self,
api_key_id: &ApiKeyId,
role: &ApiKeyRole,
) -> Result<(bool, u64), ApiKeyError> {
let rate_limiter = self.get_or_create(api_key_id);
let request_count = rate_limiter.record_request();
let current_rate_limit =
role.validate_rate_limit(request_count.into())?;
Ok((true, current_rate_limit.into()))
}
}
#[cfg(test)]
mod tests {
use std::time::Duration;
use pretty_assertions::assert_eq;
use super::*;
use crate::api_key::{ApiKeyId, MockApiKeyRole};
fn setup_rate_limiter(
id: u64,
role_type: &str,
) -> (RateLimitsController, ApiKeyId, ApiKeyRole) {
let api_key_id = ApiKeyId::from(id);
let role = match role_type {
"admin" => MockApiKeyRole::admin().into_inner(),
"builder" => MockApiKeyRole::builder().into_inner(),
"web_client" => MockApiKeyRole::web_client().into_inner(),
_ => MockApiKeyRole::web_client().into_inner(),
};
let rate_limiter = RateLimitsController::default();
(rate_limiter, api_key_id, role)
}
#[test]
fn test_rate_limiter_add_remove_subscriptions() {
let (rate_limiter, api_key_id, role) = setup_rate_limiter(1, "admin");
rate_limiter.add_active_key_sub(&api_key_id);
rate_limiter.add_active_key_sub(&api_key_id);
let (_, count) = rate_limiter
.check_subscriptions(&api_key_id, &role)
.expect("Failed to check rate limit");
assert_eq!(count, 2, "Should have 2 active subscriptions");
rate_limiter.remove_active_key_sub(&api_key_id);
let (_, count) = rate_limiter
.check_subscriptions(&api_key_id, &role)
.expect("Failed to check rate limit");
assert_eq!(count, 1, "Should have 1 active subscription after removal");
}
#[test]
fn test_builder_subscription_limit() {
let (rate_limiter, api_key_id, role) = setup_rate_limiter(3, "builder");
for _ in 0..50 {
rate_limiter.add_active_key_sub(&api_key_id);
}
let result = rate_limiter.check_subscriptions(&api_key_id, &role);
assert!(
result.is_ok(),
"Subscription check should succeed at the limit"
);
rate_limiter.add_active_key_sub(&api_key_id);
let result = rate_limiter.check_subscriptions(&api_key_id, &role);
assert!(
matches!(result, Err(ApiKeyError::SubscriptionLimitExceeded(_))),
"Expected SubscriptionLimitExceeded error when exceeding subscription limit"
);
}
#[test]
fn test_new_key_initialization() {
let (rate_limiter, api_key_id, role) =
setup_rate_limiter(8, "web_client");
let (success, count) = rate_limiter
.check_subscriptions(&api_key_id, &role)
.expect("Failed to check rate limit for new key");
assert!(success, "New key check should succeed");
assert_eq!(count, 0, "New key should start with 0 subscriptions");
rate_limiter.add_active_key_sub(&api_key_id);
let (_, updated_count) = rate_limiter
.check_subscriptions(&api_key_id, &role)
.expect("Failed to check updated rate limit");
assert_eq!(updated_count, 1, "Key should now have 1 subscription");
}
#[test]
fn test_admin_role_unlimited() {
let (rate_limiter, api_key_id, role) = setup_rate_limiter(4, "admin");
for _ in 0..2000 {
rate_limiter.add_active_key_sub(&api_key_id);
}
let sub_result = rate_limiter.check_subscriptions(&api_key_id, &role);
assert!(
sub_result.is_ok(),
"Admin role should have no subscription limit"
);
for i in 0..2000 {
let rate_result = rate_limiter.check_rate_limit(&api_key_id, &role);
assert!(
rate_result.is_ok(),
"Request {} should be allowed for admin",
i
);
}
}
#[test]
fn test_web_client_rate_limit_and_reset() {
let (rate_limiter, api_key_id, role) =
setup_rate_limiter(10, "web_client");
for i in 0..1000 {
let result = rate_limiter.check_rate_limit(&api_key_id, &role);
assert!(result.is_ok(), "Request {} should be allowed", i);
let (_, count) = result.unwrap();
assert_eq!(count, i + 1, "Request count should match");
}
let result = rate_limiter.check_rate_limit(&api_key_id, &role);
assert!(
matches!(result, Err(ApiKeyError::RateLimitExceeded(_))),
"Request 1001 should be denied"
);
if let Some(limiter) = rate_limiter.map.get_mut(&api_key_id) {
let current_time = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or(Duration::from_secs(0))
.as_secs();
limiter
.minute_start_time
.store(current_time - 61, Ordering::Relaxed);
limiter.requests_this_minute.store(999, Ordering::Relaxed);
}
let result = rate_limiter.check_rate_limit(&api_key_id, &role);
assert!(result.is_ok(), "Request after reset should be allowed");
let (_, count) = result.unwrap();
assert_eq!(count, 1, "Counter should be reset to 1");
}
#[test]
fn test_multiple_keys_management() {
let (rate_limiter, admin_key_id, admin_role) =
setup_rate_limiter(5, "admin");
let builder_key_id = ApiKeyId::from(6);
let web_client_key_id = ApiKeyId::from(7);
let builder_role = MockApiKeyRole::builder().into_inner();
let web_client_role = MockApiKeyRole::web_client().into_inner();
for _ in 0..100 {
rate_limiter.add_active_key_sub(&admin_key_id);
}
for _ in 0..30 {
rate_limiter.add_active_key_sub(&builder_key_id);
}
for _ in 0..500 {
rate_limiter.add_active_key_sub(&web_client_key_id);
}
let (_, admin_count) = rate_limiter
.check_subscriptions(&admin_key_id, &admin_role)
.expect("Failed to check rate limit for admin key");
assert_eq!(
admin_count, 100,
"Admin key should have 100 active subscriptions"
);
let (_, builder_count) = rate_limiter
.check_subscriptions(&builder_key_id, &builder_role)
.expect("Failed to check rate limit for builder key");
assert_eq!(
builder_count, 30,
"Builder key should have 30 active subscriptions"
);
let (_, web_client_count) = rate_limiter
.check_subscriptions(&web_client_key_id, &web_client_role)
.expect("Failed to check rate limit for web client key");
assert_eq!(
web_client_count, 500,
"Web client key should have 500 active subscriptions"
);
rate_limiter.remove_active_key_sub(&admin_key_id);
let (_, updated_admin_count) = rate_limiter
.check_subscriptions(&admin_key_id, &admin_role)
.expect("Failed to check updated rate limit for admin key");
assert_eq!(
updated_admin_count, 99,
"Admin key should now have 99 active subscriptions"
);
let (_, unchanged_builder_count) = rate_limiter
.check_subscriptions(&builder_key_id, &builder_role)
.expect("Failed to check unchanged rate limit for builder key");
assert_eq!(
unchanged_builder_count, 30,
"Builder key should still have 30 active subscriptions"
);
for i in 0..1200 {
let admin_result =
rate_limiter.check_rate_limit(&admin_key_id, &admin_role);
assert!(
admin_result.is_ok(),
"Admin request {} should be allowed",
i
);
let web_client_result = rate_limiter
.check_rate_limit(&web_client_key_id, &web_client_role);
if i < 1000 {
assert!(
web_client_result.is_ok(),
"Web client request {} should be allowed",
i
);
} else {
assert!(
matches!(
web_client_result,
Err(ApiKeyError::RateLimitExceeded(_))
),
"Web client request {} should be denied",
i
);
}
}
}
}