use governor::{
clock::{Clock, DefaultClock},
state::{InMemoryState, NotKeyed},
Quota, RateLimiter as Governor,
};
use std::collections::HashMap;
use std::num::NonZeroU32;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::RwLock;
#[derive(
Debug,
Clone,
Copy,
PartialEq,
Eq,
Hash,
serde::Serialize,
serde::Deserialize,
utoipa::ToSchema,
Default,
)]
pub enum RateLimitTier {
#[default]
Free,
Standard,
Pro,
Unlimited,
}
impl RateLimitTier {
pub fn quota(&self) -> Option<Quota> {
match self {
Self::Free => {
let limit = std::env::var("VEX_LIMIT_FREE")
.ok()
.and_then(|v| v.parse().ok())
.unwrap_or(10);
Some(Quota::per_minute(NonZeroU32::new(limit).unwrap()))
}
Self::Standard => {
let limit = std::env::var("VEX_LIMIT_STANDARD")
.ok()
.and_then(|v| v.parse().ok())
.unwrap_or(100);
Some(Quota::per_minute(NonZeroU32::new(limit).unwrap()))
}
Self::Pro => {
let limit = std::env::var("VEX_LIMIT_PRO")
.ok()
.and_then(|v| v.parse().ok())
.unwrap_or(1000);
Some(Quota::per_minute(NonZeroU32::new(limit).unwrap()))
}
Self::Unlimited => None, }
}
}
type TenantLimiter = Governor<NotKeyed, InMemoryState, DefaultClock>;
#[derive(Debug)]
pub struct TenantRateLimiter {
limiters: RwLock<HashMap<String, Arc<TenantLimiter>>>,
default_tier: RateLimitTier,
tier_assignments: RwLock<HashMap<String, RateLimitTier>>,
}
impl Default for TenantRateLimiter {
fn default() -> Self {
Self::new(RateLimitTier::Free)
}
}
impl TenantRateLimiter {
pub fn new(default_tier: RateLimitTier) -> Self {
Self {
limiters: RwLock::new(HashMap::new()),
default_tier,
tier_assignments: RwLock::new(HashMap::new()),
}
}
pub async fn set_tier(&self, tenant_id: &str, tier: RateLimitTier) {
let mut assignments = self.tier_assignments.write().await;
assignments.insert(tenant_id.to_string(), tier);
let mut limiters = self.limiters.write().await;
limiters.remove(tenant_id);
}
pub async fn get_tier(&self, tenant_id: &str) -> RateLimitTier {
let assignments = self.tier_assignments.read().await;
assignments
.get(tenant_id)
.copied()
.unwrap_or(self.default_tier)
}
pub async fn check(&self, tenant_id: &str) -> Result<(), Duration> {
if tenant_id.trim().is_empty() {
return Err(Duration::from_secs(3600)); }
let tier = self.get_tier(tenant_id).await;
let quota = match tier.quota() {
Some(q) => q,
None => return Ok(()),
};
let limiter = self.get_or_create_limiter(tenant_id, quota).await;
match limiter.check() {
Ok(_) => Ok(()),
Err(not_until) => {
let wait = not_until.wait_time_from(DefaultClock::default().now());
Err(wait)
}
}
}
async fn get_or_create_limiter(&self, tenant_id: &str, quota: Quota) -> Arc<TenantLimiter> {
{
let limiters = self.limiters.read().await;
if let Some(limiter) = limiters.get(tenant_id) {
return limiter.clone();
}
}
let mut limiters = self.limiters.write().await;
if let Some(limiter) = limiters.get(tenant_id) {
return limiter.clone();
}
let limiter = Arc::new(Governor::direct(quota));
limiters.insert(tenant_id.to_string(), limiter.clone());
limiter
}
pub async fn cleanup(&self) {
let limiters = self.limiters.write().await;
tracing::debug!(limiter_count = limiters.len(), "Tenant limiter cleanup");
let _ = limiters; }
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_rate_limiter_allows_within_quota() {
let limiter = TenantRateLimiter::new(RateLimitTier::Standard);
for _ in 0..10 {
assert!(limiter.check("tenant1").await.is_ok());
}
}
#[tokio::test]
async fn test_rate_limiter_blocks_over_quota() {
let limiter = TenantRateLimiter::new(RateLimitTier::Free);
for _ in 0..10 {
let _ = limiter.check("tenant1").await;
}
let result = limiter.check("tenant1").await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_different_tenants_independent() {
let limiter = TenantRateLimiter::new(RateLimitTier::Free);
for _ in 0..15 {
let _ = limiter.check("tenant1").await;
}
assert!(limiter.check("tenant2").await.is_ok());
}
#[tokio::test]
async fn test_unlimited_tier() {
let limiter = TenantRateLimiter::new(RateLimitTier::Free);
limiter.set_tier("vip", RateLimitTier::Unlimited).await;
for _ in 0..1000 {
assert!(limiter.check("vip").await.is_ok());
}
}
}