use std::sync::Arc;
use std::time::Duration;
use super::rate_limit_strategy::{FixedWindowStrategy, RateLimitStrategy};
use super::token_bucket::TokenBucketStrategy;
use crate::ToolError;
#[derive(Debug, Clone, Copy)]
pub enum RateLimitStrategyType {
TokenBucket,
FixedWindow,
}
#[derive(Clone)]
pub struct RateLimiter {
strategy: Arc<dyn RateLimitStrategy>,
}
impl std::fmt::Debug for RateLimiter {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("RateLimiter")
.field("strategy", &self.strategy.strategy_name())
.finish()
}
}
impl RateLimiter {
pub fn new(max_requests: usize, time_window: Duration) -> Self {
Self {
strategy: Arc::new(TokenBucketStrategy::new(max_requests, time_window)),
}
}
pub fn with_strategy<S: RateLimitStrategy + 'static>(strategy: S) -> Self {
Self {
strategy: Arc::new(strategy),
}
}
pub fn builder() -> RateLimiterBuilder {
RateLimiterBuilder::default()
}
pub fn check_rate_limit(&self, client_id: &str) -> Result<(), ToolError> {
self.strategy.check_rate_limit(client_id)
}
pub fn reset_client(&self, client_id: &str) {
self.strategy.reset_client(client_id)
}
pub fn clear_all(&self) {
self.strategy.clear_all()
}
pub fn get_request_count(&self, client_id: &str) -> usize {
self.strategy.get_request_count(client_id)
}
pub fn strategy_name(&self) -> &str {
self.strategy.strategy_name()
}
}
#[derive(Debug, Default)]
pub struct RateLimiterBuilder {
strategy_type: Option<RateLimitStrategyType>,
max_requests: Option<usize>,
time_window: Option<Duration>,
burst_size: Option<usize>,
}
impl RateLimiterBuilder {
pub fn strategy(mut self, strategy: RateLimitStrategyType) -> Self {
self.strategy_type = Some(strategy);
self
}
pub fn max_requests(mut self, max: usize) -> Self {
self.max_requests = Some(max);
self
}
pub fn time_window(mut self, window: Duration) -> Self {
self.time_window = Some(window);
self
}
pub fn burst_size(mut self, size: usize) -> Self {
self.burst_size = Some(size);
self
}
pub fn build(self) -> RateLimiter {
let max_requests = self.max_requests.unwrap_or(10);
let time_window = self.time_window.unwrap_or_else(|| Duration::from_secs(60));
let strategy_type = self
.strategy_type
.unwrap_or(RateLimitStrategyType::TokenBucket);
let strategy: Arc<dyn RateLimitStrategy> = match strategy_type {
RateLimitStrategyType::TokenBucket => {
if let Some(burst_size) = self.burst_size {
Arc::new(TokenBucketStrategy::with_burst(
max_requests,
time_window,
burst_size,
))
} else {
Arc::new(TokenBucketStrategy::new(max_requests, time_window))
}
}
RateLimitStrategyType::FixedWindow => {
Arc::new(FixedWindowStrategy::new(max_requests, time_window))
}
};
RateLimiter { strategy }
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::thread;
#[test]
fn test_rate_limiter_allows_requests_within_limit() {
let limiter = RateLimiter::new(3, Duration::from_secs(1));
assert!(limiter.check_rate_limit("user1").is_ok());
assert!(limiter.check_rate_limit("user1").is_ok());
assert!(limiter.check_rate_limit("user1").is_ok());
}
#[test]
fn test_rate_limiter_blocks_requests_over_limit() {
let limiter = RateLimiter::new(2, Duration::from_secs(1));
assert!(limiter.check_rate_limit("user1").is_ok());
assert!(limiter.check_rate_limit("user1").is_ok());
assert!(limiter.check_rate_limit("user1").is_err());
}
#[test]
fn test_rate_limiter_with_different_clients() {
let limiter = RateLimiter::new(1, Duration::from_secs(1));
assert!(limiter.check_rate_limit("user1").is_ok());
assert!(limiter.check_rate_limit("user2").is_ok());
assert!(limiter.check_rate_limit("user1").is_err());
assert!(limiter.check_rate_limit("user2").is_err());
}
#[test]
fn test_rate_limiter_builder() {
let limiter = RateLimiter::builder()
.max_requests(5)
.time_window(Duration::from_secs(10))
.burst_size(2)
.build();
assert!(limiter.check_rate_limit("user1").is_ok());
}
#[test]
fn test_reset_client() {
let limiter = RateLimiter::new(1, Duration::from_secs(1));
assert!(limiter.check_rate_limit("user1").is_ok());
assert!(limiter.check_rate_limit("user1").is_err());
limiter.reset_client("user1");
assert!(limiter.check_rate_limit("user1").is_ok());
}
#[test]
fn test_time_based_token_replenishment() {
let limiter = RateLimiter::new(10, Duration::from_millis(1000));
for _ in 0..10 {
assert!(limiter.check_rate_limit("user1").is_ok());
}
assert!(limiter.check_rate_limit("user1").is_err());
thread::sleep(Duration::from_millis(150));
assert!(limiter.check_rate_limit("user1").is_ok());
assert!(limiter.check_rate_limit("user1").is_err());
}
#[test]
fn test_burst_size_cap() {
let limiter = RateLimiter::builder()
.max_requests(5)
.time_window(Duration::from_secs(1))
.burst_size(3) .build();
assert!(limiter.check_rate_limit("user1").is_ok());
assert!(limiter.check_rate_limit("user1").is_ok());
assert!(limiter.check_rate_limit("user1").is_ok());
assert!(limiter.check_rate_limit("user1").is_err());
thread::sleep(Duration::from_millis(250));
assert!(limiter.check_rate_limit("user1").is_ok());
}
#[test]
fn test_token_accumulation_capped() {
let limiter = RateLimiter::builder()
.max_requests(10)
.time_window(Duration::from_millis(100))
.burst_size(5)
.build();
thread::sleep(Duration::from_millis(200));
for _ in 0..5 {
assert!(limiter.check_rate_limit("user1").is_ok());
}
assert!(limiter.check_rate_limit("user1").is_err());
thread::sleep(Duration::from_millis(15));
assert!(limiter.check_rate_limit("user1").is_ok());
assert!(limiter.check_rate_limit("user1").is_err());
}
#[test]
fn test_fractional_token_replenishment() {
let limiter = RateLimiter::new(1, Duration::from_secs(1));
assert!(limiter.check_rate_limit("user1").is_ok());
assert!(limiter.check_rate_limit("user1").is_err());
thread::sleep(Duration::from_millis(500));
assert!(limiter.check_rate_limit("user1").is_err());
thread::sleep(Duration::from_millis(600));
assert!(limiter.check_rate_limit("user1").is_ok());
}
}