#[ cfg( feature = "rate_limiting" ) ]
mod rate_limiting_tests
{
use api_openai::
{
error ::Result,
};
use std::
{
sync ::{ Arc, Mutex },
time ::Instant,
collections ::VecDeque,
};
use core::time::Duration;
use serde::{ Serialize, Deserialize };
use tokio::time::sleep;
#[ derive( Debug, Clone, PartialEq, Serialize, Deserialize ) ]
pub enum RateLimitingAlgorithm
{
TokenBucket,
SlidingWindow,
}
#[ derive( Debug, Clone, Serialize, Deserialize ) ]
pub struct EnhancedRateLimitingConfig
{
pub max_requests : u32,
pub window_duration_ms : u64,
pub burst_capacity : u32,
pub refill_rate : f64,
pub algorithm : RateLimitingAlgorithm,
pub timeout_ms : u64,
pub per_endpoint : bool,
}
impl Default for EnhancedRateLimitingConfig
{
fn default() -> Self
{
Self
{
max_requests : 100,
window_duration_ms : 60000, burst_capacity : 10,
refill_rate : 1.66, algorithm : RateLimitingAlgorithm::TokenBucket,
timeout_ms : 5000,
per_endpoint : false,
}
}
}
impl EnhancedRateLimitingConfig
{
#[ must_use ]
pub fn new() -> Self
{
Self::default()
}
#[ must_use ]
pub fn with_max_requests( mut self, max_requests : u32 ) -> Self
{
self.max_requests = max_requests;
self
}
#[ must_use ]
pub fn with_window_duration( mut self, duration_ms : u64 ) -> Self
{
self.window_duration_ms = duration_ms;
self
}
#[ must_use ]
pub fn with_burst_capacity( mut self, capacity : u32 ) -> Self
{
self.burst_capacity = capacity;
self
}
#[ must_use ]
pub fn with_refill_rate( mut self, rate : f64 ) -> Self
{
self.refill_rate = rate;
self
}
#[ must_use ]
pub fn with_algorithm( mut self, algorithm : RateLimitingAlgorithm ) -> Self
{
self.algorithm = algorithm;
self
}
#[ must_use ]
pub fn with_timeout( mut self, timeout_ms : u64 ) -> Self
{
self.timeout_ms = timeout_ms;
self
}
#[ must_use ]
pub fn with_per_endpoint( mut self, per_endpoint : bool ) -> Self
{
self.per_endpoint = per_endpoint;
self
}
pub fn validate( &self ) -> core::result::Result< (), String >
{
if self.max_requests == 0
{
return Err( "max_requests must be greater than 0".to_string() );
}
if self.window_duration_ms == 0
{
return Err( "window_duration_ms must be greater than 0".to_string() );
}
if self.burst_capacity == 0
{
return Err( "burst_capacity must be greater than 0".to_string() );
}
if self.refill_rate <= 0.0
{
return Err( "refill_rate must be greater than 0".to_string() );
}
if self.timeout_ms == 0
{
return Err( "timeout_ms must be greater than 0".to_string() );
}
Ok( () )
}
}
#[ derive( Debug ) ]
pub struct TokenBucketState
{
pub tokens : f64,
pub last_refill : Instant,
pub total_requests : u64,
pub rate_limited_requests : u64,
}
impl TokenBucketState
{
#[ must_use ]
pub fn new( initial_tokens : f64 ) -> Self
{
Self
{
tokens : initial_tokens,
last_refill : Instant::now(),
total_requests : 0,
rate_limited_requests : 0,
}
}
pub fn refill_tokens( &mut self, refill_rate : f64, burst_capacity : f64 )
{
let now = Instant::now();
let elapsed = now.duration_since( self.last_refill ).as_secs_f64();
let tokens_to_add = elapsed * refill_rate;
self.tokens = ( self.tokens + tokens_to_add ).min( burst_capacity );
self.last_refill = now;
}
pub fn try_consume( &mut self ) -> bool
{
self.total_requests += 1;
if self.tokens >= 1.0
{
self.tokens -= 1.0;
true
}
else
{
self.rate_limited_requests += 1;
false
}
}
}
#[ derive( Debug ) ]
pub struct SlidingWindowState
{
pub request_timestamps : VecDeque< Instant >,
pub total_requests : u64,
pub rate_limited_requests : u64,
}
impl Default for SlidingWindowState
{
fn default() -> Self
{
Self::new()
}
}
impl SlidingWindowState
{
#[ must_use ]
pub fn new() -> Self
{
Self
{
request_timestamps : VecDeque::new(),
total_requests : 0,
rate_limited_requests : 0,
}
}
pub fn cleanup_old_timestamps( &mut self, window_duration : Duration )
{
let cutoff_time = Instant::now().checked_sub( window_duration ).unwrap();
while let Some( &front_time ) = self.request_timestamps.front()
{
if front_time < cutoff_time
{
self.request_timestamps.pop_front();
}
else
{
break;
}
}
}
pub fn try_add_request( &mut self, max_requests : u32 ) -> bool
{
self.total_requests += 1;
if self.request_timestamps.len() < max_requests as usize
{
self.request_timestamps.push_back( Instant::now() );
true
}
else
{
self.rate_limited_requests += 1;
false
}
}
}
#[ derive( Debug ) ]
pub struct EnhancedRateLimiter
{
config : EnhancedRateLimitingConfig,
token_bucket_state : Option< Arc< Mutex< TokenBucketState > > >,
sliding_window_state : Option< Arc< Mutex< SlidingWindowState > > >,
}
impl EnhancedRateLimiter
{
pub fn new( config : EnhancedRateLimitingConfig ) -> core::result::Result< Self, String >
{
config.validate()?;
let ( token_bucket_state, sliding_window_state ) = match config.algorithm
{
RateLimitingAlgorithm::TokenBucket =>
{
let state = TokenBucketState::new( f64::from(config.burst_capacity) );
( Some( Arc::new( Mutex::new( state ) ) ), None )
},
RateLimitingAlgorithm::SlidingWindow =>
{
let state = SlidingWindowState::new();
( None, Some( Arc::new( Mutex::new( state ) ) ) )
},
};
Ok( Self
{
config,
token_bucket_state,
sliding_window_state,
} )
}
pub async fn execute< F, Fut, T >( &self, operation : F ) -> Result< T >
where
F : Fn() -> Fut,
Fut : core::future::Future< Output = Result< T > >,
{
if !self.should_allow_request().await?
{
return Err( error_tools::untyped::Error::msg( "Rate limit exceeded - request rejected" ) );
}
operation().await
}
async fn should_allow_request( &self ) -> Result< bool >
{
tokio ::task::yield_now().await;
match self.config.algorithm
{
RateLimitingAlgorithm::TokenBucket =>
{
if let Some( state ) = &self.token_bucket_state
{
let mut bucket = state.lock().unwrap();
bucket.refill_tokens( self.config.refill_rate, f64::from(self.config.burst_capacity) );
Ok( bucket.try_consume() )
}
else
{
Ok( true ) }
},
RateLimitingAlgorithm::SlidingWindow =>
{
if let Some( state ) = &self.sliding_window_state
{
let mut window = state.lock().unwrap();
window.cleanup_old_timestamps( Duration::from_millis( self.config.window_duration_ms ) );
Ok( window.try_add_request( self.config.max_requests ) )
}
else
{
Ok( true ) }
}
}
}
#[ must_use ]
pub fn get_token_bucket_state( &self ) -> Option< TokenBucketState >
{
if let Some( state ) = &self.token_bucket_state
{
let bucket = state.lock().unwrap();
Some( TokenBucketState
{
tokens : bucket.tokens,
last_refill : bucket.last_refill,
total_requests : bucket.total_requests,
rate_limited_requests : bucket.rate_limited_requests,
} )
}
else
{
None
}
}
#[ must_use ]
pub fn get_sliding_window_state( &self ) -> Option< SlidingWindowState >
{
if let Some( state ) = &self.sliding_window_state
{
let window = state.lock().unwrap();
Some( SlidingWindowState
{
request_timestamps : window.request_timestamps.clone(),
total_requests : window.total_requests,
rate_limited_requests : window.rate_limited_requests,
} )
}
else
{
None
}
}
#[ must_use ]
pub fn config( &self ) -> &EnhancedRateLimitingConfig
{
&self.config
}
}
struct MockHttpClient
{
call_count : Arc< Mutex< u32 > >,
}
impl MockHttpClient
{
fn new() -> Self
{
Self
{
call_count : Arc::new( Mutex::new( 0 ) ),
}
}
async fn make_request( &self ) -> Result< String >
{
tokio ::task::yield_now().await;
let mut count = self.call_count.lock().unwrap();
*count += 1;
let call_number = *count;
drop( count );
Ok( format!( "response_{call_number}" ) )
}
#[ allow(dead_code) ]
fn get_call_count( &self ) -> u32
{
*self.call_count.lock().unwrap()
}
}
#[ tokio::test ]
async fn test_rate_limiting_config_default_values()
{
let config = EnhancedRateLimitingConfig::default();
assert_eq!( config.max_requests, 100 );
assert_eq!( config.window_duration_ms, 60000 );
assert_eq!( config.burst_capacity, 10 );
assert!((config.refill_rate - 1.66).abs() < f64::EPSILON);
assert_eq!( config.algorithm, RateLimitingAlgorithm::TokenBucket );
assert_eq!( config.timeout_ms, 5000 );
assert!( !config.per_endpoint );
}
#[ tokio::test ]
async fn test_rate_limiting_config_builder_pattern()
{
let config = EnhancedRateLimitingConfig::new()
.with_max_requests( 50 )
.with_window_duration( 30000 )
.with_burst_capacity( 5 )
.with_refill_rate( 0.83 )
.with_algorithm( RateLimitingAlgorithm::SlidingWindow )
.with_timeout( 3000 )
.with_per_endpoint( true );
assert_eq!( config.max_requests, 50 );
assert_eq!( config.window_duration_ms, 30000 );
assert_eq!( config.burst_capacity, 5 );
assert!((config.refill_rate - 0.83).abs() < f64::EPSILON);
assert_eq!( config.algorithm, RateLimitingAlgorithm::SlidingWindow );
assert_eq!( config.timeout_ms, 3000 );
assert!( config.per_endpoint );
}
#[ tokio::test ]
async fn test_rate_limiting_config_validation()
{
let valid_config = EnhancedRateLimitingConfig::default();
assert!( valid_config.validate().is_ok() );
let invalid_config = EnhancedRateLimitingConfig::default().with_max_requests( 0 );
assert!( invalid_config.validate().is_err() );
let invalid_config = EnhancedRateLimitingConfig::default().with_window_duration( 0 );
assert!( invalid_config.validate().is_err() );
let invalid_config = EnhancedRateLimitingConfig::default().with_burst_capacity( 0 );
assert!( invalid_config.validate().is_err() );
let invalid_config = EnhancedRateLimitingConfig::default().with_refill_rate( 0.0 );
assert!( invalid_config.validate().is_err() );
let invalid_config = EnhancedRateLimitingConfig::default().with_timeout( 0 );
assert!( invalid_config.validate().is_err() );
}
#[ tokio::test ]
async fn test_token_bucket_initial_state()
{
let config = EnhancedRateLimitingConfig::new()
.with_algorithm( RateLimitingAlgorithm::TokenBucket )
.with_burst_capacity( 5 );
let rate_limiter = EnhancedRateLimiter::new( config ).unwrap();
let state = rate_limiter.get_token_bucket_state().unwrap();
assert!((state.tokens - 5.0).abs() < f64::EPSILON);
assert_eq!( state.total_requests, 0 );
assert_eq!( state.rate_limited_requests, 0 );
}
#[ tokio::test ]
async fn test_token_bucket_token_consumption()
{
let config = EnhancedRateLimitingConfig::new()
.with_algorithm( RateLimitingAlgorithm::TokenBucket )
.with_burst_capacity( 3 )
.with_refill_rate( 0.1 ); let rate_limiter = EnhancedRateLimiter::new( config ).unwrap();
let mock_client = MockHttpClient::new();
for i in 1..=3
{
let result = rate_limiter.execute( || mock_client.make_request() ).await;
assert!( result.is_ok() );
assert_eq!( result.unwrap(), format!( "response_{i}" ) );
}
let result = rate_limiter.execute( || mock_client.make_request() ).await;
assert!( result.is_err() );
assert!( result.unwrap_err().to_string().contains( "Rate limit exceeded" ) );
let state = rate_limiter.get_token_bucket_state().unwrap();
assert!( state.tokens < 1.0 ); assert_eq!( state.total_requests, 4 );
assert_eq!( state.rate_limited_requests, 1 );
}
#[ tokio::test ]
async fn test_token_bucket_refill_timing()
{
let config = EnhancedRateLimitingConfig::new()
.with_algorithm( RateLimitingAlgorithm::TokenBucket )
.with_burst_capacity( 1 )
.with_refill_rate( 10.0 ); let rate_limiter = EnhancedRateLimiter::new( config ).unwrap();
let mock_client = MockHttpClient::new();
let result = rate_limiter.execute( || mock_client.make_request() ).await;
assert!( result.is_ok() );
let result = rate_limiter.execute( || mock_client.make_request() ).await;
assert!( result.is_err() );
sleep( Duration::from_millis( 100 ) ).await;
let result = rate_limiter.execute( || mock_client.make_request() ).await;
assert!( result.is_ok() );
}
#[ tokio::test ]
async fn test_sliding_window_initial_state()
{
let config = EnhancedRateLimitingConfig::new()
.with_algorithm( RateLimitingAlgorithm::SlidingWindow )
.with_max_requests( 5 );
let rate_limiter = EnhancedRateLimiter::new( config ).unwrap();
let state = rate_limiter.get_sliding_window_state().unwrap();
assert_eq!( state.request_timestamps.len(), 0 );
assert_eq!( state.total_requests, 0 );
assert_eq!( state.rate_limited_requests, 0 );
}
#[ tokio::test ]
async fn test_sliding_window_request_tracking()
{
let config = EnhancedRateLimitingConfig::new()
.with_algorithm( RateLimitingAlgorithm::SlidingWindow )
.with_max_requests( 3 )
.with_window_duration( 1000 ); let rate_limiter = EnhancedRateLimiter::new( config ).unwrap();
let mock_client = MockHttpClient::new();
for i in 1..=3
{
let result = rate_limiter.execute( || mock_client.make_request() ).await;
assert!( result.is_ok() );
assert_eq!( result.unwrap(), format!( "response_{i}" ) );
}
let result = rate_limiter.execute( || mock_client.make_request() ).await;
assert!( result.is_err() );
assert!( result.unwrap_err().to_string().contains( "Rate limit exceeded" ) );
let state = rate_limiter.get_sliding_window_state().unwrap();
assert_eq!( state.request_timestamps.len(), 3 );
assert_eq!( state.total_requests, 4 );
assert_eq!( state.rate_limited_requests, 1 );
}
#[ tokio::test ]
async fn test_sliding_window_cleanup()
{
let config = EnhancedRateLimitingConfig::new()
.with_algorithm( RateLimitingAlgorithm::SlidingWindow )
.with_max_requests( 2 )
.with_window_duration( 100 ); let rate_limiter = EnhancedRateLimiter::new( config ).unwrap();
let mock_client = MockHttpClient::new();
let result1 = rate_limiter.execute( || mock_client.make_request() ).await;
assert!( result1.is_ok() );
let result2 = rate_limiter.execute( || mock_client.make_request() ).await;
assert!( result2.is_ok() );
let result3 = rate_limiter.execute( || mock_client.make_request() ).await;
assert!( result3.is_err() );
sleep( Duration::from_millis( 150 ) ).await;
let result4 = rate_limiter.execute( || mock_client.make_request() ).await;
assert!( result4.is_ok() );
}
#[ tokio::test ]
async fn test_rate_limiter_thread_safety()
{
let config = EnhancedRateLimitingConfig::new()
.with_algorithm( RateLimitingAlgorithm::TokenBucket )
.with_burst_capacity( 10 )
.with_refill_rate( 1.0 );
let rate_limiter = Arc::new( EnhancedRateLimiter::new( config ).unwrap() );
let mock_client = MockHttpClient::new();
let rate_limiter_clone = rate_limiter.clone();
let handle = tokio::spawn( async move {
rate_limiter_clone.execute( || async { Ok( "concurrent success" ) } ).await
} );
let result = rate_limiter.execute( || mock_client.make_request() ).await;
let concurrent_result = handle.await.unwrap();
assert!( result.is_ok() );
assert!( concurrent_result.is_ok() );
}
#[ tokio::test ]
async fn test_rate_limiter_zero_overhead_when_disabled()
{
let config = EnhancedRateLimitingConfig::default();
assert!( config.validate().is_ok() );
let rate_limiter = EnhancedRateLimiter::new( config );
assert!( rate_limiter.is_ok() );
}
#[ tokio::test ]
async fn test_token_bucket_burst_capacity_enforcement()
{
let config = EnhancedRateLimitingConfig::new()
.with_algorithm( RateLimitingAlgorithm::TokenBucket )
.with_burst_capacity( 3 )
.with_refill_rate( 100.0 ); let rate_limiter = EnhancedRateLimiter::new( config ).unwrap();
sleep( Duration::from_millis( 100 ) ).await;
let state = rate_limiter.get_token_bucket_state().unwrap();
assert!( state.tokens <= 3.0 );
}
#[ tokio::test ]
async fn test_rate_limiter_metrics_tracking()
{
let config = EnhancedRateLimitingConfig::new()
.with_algorithm( RateLimitingAlgorithm::TokenBucket )
.with_burst_capacity( 2 )
.with_refill_rate( 0.1 );
let rate_limiter = EnhancedRateLimiter::new( config ).unwrap();
let mock_client = MockHttpClient::new();
let _result1 = rate_limiter.execute( || mock_client.make_request() ).await;
let _result2 = rate_limiter.execute( || mock_client.make_request() ).await;
let _result3 = rate_limiter.execute( || mock_client.make_request() ).await;
let state = rate_limiter.get_token_bucket_state().unwrap();
assert_eq!( state.total_requests, 3 );
assert_eq!( state.rate_limited_requests, 1 );
}
#[ tokio::test ]
async fn test_sliding_window_vs_token_bucket_behavior()
{
let token_bucket_config = EnhancedRateLimitingConfig::new()
.with_algorithm( RateLimitingAlgorithm::TokenBucket )
.with_burst_capacity( 2 )
.with_refill_rate( 0.1 );
let token_bucket = EnhancedRateLimiter::new( token_bucket_config ).unwrap();
let sliding_window_config = EnhancedRateLimitingConfig::new()
.with_algorithm( RateLimitingAlgorithm::SlidingWindow )
.with_max_requests( 2 )
.with_window_duration( 1000 );
let sliding_window = EnhancedRateLimiter::new( sliding_window_config ).unwrap();
let mock_client = MockHttpClient::new();
assert!( token_bucket.execute( || mock_client.make_request() ).await.is_ok() );
assert!( sliding_window.execute( || mock_client.make_request() ).await.is_ok() );
assert!( token_bucket.get_token_bucket_state().is_some() );
assert!( token_bucket.get_sliding_window_state().is_none() );
assert!( sliding_window.get_token_bucket_state().is_none() );
assert!( sliding_window.get_sliding_window_state().is_some() );
}
}
#[ cfg( not( feature = "rate_limiting" ) ) ]
mod no_rate_limiting_tests
{
#[ tokio::test ]
async fn test_zero_overhead_when_rate_limiting_disabled()
{
assert!( true, "Rate limiting feature is disabled - zero overhead ensured" );
}
}
#[ cfg( feature = "rate_limiting" ) ]
pub use rate_limiting_tests::
{
EnhancedRateLimitingConfig,
RateLimitingAlgorithm,
TokenBucketState,
SlidingWindowState,
EnhancedRateLimiter,
};