use crate::ToolError;
use std::time::{Duration, Instant};
pub trait RateLimitStrategy: Send + Sync {
fn check_rate_limit(&self, client_id: &str) -> Result<(), ToolError>;
fn reset_client(&self, client_id: &str);
fn clear_all(&self);
fn get_request_count(&self, client_id: &str) -> usize;
fn strategy_name(&self) -> &str;
}
#[derive(Debug, Clone)]
pub struct ClientRateInfo {
pub request_times: Vec<Instant>,
pub burst_tokens: f64,
pub last_refill: Instant,
}
impl ClientRateInfo {
pub fn new(initial_tokens: f64) -> Self {
Self {
request_times: Vec::new(),
burst_tokens: initial_tokens,
last_refill: Instant::now(),
}
}
}
#[derive(Debug)]
pub struct FixedWindowStrategy {
pub max_requests: usize,
pub window_duration: Duration,
pub clients: dashmap::DashMap<String, FixedWindowClientInfo>,
}
#[derive(Debug, Clone)]
pub struct FixedWindowClientInfo {
pub window_start: Instant,
pub request_count: usize,
}
impl RateLimitStrategy for FixedWindowStrategy {
fn check_rate_limit(&self, client_id: &str) -> Result<(), ToolError> {
let now = Instant::now();
let mut entry = self
.clients
.entry(client_id.to_string())
.or_insert_with(|| FixedWindowClientInfo {
window_start: now,
request_count: 0,
});
if now.duration_since(entry.window_start) >= self.window_duration {
entry.window_start = now;
entry.request_count = 0;
}
if entry.request_count >= self.max_requests {
let time_until_reset = self
.window_duration
.saturating_sub(now.duration_since(entry.window_start));
return Err(ToolError::RateLimited {
source: None,
source_message: format!(
"Fixed window rate limit: {} requests per {:?}",
self.max_requests, self.window_duration
),
context: format!("Exceeded {} requests in current window", self.max_requests),
retry_after: Some(time_until_reset),
});
}
entry.request_count += 1;
Ok(())
}
fn reset_client(&self, client_id: &str) {
self.clients.remove(client_id);
}
fn clear_all(&self) {
self.clients.clear();
}
fn get_request_count(&self, client_id: &str) -> usize {
self.clients
.get(client_id)
.map(|entry| entry.request_count)
.unwrap_or(0)
}
fn strategy_name(&self) -> &str {
"FixedWindow"
}
}
impl FixedWindowStrategy {
pub fn new(max_requests: usize, window_duration: Duration) -> Self {
Self {
max_requests,
window_duration,
clients: dashmap::DashMap::new(),
}
}
}