use crate::utils::error::gateway_error::{GatewayError, Result};
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::RwLock;
use super::types::{LimiterConfig, RateLimitKey, RateLimitResult, SlidingWindow, TokenBucket};
#[derive(Debug, Clone)]
pub struct RateLimiter {
pub(super) configs: Arc<RwLock<HashMap<String, LimiterConfig>>>,
pub(super) buckets: Arc<RwLock<HashMap<String, TokenBucket>>>,
pub(super) windows: Arc<RwLock<HashMap<String, SlidingWindow>>>,
}
impl RateLimiter {
pub fn new() -> Self {
Self {
configs: Arc::new(RwLock::new(HashMap::new())),
buckets: Arc::new(RwLock::new(HashMap::new())),
windows: Arc::new(RwLock::new(HashMap::new())),
}
}
pub async fn add_config(&self, key: String, config: LimiterConfig) {
let mut configs = self.configs.write().await;
configs.insert(key, config);
}
pub async fn check_rate_limit(
&self,
key: &RateLimitKey,
tokens: u32,
) -> Result<RateLimitResult> {
let key_str = self.build_key_string(key);
let configs = self.configs.read().await;
let config = configs
.get(&key_str)
.or_else(|| configs.get("default"))
.ok_or_else(|| GatewayError::Config("No rate limit config found".to_string()))?;
let mut result = RateLimitResult {
allowed: true,
remaining_requests: None,
remaining_tokens: None,
reset_time: None,
retry_after: None,
limit_type: None,
};
if let Some(rpm) = config.rpm {
let rpm_result = self
.check_sliding_window_requests(
&format!("{}_rpm", key_str),
Duration::from_secs(60),
rpm,
)
.await?;
if !rpm_result.allowed {
result.allowed = false;
result.limit_type = Some("rpm".to_string());
result.retry_after = rpm_result.retry_after;
return Ok(result);
}
result.remaining_requests = rpm_result.remaining_requests;
}
if let Some(tpm) = config.tpm {
let tpm_result = self
.check_sliding_window_tokens(
&format!("{}_tpm", key_str),
Duration::from_secs(60),
tpm,
tokens,
)
.await?;
if !tpm_result.allowed {
result.allowed = false;
result.limit_type = Some("tpm".to_string());
result.retry_after = tpm_result.retry_after;
return Ok(result);
}
result.remaining_tokens = tpm_result.remaining_tokens;
}
if let Some(rpd) = config.rpd {
let rpd_result = self
.check_sliding_window_requests(
&format!("{}_rpd", key_str),
Duration::from_secs(86400), rpd,
)
.await?;
if !rpd_result.allowed {
result.allowed = false;
result.limit_type = Some("rpd".to_string());
result.retry_after = rpd_result.retry_after;
return Ok(result);
}
}
if let Some(tpd) = config.tpd {
let tpd_result = self
.check_sliding_window_tokens(
&format!("{}_tpd", key_str),
Duration::from_secs(86400), tpd,
tokens,
)
.await?;
if !tpd_result.allowed {
result.allowed = false;
result.limit_type = Some("tpd".to_string());
result.retry_after = tpd_result.retry_after;
return Ok(result);
}
}
if result.allowed {
self.record_request(&key_str, tokens).await?;
}
Ok(result)
}
}
impl Default for RateLimiter {
fn default() -> Self {
Self::new()
}
}