use super::rate_limit_strategy::{ClientRateInfo, RateLimitStrategy};
use crate::ToolError;
use dashmap::DashMap;
use std::sync::Arc;
use std::time::{Duration, Instant};
#[derive(Debug, Clone)]
pub struct TokenBucketStrategy {
max_requests: usize,
time_window: Duration,
burst_size: Option<usize>,
clients: Arc<DashMap<String, ClientRateInfo>>,
}
impl TokenBucketStrategy {
pub fn new(max_requests: usize, time_window: Duration) -> Self {
Self {
max_requests,
time_window,
burst_size: None,
clients: Arc::new(DashMap::new()),
}
}
pub fn with_burst(max_requests: usize, time_window: Duration, burst_size: usize) -> Self {
Self {
max_requests,
time_window,
burst_size: Some(burst_size),
clients: Arc::new(DashMap::new()),
}
}
}
impl RateLimitStrategy for TokenBucketStrategy {
fn check_rate_limit(&self, client_id: &str) -> Result<(), ToolError> {
let now = Instant::now();
let mut entry = self
.clients
.entry(client_id.to_string())
.or_insert_with(|| {
ClientRateInfo::new(self.burst_size.unwrap_or(self.max_requests) as f64)
});
let elapsed = now.duration_since(entry.last_refill);
let elapsed_seconds = elapsed.as_secs_f64();
let refill_rate = self.max_requests as f64 / self.time_window.as_secs_f64();
let tokens_to_add = elapsed_seconds * refill_rate;
let max_tokens = self.burst_size.unwrap_or(self.max_requests) as f64;
entry.burst_tokens = (entry.burst_tokens + tokens_to_add).min(max_tokens);
entry.last_refill = now;
entry
.request_times
.retain(|&time| now.duration_since(time) < self.time_window);
let has_burst_token = entry.burst_tokens >= 1.0;
if !has_burst_token {
let tokens_needed = 1.0 - entry.burst_tokens;
let seconds_until_token = tokens_needed / refill_rate;
let retry_after = Duration::from_secs_f64(seconds_until_token);
return Err(ToolError::RateLimited {
source: None,
source_message: format!(
"Token bucket rate limit: {} requests per {:?}",
self.max_requests, self.time_window
),
context: format!("User exceeded rate limit of {} requests", self.max_requests),
retry_after: Some(retry_after),
});
}
entry.burst_tokens -= 1.0;
entry.request_times.push(now);
Ok(())
}
fn reset_client(&self, client_id: &str) {
self.clients.remove(client_id);
}
fn clear_all(&self) {
self.clients.clear();
}
fn get_request_count(&self, client_id: &str) -> usize {
let now = Instant::now();
self.clients
.get(client_id)
.map(|entry| {
entry
.request_times
.iter()
.filter(|&&time| now.duration_since(time) < self.time_window)
.count()
})
.unwrap_or(0)
}
fn strategy_name(&self) -> &str {
"TokenBucket"
}
}