use std::collections::HashMap;
use std::net::IpAddr;
use std::sync::Arc;
use std::time::{Duration, Instant};
use dashmap::DashMap;
use parking_lot::RwLock;
use super::config::{ExceededAction, PriorityLevel, RateLimitConfig};
use super::concurrency::ConcurrencyLimiter;
use super::cost_estimator::QueryCostEstimator;
use super::metrics::RateLimitMetrics;
use super::sliding_window::{SlidingWindow, SlidingWindowExceeded};
use super::token_bucket::{TokenBucket, TokenBucketExceeded};
#[derive(Debug, Clone, Hash, Eq, PartialEq)]
pub enum LimiterKey {
Global,
User(String),
ClientIp(IpAddr),
Database(String),
Tenant(String),
QueryPattern(String),
Role(String),
Composite(Vec<LimiterKey>),
}
impl LimiterKey {
pub fn user(name: impl Into<String>) -> Self {
Self::User(name.into())
}
pub fn database(name: impl Into<String>) -> Self {
Self::Database(name.into())
}
pub fn tenant(id: impl Into<String>) -> Self {
Self::Tenant(id.into())
}
pub fn pattern(pattern: impl Into<String>) -> Self {
Self::QueryPattern(pattern.into())
}
pub fn composite(keys: Vec<LimiterKey>) -> Self {
Self::Composite(keys)
}
}
impl std::fmt::Display for LimiterKey {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
LimiterKey::Global => write!(f, "global"),
LimiterKey::User(u) => write!(f, "user:{}", u),
LimiterKey::ClientIp(ip) => write!(f, "ip:{}", ip),
LimiterKey::Database(d) => write!(f, "db:{}", d),
LimiterKey::Tenant(t) => write!(f, "tenant:{}", t),
LimiterKey::QueryPattern(p) => write!(f, "pattern:{}", p),
LimiterKey::Role(r) => write!(f, "role:{}", r),
LimiterKey::Composite(keys) => {
let parts: Vec<_> = keys.iter().map(|k| k.to_string()).collect();
write!(f, "composite:[{}]", parts.join(","))
}
}
}
}
#[derive(Debug, Clone)]
pub enum RateLimitResult {
Allowed,
Queued(Duration),
Throttled(Duration),
Warned(String),
Denied(RateLimitExceeded),
}
impl RateLimitResult {
pub fn is_allowed(&self) -> bool {
!matches!(self, RateLimitResult::Denied(_))
}
pub fn wait_duration(&self) -> Option<Duration> {
match self {
RateLimitResult::Queued(d) | RateLimitResult::Throttled(d) => Some(*d),
_ => None,
}
}
}
#[derive(Debug, Clone)]
pub struct RateLimitExceeded {
pub key: LimiterKey,
pub limit_type: LimitType,
pub current: u64,
pub limit: u64,
pub retry_after: Duration,
pub message: String,
}
impl std::fmt::Display for RateLimitExceeded {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"{}: {} exceeded for {} ({}/{}), retry after {}ms",
self.message,
self.limit_type,
self.key,
self.current,
self.limit,
self.retry_after.as_millis()
)
}
}
impl std::error::Error for RateLimitExceeded {}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum LimitType {
TokenBucket,
SlidingWindow,
Concurrency,
}
impl std::fmt::Display for LimitType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
LimitType::TokenBucket => write!(f, "qps"),
LimitType::SlidingWindow => write!(f, "window"),
LimitType::Concurrency => write!(f, "concurrency"),
}
}
}
pub struct RateLimiter {
config: RwLock<RateLimitConfig>,
token_buckets: DashMap<LimiterKey, TokenBucket>,
sliding_windows: DashMap<LimiterKey, SlidingWindow>,
concurrency: DashMap<LimiterKey, Arc<ConcurrencyLimiter>>,
cost_estimator: QueryCostEstimator,
metrics: Arc<RateLimitMetrics>,
created_at: Instant,
}
impl RateLimiter {
pub fn new(config: RateLimitConfig) -> Self {
Self {
config: RwLock::new(config),
token_buckets: DashMap::new(),
sliding_windows: DashMap::new(),
concurrency: DashMap::new(),
cost_estimator: QueryCostEstimator::new(),
metrics: Arc::new(RateLimitMetrics::new()),
created_at: Instant::now(),
}
}
pub fn with_cost_estimator(config: RateLimitConfig, estimator: QueryCostEstimator) -> Self {
Self {
config: RwLock::new(config),
token_buckets: DashMap::new(),
sliding_windows: DashMap::new(),
concurrency: DashMap::new(),
cost_estimator: estimator,
metrics: Arc::new(RateLimitMetrics::new()),
created_at: Instant::now(),
}
}
pub fn check(&self, key: &LimiterKey, cost: u32) -> RateLimitResult {
self.check_with_priority(key, cost, PriorityLevel::Normal)
}
pub fn check_with_priority(
&self,
key: &LimiterKey,
cost: u32,
priority: PriorityLevel,
) -> RateLimitResult {
let config = self.config.read();
if !config.enabled {
return RateLimitResult::Allowed;
}
let start = Instant::now();
if let Err(exceeded) = self.check_token_bucket(key, cost, priority, &config) {
let result = self.handle_exceeded(key, exceeded, &config);
self.metrics.record_decision(key, &result, start.elapsed());
return result;
}
if let Err(exceeded) = self.check_sliding_window(key, cost, &config) {
let result = self.handle_exceeded_window(key, exceeded, &config);
self.metrics.record_decision(key, &result, start.elapsed());
return result;
}
self.metrics.record_decision(key, &RateLimitResult::Allowed, start.elapsed());
RateLimitResult::Allowed
}
pub fn check_concurrency(&self, key: &LimiterKey) -> Result<Arc<ConcurrencyLimiter>, RateLimitExceeded> {
let config = self.config.read();
if !config.enabled {
return Ok(Arc::new(ConcurrencyLimiter::new(u32::MAX)));
}
let max = config.effective_concurrency(key, PriorityLevel::Normal);
let limiter = self
.concurrency
.entry(key.clone())
.or_insert_with(|| Arc::new(ConcurrencyLimiter::new(max)))
.clone();
if limiter.at_capacity() {
return Err(RateLimitExceeded {
key: key.clone(),
limit_type: LimitType::Concurrency,
current: limiter.active_count() as u64,
limit: max as u64,
retry_after: Duration::from_millis(100), message: "Concurrency limit exceeded".to_string(),
});
}
Ok(limiter)
}
pub fn check_query(&self, key: &LimiterKey, query: &str) -> RateLimitResult {
self.check_query_with_priority(key, query, PriorityLevel::Normal)
}
pub fn check_query_with_priority(
&self,
key: &LimiterKey,
query: &str,
priority: PriorityLevel,
) -> RateLimitResult {
let config = self.config.read();
let cost = if config.cost_estimation_enabled {
self.cost_estimator.estimate_cost_with_hint(query)
} else {
1
};
drop(config);
self.check_with_priority(key, cost, priority)
}
pub fn check_all(&self, keys: &[LimiterKey], cost: u32) -> RateLimitResult {
for key in keys {
let result = self.check(key, cost);
if !result.is_allowed() {
return result;
}
}
RateLimitResult::Allowed
}
pub fn reset(&self, key: &LimiterKey) {
if let Some(bucket) = self.token_buckets.get(key) {
bucket.reset();
}
if let Some(window) = self.sliding_windows.get(key) {
window.reset();
}
if let Some(limiter) = self.concurrency.get(key) {
limiter.reset_stats();
}
self.metrics.reset_key(key);
}
pub fn get_key_stats(&self, key: &LimiterKey) -> HashMap<String, u64> {
let mut stats = HashMap::new();
if let Some(bucket) = self.token_buckets.get(key) {
stats.insert("tokens_available".to_string(), bucket.current_tokens() as u64);
stats.insert("bucket_capacity".to_string(), bucket.capacity() as u64);
}
if let Some(window) = self.sliding_windows.get(key) {
stats.insert("window_count".to_string(), window.current_count() as u64);
stats.insert("window_max".to_string(), window.max_events() as u64);
}
if let Some(limiter) = self.concurrency.get(key) {
stats.insert("active_concurrent".to_string(), limiter.active_count() as u64);
stats.insert("max_concurrent".to_string(), limiter.max_concurrent() as u64);
stats.insert("queued".to_string(), limiter.queue_length() as u64);
}
stats
}
pub fn metrics(&self) -> Arc<RateLimitMetrics> {
Arc::clone(&self.metrics)
}
pub fn uptime(&self) -> Duration {
self.created_at.elapsed()
}
pub fn update_config(&self, config: RateLimitConfig) {
*self.config.write() = config;
}
pub fn config(&self) -> RateLimitConfig {
self.config.read().clone()
}
fn check_token_bucket(
&self,
key: &LimiterKey,
cost: u32,
priority: PriorityLevel,
config: &RateLimitConfig,
) -> Result<(), TokenBucketExceeded> {
let qps = config.effective_qps(key, priority);
let burst = config.effective_burst(key, priority);
let bucket = self
.token_buckets
.entry(key.clone())
.or_insert_with(|| TokenBucket::from_qps(qps, burst));
bucket.try_acquire(cost)
}
fn check_sliding_window(
&self,
key: &LimiterKey,
cost: u32,
_config: &RateLimitConfig,
) -> Result<(), SlidingWindowExceeded> {
let window = self
.sliding_windows
.entry(key.clone())
.or_insert_with(|| SlidingWindow::per_minute(60_000));
window.try_record_n(cost)
}
fn handle_exceeded(
&self,
key: &LimiterKey,
exceeded: TokenBucketExceeded,
config: &RateLimitConfig,
) -> RateLimitResult {
let error = RateLimitExceeded {
key: key.clone(),
limit_type: LimitType::TokenBucket,
current: exceeded.current_tokens as u64,
limit: exceeded.requested_tokens as u64,
retry_after: exceeded.retry_after,
message: "QPS rate limit exceeded".to_string(),
};
self.apply_action(&config.action_for_key(key), error)
}
fn handle_exceeded_window(
&self,
key: &LimiterKey,
exceeded: SlidingWindowExceeded,
config: &RateLimitConfig,
) -> RateLimitResult {
let error = RateLimitExceeded {
key: key.clone(),
limit_type: LimitType::SlidingWindow,
current: exceeded.current_count as u64,
limit: exceeded.max_count as u64,
retry_after: exceeded.retry_after,
message: "Window rate limit exceeded".to_string(),
};
self.apply_action(&config.action_for_key(key), error)
}
fn apply_action(&self, action: &ExceededAction, error: RateLimitExceeded) -> RateLimitResult {
match action {
ExceededAction::Reject => RateLimitResult::Denied(error),
ExceededAction::Queue { max_wait } => {
let wait = error.retry_after.min(*max_wait);
RateLimitResult::Queued(wait)
}
ExceededAction::Throttle { delay } => {
RateLimitResult::Throttled(*delay)
}
ExceededAction::Warn => {
RateLimitResult::Warned(format!("Rate limit warning: {}", error))
}
}
}
pub fn cleanup(&self) {
let mut config = self.config.write();
config.cleanup_expired();
}
}
impl std::fmt::Debug for RateLimiter {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("RateLimiter")
.field("enabled", &self.config.read().enabled)
.field("token_buckets", &self.token_buckets.len())
.field("sliding_windows", &self.sliding_windows.len())
.field("concurrency_limiters", &self.concurrency.len())
.field("uptime", &self.uptime())
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_limiter_creation() {
let config = RateLimitConfig::default();
let limiter = RateLimiter::new(config);
assert!(limiter.uptime().as_nanos() > 0);
}
#[test]
fn test_check_allowed() {
let config = RateLimitConfig::builder()
.default_qps(100)
.default_burst(200)
.build();
let limiter = RateLimiter::new(config);
let key = LimiterKey::User("test".to_string());
let result = limiter.check(&key, 1);
assert!(result.is_allowed());
}
#[test]
fn test_check_exceeded() {
let config = RateLimitConfig::builder()
.default_qps(1)
.default_burst(1)
.exceeded_action(ExceededAction::Reject)
.build();
let limiter = RateLimiter::new(config);
let key = LimiterKey::User("test".to_string());
assert!(limiter.check(&key, 1).is_allowed());
let result = limiter.check(&key, 1);
assert!(!result.is_allowed());
}
#[test]
fn test_check_with_priority() {
let config = RateLimitConfig::builder()
.default_qps(10)
.default_burst(10)
.build();
let limiter = RateLimiter::new(config);
let key = LimiterKey::User("test".to_string());
for _ in 0..20 {
assert!(limiter.check_with_priority(&key, 1, PriorityLevel::High).is_allowed());
}
}
#[test]
fn test_check_disabled() {
let config = RateLimitConfig::builder()
.enabled(false)
.default_qps(1)
.build();
let limiter = RateLimiter::new(config);
let key = LimiterKey::User("test".to_string());
for _ in 0..100 {
assert!(limiter.check(&key, 1).is_allowed());
}
}
#[test]
fn test_check_query() {
let config = RateLimitConfig::builder()
.default_qps(100)
.default_burst(200)
.cost_estimation(true)
.build();
let limiter = RateLimiter::new(config);
let key = LimiterKey::User("test".to_string());
let result = limiter.check_query(&key, "SELECT * FROM users WHERE id = 1");
assert!(result.is_allowed());
}
#[test]
fn test_check_all_keys() {
let config = RateLimitConfig::builder()
.default_qps(100)
.default_burst(200)
.build();
let limiter = RateLimiter::new(config);
let keys = vec![
LimiterKey::User("test".to_string()),
LimiterKey::Database("db1".to_string()),
LimiterKey::Global,
];
let result = limiter.check_all(&keys, 1);
assert!(result.is_allowed());
}
#[test]
fn test_reset() {
let config = RateLimitConfig::builder()
.default_qps(1)
.default_burst(1)
.build();
let limiter = RateLimiter::new(config);
let key = LimiterKey::User("test".to_string());
assert!(limiter.check(&key, 1).is_allowed());
assert!(!limiter.check(&key, 1).is_allowed());
limiter.reset(&key);
assert!(limiter.check(&key, 1).is_allowed());
}
#[test]
fn test_get_key_stats() {
let config = RateLimitConfig::default();
let limiter = RateLimiter::new(config);
let key = LimiterKey::User("test".to_string());
let _ = limiter.check(&key, 1);
let stats = limiter.get_key_stats(&key);
assert!(stats.contains_key("tokens_available"));
assert!(stats.contains_key("bucket_capacity"));
}
#[test]
fn test_exceeded_action_queue() {
let config = RateLimitConfig::builder()
.default_qps(1)
.default_burst(1)
.exceeded_action(ExceededAction::Queue {
max_wait: Duration::from_secs(5),
})
.build();
let limiter = RateLimiter::new(config);
let key = LimiterKey::User("test".to_string());
assert!(limiter.check(&key, 1).is_allowed());
let result = limiter.check(&key, 1);
match result {
RateLimitResult::Queued(wait) => {
assert!(wait.as_secs() <= 5);
}
_ => panic!("Expected Queued result"),
}
}
#[test]
fn test_exceeded_action_warn() {
let config = RateLimitConfig::builder()
.default_qps(1)
.default_burst(1)
.exceeded_action(ExceededAction::Warn)
.build();
let limiter = RateLimiter::new(config);
let key = LimiterKey::User("test".to_string());
assert!(limiter.check(&key, 1).is_allowed());
let result = limiter.check(&key, 1);
match result {
RateLimitResult::Warned(msg) => {
assert!(msg.contains("Rate limit"));
}
_ => panic!("Expected Warned result"),
}
}
#[test]
fn test_limiter_key_display() {
assert_eq!(LimiterKey::Global.to_string(), "global");
assert_eq!(LimiterKey::User("alice".to_string()).to_string(), "user:alice");
assert_eq!(LimiterKey::Database("mydb".to_string()).to_string(), "db:mydb");
}
#[test]
fn test_update_config() {
let config = RateLimitConfig::builder()
.default_qps(100)
.build();
let limiter = RateLimiter::new(config);
assert_eq!(limiter.config().default_qps, 100);
let new_config = RateLimitConfig::builder()
.default_qps(200)
.build();
limiter.update_config(new_config);
assert_eq!(limiter.config().default_qps, 200);
}
#[test]
fn test_concurrency_check() {
let config = RateLimitConfig::builder()
.default_concurrency(10)
.build();
let limiter = RateLimiter::new(config);
let key = LimiterKey::User("test".to_string());
let result = limiter.check_concurrency(&key);
assert!(result.is_ok());
let conc_limiter = result.unwrap();
assert_eq!(conc_limiter.max_concurrent(), 10);
}
#[test]
fn test_rate_limit_result_methods() {
assert!(RateLimitResult::Allowed.is_allowed());
assert!(RateLimitResult::Queued(Duration::from_secs(1)).is_allowed());
assert!(RateLimitResult::Throttled(Duration::from_secs(1)).is_allowed());
assert!(RateLimitResult::Warned("test".to_string()).is_allowed());
let error = RateLimitExceeded {
key: LimiterKey::Global,
limit_type: LimitType::TokenBucket,
current: 0,
limit: 100,
retry_after: Duration::from_secs(1),
message: "test".to_string(),
};
assert!(!RateLimitResult::Denied(error).is_allowed());
assert_eq!(
RateLimitResult::Queued(Duration::from_secs(5)).wait_duration(),
Some(Duration::from_secs(5))
);
assert_eq!(RateLimitResult::Allowed.wait_duration(), None);
}
}