use std::{
num::NonZeroUsize,
sync::{Arc, Mutex},
};
use lru::LruCache;
use crate::{
error::FraiseQLError,
utils::clock::{Clock, SystemClock},
};
const MAX_RATE_LIMITER_ENTRIES: usize = 100_000;
#[derive(Debug, Clone)]
pub struct RateLimitDimension {
pub max_requests: u32,
pub window_secs: u64,
}
impl RateLimitDimension {
const fn is_rate_limited(&self) -> bool {
self.max_requests == 0
}
}
#[derive(Debug, Clone)]
pub struct ValidationRateLimitingConfig {
pub enabled: bool,
pub validation_errors_max_requests: u32,
pub validation_errors_window_secs: u64,
pub depth_errors_max_requests: u32,
pub depth_errors_window_secs: u64,
pub complexity_errors_max_requests: u32,
pub complexity_errors_window_secs: u64,
pub malformed_errors_max_requests: u32,
pub malformed_errors_window_secs: u64,
pub async_validation_errors_max_requests: u32,
pub async_validation_errors_window_secs: u64,
}
impl Default for ValidationRateLimitingConfig {
fn default() -> Self {
Self {
enabled: true,
validation_errors_max_requests: 100,
validation_errors_window_secs: 60,
depth_errors_max_requests: 50,
depth_errors_window_secs: 60,
complexity_errors_max_requests: 30,
complexity_errors_window_secs: 60,
malformed_errors_max_requests: 40,
malformed_errors_window_secs: 60,
async_validation_errors_max_requests: 60,
async_validation_errors_window_secs: 60,
}
}
}
#[derive(Debug, Clone)]
struct RequestRecord {
count: u32,
window_start: u64,
}
struct DimensionRateLimiter {
records: Arc<Mutex<LruCache<String, RequestRecord>>>,
dimension: RateLimitDimension,
clock: Arc<dyn Clock>,
}
impl DimensionRateLimiter {
#[cfg(test)]
fn new(max_requests: u32, window_secs: u64) -> Self {
Self::new_with_clock(max_requests, window_secs, Arc::new(SystemClock))
}
fn new_with_clock(max_requests: u32, window_secs: u64, clock: Arc<dyn Clock>) -> Self {
#[allow(clippy::expect_used)]
let cap = NonZeroUsize::new(MAX_RATE_LIMITER_ENTRIES)
.expect("MAX_RATE_LIMITER_ENTRIES must be > 0");
Self {
records: Arc::new(Mutex::new(LruCache::new(cap))),
dimension: RateLimitDimension {
max_requests,
window_secs,
},
clock,
}
}
fn check(&self, key: &str) -> Result<(), FraiseQLError> {
if self.dimension.is_rate_limited() {
return Ok(());
}
let mut records = self.records.lock().expect("records mutex poisoned");
let now = self.clock.now_secs();
let record = records.get_or_insert_mut(key.to_string(), || RequestRecord {
count: 0,
window_start: now,
});
if now >= record.window_start + self.dimension.window_secs {
record.count = 1;
record.window_start = now;
Ok(())
} else if record.count < self.dimension.max_requests {
record.count += 1;
Ok(())
} else {
Err(FraiseQLError::RateLimited {
message: "Rate limit exceeded for validation errors".to_string(),
retry_after_secs: self.dimension.window_secs,
})
}
}
fn clear(&self) {
let mut records = self.records.lock().expect("records mutex poisoned");
records.clear();
}
}
impl Clone for DimensionRateLimiter {
fn clone(&self) -> Self {
Self {
records: Arc::clone(&self.records),
dimension: self.dimension.clone(),
clock: Arc::clone(&self.clock),
}
}
}
#[derive(Clone)]
#[allow(clippy::module_name_repetitions, clippy::struct_field_names)] pub struct ValidationRateLimiter {
validation_errors: DimensionRateLimiter,
depth_errors: DimensionRateLimiter,
complexity_errors: DimensionRateLimiter,
malformed_errors: DimensionRateLimiter,
async_validation_errors: DimensionRateLimiter,
}
impl ValidationRateLimiter {
pub fn new(config: &ValidationRateLimitingConfig) -> Self {
Self::new_with_clock(config, Arc::new(SystemClock))
}
pub fn new_with_clock(config: &ValidationRateLimitingConfig, clock: Arc<dyn Clock>) -> Self {
Self {
validation_errors: DimensionRateLimiter::new_with_clock(
config.validation_errors_max_requests,
config.validation_errors_window_secs,
Arc::clone(&clock),
),
depth_errors: DimensionRateLimiter::new_with_clock(
config.depth_errors_max_requests,
config.depth_errors_window_secs,
Arc::clone(&clock),
),
complexity_errors: DimensionRateLimiter::new_with_clock(
config.complexity_errors_max_requests,
config.complexity_errors_window_secs,
Arc::clone(&clock),
),
malformed_errors: DimensionRateLimiter::new_with_clock(
config.malformed_errors_max_requests,
config.malformed_errors_window_secs,
Arc::clone(&clock),
),
async_validation_errors: DimensionRateLimiter::new_with_clock(
config.async_validation_errors_max_requests,
config.async_validation_errors_window_secs,
clock,
),
}
}
pub fn check_validation_errors(&self, key: &str) -> Result<(), FraiseQLError> {
self.validation_errors.check(key)
}
pub fn check_depth_errors(&self, key: &str) -> Result<(), FraiseQLError> {
self.depth_errors.check(key)
}
pub fn check_complexity_errors(&self, key: &str) -> Result<(), FraiseQLError> {
self.complexity_errors.check(key)
}
pub fn check_malformed_errors(&self, key: &str) -> Result<(), FraiseQLError> {
self.malformed_errors.check(key)
}
pub fn check_async_validation_errors(&self, key: &str) -> Result<(), FraiseQLError> {
self.async_validation_errors.check(key)
}
pub fn clear(&self) {
self.validation_errors.clear();
self.depth_errors.clear();
self.complexity_errors.clear();
self.malformed_errors.clear();
self.async_validation_errors.clear();
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_dimension_rate_limiter_allows_within_limit() {
let limiter = DimensionRateLimiter::new(3, 60);
limiter.check("key").unwrap_or_else(|e| panic!("expected Ok on request 1: {e}"));
limiter.check("key").unwrap_or_else(|e| panic!("expected Ok on request 2: {e}"));
limiter.check("key").unwrap_or_else(|e| panic!("expected Ok on request 3: {e}"));
}
#[test]
fn test_dimension_rate_limiter_rejects_over_limit() {
let limiter = DimensionRateLimiter::new(2, 60);
limiter.check("key").unwrap_or_else(|e| panic!("expected Ok on request 1: {e}"));
limiter.check("key").unwrap_or_else(|e| panic!("expected Ok on request 2: {e}"));
assert!(
matches!(limiter.check("key"), Err(FraiseQLError::RateLimited { .. })),
"expected RateLimited error on request 3, got: {:?}",
limiter.check("key")
);
}
#[test]
fn test_dimension_rate_limiter_per_key() {
let limiter = DimensionRateLimiter::new(2, 60);
limiter
.check("key1")
.unwrap_or_else(|e| panic!("expected Ok for key1 request 1: {e}"));
limiter
.check("key1")
.unwrap_or_else(|e| panic!("expected Ok for key1 request 2: {e}"));
limiter
.check("key2")
.unwrap_or_else(|e| panic!("expected Ok for key2 request 1 (independent key): {e}"));
}
#[test]
fn test_dimension_rate_limiter_clear() {
let limiter = DimensionRateLimiter::new(1, 60);
limiter.check("key").unwrap_or_else(|e| panic!("expected Ok before limit: {e}"));
assert!(
matches!(limiter.check("key"), Err(FraiseQLError::RateLimited { .. })),
"expected RateLimited error at limit, got: {:?}",
limiter.check("key")
);
limiter.clear();
limiter.check("key").unwrap_or_else(|e| panic!("expected Ok after clear: {e}"));
}
#[test]
fn test_config_defaults() {
let config = ValidationRateLimitingConfig::default();
assert!(config.enabled);
assert!(config.validation_errors_max_requests > 0);
assert!(config.depth_errors_max_requests > 0);
assert!(config.complexity_errors_max_requests > 0);
assert!(config.malformed_errors_max_requests > 0);
assert!(config.async_validation_errors_max_requests > 0);
}
#[test]
fn test_validation_limiter_independent_dimensions() {
let config = ValidationRateLimitingConfig::default();
let limiter = ValidationRateLimiter::new(&config);
let key = "test-key";
for _ in 0..100 {
let _ = limiter.check_validation_errors(key);
}
assert!(
matches!(limiter.check_validation_errors(key), Err(FraiseQLError::RateLimited { .. })),
"expected RateLimited after exhausting validation_errors quota"
);
limiter
.check_depth_errors(key)
.unwrap_or_else(|e| panic!("depth_errors should still allow: {e}"));
limiter
.check_complexity_errors(key)
.unwrap_or_else(|e| panic!("complexity_errors should still allow: {e}"));
limiter
.check_malformed_errors(key)
.unwrap_or_else(|e| panic!("malformed_errors should still allow: {e}"));
limiter
.check_async_validation_errors(key)
.unwrap_or_else(|e| panic!("async_validation_errors should still allow: {e}"));
}
#[test]
fn test_validation_limiter_clone_shares_state() {
let config = ValidationRateLimitingConfig::default();
let limiter1 = ValidationRateLimiter::new(&config);
let limiter2 = limiter1.clone();
let key = "shared-key";
for _ in 0..100 {
let _ = limiter1.check_validation_errors(key);
}
assert!(
matches!(limiter2.check_validation_errors(key), Err(FraiseQLError::RateLimited { .. })),
"cloned limiter should share rate limit state"
);
}
#[test]
fn test_window_rollover_does_not_leak_across_windows() {
use std::time::Duration;
use crate::utils::clock::ManualClock;
let clock = ManualClock::new();
let clock_arc: Arc<dyn Clock> = Arc::new(clock.clone());
let config = ValidationRateLimitingConfig {
enabled: true,
validation_errors_max_requests: 2,
validation_errors_window_secs: 60,
..ValidationRateLimitingConfig::default()
};
let limiter = ValidationRateLimiter::new_with_clock(&config, clock_arc);
limiter
.check_validation_errors("u1")
.unwrap_or_else(|e| panic!("expected Ok on 1st request: {e}")); limiter
.check_validation_errors("u1")
.unwrap_or_else(|e| panic!("expected Ok on 2nd request: {e}")); assert!(
matches!(limiter.check_validation_errors("u1"), Err(FraiseQLError::RateLimited { .. })),
"expected RateLimited on 3rd request (over limit)"
);
clock.advance(Duration::from_secs(61));
limiter
.check_validation_errors("u1")
.unwrap_or_else(|e| panic!("expected Ok after window rollover: {e}")); }
#[test]
fn test_window_exact_boundary_triggers_rollover() {
use std::time::Duration;
use crate::utils::clock::ManualClock;
let clock = ManualClock::new();
let clock_arc: Arc<dyn Clock> = Arc::new(clock.clone());
let window_secs = 60u64;
let max = 2u32;
let limiter = DimensionRateLimiter::new_with_clock(max, window_secs, clock_arc);
for _ in 0..max {
limiter.check("u").unwrap_or_else(|e| panic!("expected Ok filling window: {e}"));
}
assert!(
matches!(limiter.check("u"), Err(FraiseQLError::RateLimited { .. })),
"expected RateLimited when over limit"
);
clock.advance(Duration::from_secs(window_secs));
limiter
.check("u")
.unwrap_or_else(|e| panic!("expected Ok at exact window boundary (>= not >): {e}"));
}
#[test]
fn test_max_requests_zero_disables_limiter() {
let limiter = DimensionRateLimiter::new(0, 60);
for i in 0..10u32 {
limiter
.check("key")
.unwrap_or_else(|e| panic!("expected Ok with max_requests=0 on request {i}: {e}"));
}
}
#[test]
fn test_window_secs_zero_does_not_panic() {
use crate::utils::clock::ManualClock;
let clock_arc: Arc<dyn Clock> = Arc::new(ManualClock::new());
let limiter = DimensionRateLimiter::new_with_clock(5, 0, clock_arc);
limiter
.check("key")
.unwrap_or_else(|e| panic!("expected Ok with window_secs=0 (1st): {e}"));
limiter
.check("key")
.unwrap_or_else(|e| panic!("expected Ok with window_secs=0 (2nd): {e}"));
limiter
.check("key")
.unwrap_or_else(|e| panic!("expected Ok with window_secs=0 (3rd): {e}"));
}
}