#![ allow( clippy::missing_inline_in_public_items ) ]
#[ cfg( feature = "rate_limiting" ) ]
mod private
{
use crate::
{
error ::Result,
};
use core::time::Duration;
use std::
{
sync ::{ Arc, Mutex },
time ::Instant,
collections ::VecDeque,
};
use serde::{ Serialize, Deserialize };
#[ derive( Debug, Clone, PartialEq, Serialize, Deserialize ) ]
pub enum RateLimitingAlgorithm
{
TokenBucket,
SlidingWindow,
}
#[ derive( Debug, Clone, Serialize, Deserialize ) ]
pub struct EnhancedRateLimitingConfig
{
pub max_requests : u32,
pub window_duration_ms : u64,
pub burst_capacity : u32,
pub refill_rate : f64,
pub algorithm : RateLimitingAlgorithm,
pub timeout_ms : u64,
pub per_endpoint : bool,
}
impl Default for EnhancedRateLimitingConfig
{
fn default() -> Self
{
Self
{
max_requests : 100,
window_duration_ms : 60000, burst_capacity : 10,
refill_rate : 1.66, algorithm : RateLimitingAlgorithm::TokenBucket,
timeout_ms : 5000,
per_endpoint : false,
}
}
}
impl EnhancedRateLimitingConfig
{
#[ must_use ]
pub fn new() -> Self
{
Self::default()
}
#[ must_use ]
pub fn with_max_requests( mut self, max_requests : u32 ) -> Self
{
self.max_requests = max_requests;
self
}
#[ must_use ]
pub fn with_window_duration( mut self, duration_ms : u64 ) -> Self
{
self.window_duration_ms = duration_ms;
self
}
#[ must_use ]
pub fn with_burst_capacity( mut self, capacity : u32 ) -> Self
{
self.burst_capacity = capacity;
self
}
#[ must_use ]
pub fn with_refill_rate( mut self, rate : f64 ) -> Self
{
self.refill_rate = rate;
self
}
#[ must_use ]
pub fn with_algorithm( mut self, algorithm : RateLimitingAlgorithm ) -> Self
{
self.algorithm = algorithm;
self
}
#[ must_use ]
pub fn with_timeout( mut self, timeout_ms : u64 ) -> Self
{
self.timeout_ms = timeout_ms;
self
}
#[ must_use ]
pub fn with_per_endpoint( mut self, per_endpoint : bool ) -> Self
{
self.per_endpoint = per_endpoint;
self
}
pub fn validate( &self ) -> core::result::Result< (), String >
{
if self.max_requests == 0
{
return Err( "max_requests must be greater than 0".to_string() );
}
if self.window_duration_ms == 0
{
return Err( "window_duration_ms must be greater than 0".to_string() );
}
if self.burst_capacity == 0
{
return Err( "burst_capacity must be greater than 0".to_string() );
}
if self.refill_rate <= 0.0
{
return Err( "refill_rate must be greater than 0".to_string() );
}
if self.timeout_ms == 0
{
return Err( "timeout_ms must be greater than 0".to_string() );
}
Ok( () )
}
}
#[ derive( Debug ) ]
pub struct TokenBucketState
{
pub tokens : f64,
pub last_refill : Instant,
pub total_requests : u64,
pub rate_limited_requests : u64,
}
impl TokenBucketState
{
#[ must_use ]
pub fn new( initial_tokens : f64 ) -> Self
{
Self
{
tokens : initial_tokens,
last_refill : Instant::now(),
total_requests : 0,
rate_limited_requests : 0,
}
}
pub fn refill_tokens( &mut self, refill_rate : f64, burst_capacity : f64 )
{
let now = Instant::now();
let elapsed = now.duration_since( self.last_refill ).as_secs_f64();
let tokens_to_add = elapsed * refill_rate;
self.tokens = ( self.tokens + tokens_to_add ).min( burst_capacity );
self.last_refill = now;
}
#[ must_use ]
pub fn try_consume( &mut self ) -> bool
{
self.total_requests += 1;
if self.tokens >= 1.0
{
self.tokens -= 1.0;
true
}
else
{
self.rate_limited_requests += 1;
false
}
}
pub fn reset( &mut self, initial_tokens : f64 )
{
self.tokens = initial_tokens;
self.last_refill = Instant::now();
self.total_requests = 0;
self.rate_limited_requests = 0;
}
}
#[ derive( Debug, Default ) ]
pub struct SlidingWindowState
{
pub request_timestamps : VecDeque< Instant >,
pub total_requests : u64,
pub rate_limited_requests : u64,
}
impl SlidingWindowState
{
#[ must_use ]
pub fn new() -> Self
{
Self::default()
}
pub fn cleanup_old_timestamps( &mut self, window_duration : Duration )
{
let cutoff_time = Instant::now().checked_sub( window_duration ).unwrap();
while let Some( &front_time ) = self.request_timestamps.front()
{
if front_time < cutoff_time
{
self.request_timestamps.pop_front();
}
else
{
break;
}
}
}
#[ must_use ]
pub fn try_add_request( &mut self, max_requests : u32 ) -> bool
{
self.total_requests += 1;
if self.request_timestamps.len() < max_requests as usize
{
self.request_timestamps.push_back( Instant::now() );
true
}
else
{
self.rate_limited_requests += 1;
false
}
}
pub fn reset( &mut self )
{
self.request_timestamps.clear();
self.total_requests = 0;
self.rate_limited_requests = 0;
}
}
#[ derive( Debug, Clone ) ]
pub struct EnhancedRateLimiter
{
config : EnhancedRateLimitingConfig,
token_bucket_state : Option< Arc< Mutex< TokenBucketState > > >,
sliding_window_state : Option< Arc< Mutex< SlidingWindowState > > >,
}
impl EnhancedRateLimiter
{
pub fn new( config : EnhancedRateLimitingConfig ) -> core::result::Result< Self, String >
{
config.validate()?;
let ( token_bucket_state, sliding_window_state ) = match config.algorithm
{
RateLimitingAlgorithm::TokenBucket =>
{
let state = TokenBucketState::new( f64::from( config.burst_capacity ) );
( Some( Arc::new( Mutex::new( state ) ) ), None )
},
RateLimitingAlgorithm::SlidingWindow =>
{
let state = SlidingWindowState::new();
( None, Some( Arc::new( Mutex::new( state ) ) ) )
},
};
Ok( Self
{
config,
token_bucket_state,
sliding_window_state,
} )
}
pub async fn execute< F, Fut, T >( &self, operation : F ) -> Result< T >
where
F : Fn() -> Fut,
Fut : core::future::Future< Output = Result< T > >,
{
if !self.should_allow_request()
{
return Err( error_tools::untyped::Error::msg( "Rate limit exceeded - request rejected" ) );
}
operation().await
}
fn should_allow_request( &self ) -> bool
{
match self.config.algorithm
{
RateLimitingAlgorithm::TokenBucket =>
{
if let Some( state ) = &self.token_bucket_state
{
let mut bucket = state.lock().unwrap();
bucket.refill_tokens( self.config.refill_rate, f64::from( self.config.burst_capacity ) );
bucket.try_consume()
}
else
{
true }
},
RateLimitingAlgorithm::SlidingWindow =>
{
if let Some( state ) = &self.sliding_window_state
{
let mut window = state.lock().unwrap();
window.cleanup_old_timestamps( Duration::from_millis( self.config.window_duration_ms ) );
window.try_add_request( self.config.max_requests )
}
else
{
true }
}
}
}
pub fn reset( &self )
{
match self.config.algorithm
{
RateLimitingAlgorithm::TokenBucket =>
{
if let Some( state ) = &self.token_bucket_state
{
let mut bucket = state.lock().unwrap();
bucket.reset( f64::from( self.config.burst_capacity ) );
}
},
RateLimitingAlgorithm::SlidingWindow =>
{
if let Some( state ) = &self.sliding_window_state
{
let mut window = state.lock().unwrap();
window.reset();
}
}
}
}
#[ must_use ]
pub fn get_token_bucket_state( &self ) -> Option< TokenBucketState >
{
if let Some( state ) = &self.token_bucket_state
{
let bucket = state.lock().unwrap();
Some( TokenBucketState
{
tokens : bucket.tokens,
last_refill : bucket.last_refill,
total_requests : bucket.total_requests,
rate_limited_requests : bucket.rate_limited_requests,
} )
}
else
{
None
}
}
#[ must_use ]
pub fn get_sliding_window_state( &self ) -> Option< SlidingWindowState >
{
if let Some( state ) = &self.sliding_window_state
{
let window = state.lock().unwrap();
Some( SlidingWindowState
{
request_timestamps : window.request_timestamps.clone(),
total_requests : window.total_requests,
rate_limited_requests : window.rate_limited_requests,
} )
}
else
{
None
}
}
#[ must_use ]
pub fn config( &self ) -> &EnhancedRateLimitingConfig
{
&self.config
}
}
}
#[ cfg( not( feature = "rate_limiting" ) ) ]
pub mod private
{
#[ derive( Debug, Clone ) ]
pub struct EnhancedRateLimitingConfig;
impl EnhancedRateLimitingConfig
{
#[ must_use ]
pub fn new() -> Self
{
Self
}
}
impl Default for EnhancedRateLimitingConfig
{
fn default() -> Self
{
Self
}
}
}
#[ cfg( feature = "rate_limiting" ) ]
pub use private::
{
EnhancedRateLimitingConfig,
RateLimitingAlgorithm,
TokenBucketState,
SlidingWindowState,
EnhancedRateLimiter,
};
#[ cfg( not( feature = "rate_limiting" ) ) ]
pub use private::EnhancedRateLimitingConfig;
crate ::mod_interface!
{
#[ cfg( feature = "rate_limiting" ) ]
exposed use
{
EnhancedRateLimitingConfig,
RateLimitingAlgorithm,
TokenBucketState,
SlidingWindowState,
EnhancedRateLimiter,
};
#[ cfg( not( feature = "rate_limiting" ) ) ]
exposed use
{
EnhancedRateLimitingConfig,
};
}