#[ cfg( feature = "rate_limiting" ) ]
mod private
{
use core::time::Duration;
use std::collections::VecDeque;
use std::sync::{ Arc, Mutex };
use std::time::Instant;
use error_tools::untyped::{ format_err, Result };
#[ derive( Debug, Clone ) ]
pub struct RateLimitingConfig
{
max_requests : u32,
window_duration_ms : u64,
burst_capacity : u32,
refill_rate : f64,
algorithm : RateLimitingAlgorithm,
timeout_ms : u64,
per_endpoint : bool,
}
#[ derive( Debug, Clone, PartialEq ) ]
pub enum RateLimitingAlgorithm
{
TokenBucket,
SlidingWindow,
}
#[ derive( Debug, Clone ) ]
pub struct RateLimiter
{
config : RateLimitingConfig,
token_bucket_state : Option< Arc< Mutex< TokenBucketState > > >,
sliding_window_state : Option< Arc< Mutex< SlidingWindowState > > >,
}
#[ derive( Debug ) ]
struct TokenBucketState
{
tokens : f64,
last_refill : Instant,
total_requests : u64,
rate_limited_requests : u64,
}
#[ derive( Debug ) ]
struct SlidingWindowState
{
request_timestamps : VecDeque< Instant >,
total_requests : u64,
rate_limited_requests : u64,
}
impl TokenBucketState
{
#[ inline ]
fn new( initial_tokens : f64 ) -> Self
{
Self
{
tokens : initial_tokens,
last_refill : Instant::now(),
total_requests : 0,
rate_limited_requests : 0,
}
}
}
impl SlidingWindowState
{
#[ inline ]
fn new() -> Self
{
Self
{
request_timestamps : VecDeque::new(),
total_requests : 0,
rate_limited_requests : 0,
}
}
}
impl RateLimitingConfig
{
#[ inline ]
#[ must_use ]
pub fn new() -> Self
{
Self
{
max_requests : 100,
window_duration_ms : 60000, burst_capacity : 10,
refill_rate : 1.67, algorithm : RateLimitingAlgorithm::TokenBucket,
timeout_ms : 5000,
per_endpoint : false,
}
}
#[ inline ]
#[ must_use ]
pub fn with_max_requests( mut self, max_requests : u32 ) -> Self
{
self.max_requests = max_requests;
self
}
#[ inline ]
#[ must_use ]
pub fn with_window_duration( mut self, duration_ms : u64 ) -> Self
{
self.window_duration_ms = duration_ms;
self
}
#[ inline ]
#[ must_use ]
pub fn with_burst_capacity( mut self, capacity : u32 ) -> Self
{
self.burst_capacity = capacity;
self
}
#[ inline ]
#[ must_use ]
pub fn with_refill_rate( mut self, rate : f64 ) -> Self
{
self.refill_rate = rate;
self
}
#[ inline ]
#[ must_use ]
pub fn with_algorithm( mut self, algorithm : RateLimitingAlgorithm ) -> Self
{
self.algorithm = algorithm;
self
}
#[ inline ]
#[ must_use ]
pub fn with_timeout( mut self, timeout_ms : u64 ) -> Self
{
self.timeout_ms = timeout_ms;
self
}
#[ inline ]
#[ must_use ]
pub fn with_per_endpoint( mut self, per_endpoint : bool ) -> Self
{
self.per_endpoint = per_endpoint;
self
}
#[ inline ]
pub fn validate( &self ) -> std::result::Result< (), String >
{
if self.max_requests == 0
{
return Err( "max_requests must be greater than 0".to_string() );
}
if self.window_duration_ms == 0
{
return Err( "window_duration_ms must be greater than 0".to_string() );
}
if self.burst_capacity == 0
{
return Err( "burst_capacity must be greater than 0".to_string() );
}
if self.refill_rate <= 0.0
{
return Err( "refill_rate must be greater than 0".to_string() );
}
Ok( () )
}
#[ inline ]
#[ must_use ]
pub fn max_requests( &self ) -> u32
{
self.max_requests
}
#[ inline ]
#[ must_use ]
pub fn window_duration_ms( &self ) -> u64
{
self.window_duration_ms
}
#[ inline ]
#[ must_use ]
pub fn burst_capacity( &self ) -> u32
{
self.burst_capacity
}
#[ inline ]
#[ must_use ]
pub fn refill_rate( &self ) -> f64
{
self.refill_rate
}
#[ inline ]
#[ must_use ]
pub fn algorithm( &self ) -> &RateLimitingAlgorithm
{
&self.algorithm
}
#[ inline ]
#[ must_use ]
pub fn timeout_ms( &self ) -> u64
{
self.timeout_ms
}
#[ inline ]
#[ must_use ]
pub fn is_per_endpoint( &self ) -> bool
{
self.per_endpoint
}
}
impl Default for RateLimitingConfig
{
#[ inline ]
fn default() -> Self
{
Self::new()
}
}
impl RateLimiter
{
#[ inline ]
pub fn new( config : RateLimitingConfig ) -> Result< Self >
{
config.validate().map_err( |e| format_err!( "Rate limiting configuration validation failed : {}", e ) )?;
let ( token_bucket_state, sliding_window_state ) = match config.algorithm
{
RateLimitingAlgorithm::TokenBucket =>
{
let state = Arc::new( Mutex::new( TokenBucketState::new( config.burst_capacity as f64 ) ) );
( Some( state ), None )
},
RateLimitingAlgorithm::SlidingWindow =>
{
let state = Arc::new( Mutex::new( SlidingWindowState::new() ) );
( None, Some( state ) )
},
};
Ok( Self
{
config,
token_bucket_state,
sliding_window_state,
} )
}
#[ inline ]
#[ must_use ]
pub fn config( &self ) -> &RateLimitingConfig
{
&self.config
}
#[ inline ]
#[ must_use ]
pub fn should_allow_request( &self ) -> bool
{
match self.config.algorithm
{
RateLimitingAlgorithm::TokenBucket =>
{
if let Some( ref state ) = self.token_bucket_state
{
let mut bucket = state.lock().unwrap();
let now = Instant::now();
let elapsed = now.duration_since( bucket.last_refill ).as_secs_f64();
let tokens_to_add = elapsed * self.config.refill_rate;
bucket.tokens = ( bucket.tokens + tokens_to_add ).min( self.config.burst_capacity as f64 );
bucket.last_refill = now;
bucket.total_requests += 1;
if bucket.tokens >= 1.0
{
bucket.tokens -= 1.0;
true
}
else
{
bucket.rate_limited_requests += 1;
false
}
}
else
{
true
}
},
RateLimitingAlgorithm::SlidingWindow =>
{
if let Some( ref state ) = self.sliding_window_state
{
let mut window = state.lock().unwrap();
let now = Instant::now();
let window_duration = Duration::from_millis( self.config.window_duration_ms );
while let Some( ×tamp ) = window.request_timestamps.front()
{
if now.duration_since( timestamp ) > window_duration
{
window.request_timestamps.pop_front();
}
else
{
break;
}
}
window.total_requests += 1;
if window.request_timestamps.len() < self.config.max_requests as usize
{
window.request_timestamps.push_back( now );
true
}
else
{
window.rate_limited_requests += 1;
false
}
}
else
{
true
}
},
}
}
#[ inline ]
pub fn reset( &self )
{
match self.config.algorithm
{
RateLimitingAlgorithm::TokenBucket =>
{
if let Some( ref state ) = self.token_bucket_state
{
let mut bucket = state.lock().unwrap();
*bucket = TokenBucketState::new( self.config.burst_capacity as f64 );
}
},
RateLimitingAlgorithm::SlidingWindow =>
{
if let Some( ref state ) = self.sliding_window_state
{
let mut window = state.lock().unwrap();
*window = SlidingWindowState::new();
}
}
}
}
}
}
#[ cfg( feature = "rate_limiting" ) ]
crate ::mod_interface!
{
exposed use private::RateLimiter;
exposed use private::RateLimitingConfig;
exposed use private::RateLimitingAlgorithm;
}