use std::sync::{ Arc, Mutex };
use std::time::{ Duration, Instant };
use std::collections::VecDeque;
use reqwest::{ Client, Method };
use serde::Serialize;
use serde::Deserialize;
use crate::error::Error;
#[ cfg( feature = "logging" ) ]
use tracing::{ warn, debug };
#[ derive( Debug, Clone ) ]
pub struct RateLimitingConfig
{
pub requests_per_second : f64,
pub bucket_size : usize,
pub algorithm : String,
pub enable_metrics : bool,
}
#[ derive( Debug, Clone ) ]
pub enum RateLimiter
{
TokenBucket
{
tokens : f64,
last_refill : Instant,
},
SlidingWindow
{
requests : VecDeque< Instant >,
},
}
#[ derive( Debug, Clone ) ]
pub struct RateLimitingMetrics
{
pub total_requests : u64,
pub limited_requests : u64,
pub allowed_requests : u64,
pub current_algorithm : String,
pub available_tokens : f64,
pub window_requests : usize,
}
#[ derive( Debug ) ]
pub struct RateLimit
{
config : RateLimitingConfig,
limiter : Arc< Mutex< RateLimiter > >,
metrics : Arc< Mutex< RateLimitingMetrics > >,
}
impl RateLimit
{
pub fn new( config : RateLimitingConfig ) -> Self
{
let limiter = match config.algorithm.as_str()
{
"sliding_window" => RateLimiter::SlidingWindow {
requests : VecDeque::new(),
},
_ => RateLimiter::TokenBucket {
tokens : config.bucket_size as f64,
last_refill : Instant::now(),
},
};
Self {
config : config.clone(),
limiter : Arc::new( Mutex::new( limiter ) ),
metrics : Arc::new( Mutex::new( RateLimitingMetrics {
total_requests : 0,
limited_requests : 0,
allowed_requests : 0,
current_algorithm : config.algorithm,
available_tokens : config.bucket_size as f64,
window_requests : 0,
} ) ),
}
}
pub async fn should_allow_request( &self ) -> bool
{
let mut limiter = self.limiter.lock().unwrap();
let mut metrics = self.metrics.lock().unwrap();
metrics.total_requests += 1;
let allowed = match &mut *limiter
{
RateLimiter::TokenBucket { tokens, last_refill } => {
let now = Instant::now();
let elapsed = now.duration_since( *last_refill ).as_secs_f64();
*tokens += elapsed * self.config.requests_per_second;
*tokens = tokens.min( self.config.bucket_size as f64 );
*last_refill = now;
if *tokens >= 1.0
{
*tokens -= 1.0;
metrics.available_tokens = *tokens;
true
} else {
#[ cfg( feature = "logging" ) ]
debug!( "Rate limit exceeded : {} tokens available", *tokens );
metrics.available_tokens = *tokens;
false
}
},
RateLimiter::SlidingWindow { requests } => {
let now = Instant::now();
let window_duration = Duration::from_secs_f64( 1.0 / self.config.requests_per_second * self.config.bucket_size as f64 );
while let Some( &front_time ) = requests.front()
{
if now.duration_since( front_time ) > window_duration
{
requests.pop_front();
} else {
break;
}
}
if requests.len() < self.config.bucket_size
{
requests.push_back( now );
metrics.window_requests = requests.len();
true
} else {
#[ cfg( feature = "logging" ) ]
debug!( "Rate limit exceeded : {} requests in window", requests.len() );
metrics.window_requests = requests.len();
false
}
}
};
if allowed
{
metrics.allowed_requests += 1;
} else {
metrics.limited_requests += 1;
}
allowed
}
pub fn get_metrics( &self ) -> RateLimitingMetrics
{
self.metrics.lock().unwrap().clone()
}
pub fn reset( &self )
{
let mut limiter = self.limiter.lock().unwrap();
let mut metrics = self.metrics.lock().unwrap();
match &mut *limiter
{
RateLimiter::TokenBucket { tokens, last_refill } => {
*tokens = self.config.bucket_size as f64;
*last_refill = Instant::now();
metrics.available_tokens = *tokens;
},
RateLimiter::SlidingWindow { requests } => {
requests.clear();
metrics.window_requests = 0;
}
}
metrics.total_requests = 0;
metrics.limited_requests = 0;
metrics.allowed_requests = 0;
}
}
pub async fn execute_with_rate_limiting< T, R >
(
client : &Client,
method : Method,
url : &str,
api_key : &str,
body : Option< &T >,
config : &super::HttpConfig,
rate_limiter : Option< &RateLimit >,
)
-> Result< R, Error >
where
T: Serialize,
R: for< 'de > Deserialize< 'de >,
{
let Some( rl ) = rate_limiter else {
return super::execute( client, method, url, api_key, body, config ).await;
};
if !rl.should_allow_request().await
{
#[ cfg( feature = "logging" ) ]
warn!( "Request rate limited" );
return Err( Error::RateLimited(
"Rate limit exceeded".to_string()
) );
}
super ::execute( client, method, url, api_key, body, config ).await
}