use std::{
collections::HashMap,
sync::Arc,
time::{Duration, Instant},
};
use tokio::sync::RwLock;
#[derive(Debug, Clone)]
pub struct MessageLimits {
pub max_size_bytes: usize,
pub max_messages_per_window: u32,
pub window_duration: Duration,
pub ban_duration: Duration,
}
impl Default for MessageLimits {
fn default() -> Self {
Self {
max_size_bytes: 64 * 1024, max_messages_per_window: 10,
window_duration: Duration::from_secs(1),
ban_duration: Duration::from_secs(60),
}
}
}
#[derive(Debug)]
struct UserRateState {
message_times: Vec<Instant>,
banned_until: Option<Instant>,
}
impl UserRateState {
fn new() -> Self {
Self {
message_times: Vec::new(),
banned_until: None,
}
}
fn is_banned(&self) -> bool {
if let Some(until) = self.banned_until {
Instant::now() < until
} else {
false
}
}
fn ban(&mut self, duration: Duration) {
self.banned_until = Some(Instant::now() + duration);
}
fn clean_old_messages(&mut self, window: Duration) {
let cutoff = Instant::now() - window;
self.message_times.retain(|&time| time > cutoff);
}
fn record_message(&mut self) {
self.message_times.push(Instant::now());
}
}
pub struct RateLimiter {
limits: MessageLimits,
states: Arc<RwLock<HashMap<String, UserRateState>>>,
}
impl RateLimiter {
pub fn new(limits: MessageLimits) -> Self {
Self {
limits,
states: Arc::new(RwLock::new(HashMap::new())),
}
}
pub async fn check_allowed(&self, user_id: &str, message_size: usize) -> Result<(), String> {
if message_size > self.limits.max_size_bytes {
return Err(format!(
"Message too large: {} bytes (max {})",
message_size, self.limits.max_size_bytes
));
}
let mut states = self.states.write().await;
let state = states.entry(user_id.to_string()).or_insert_with(UserRateState::new);
if state.is_banned() {
return Err("You are temporarily banned for sending too many messages".to_string());
}
state.clean_old_messages(self.limits.window_duration);
if state.message_times.len() >= self.limits.max_messages_per_window as usize {
state.ban(self.limits.ban_duration);
tracing::warn!("User {} exceeded rate limit and was banned", user_id);
return Err(format!(
"Rate limit exceeded: max {} messages per {} seconds",
self.limits.max_messages_per_window,
self.limits.window_duration.as_secs()
));
}
state.record_message();
Ok(())
}
pub async fn reset_user(&self, user_id: &str) {
self.states.write().await.remove(user_id);
}
pub async fn ban_user(&self, user_id: &str, duration: Duration) {
let mut states = self.states.write().await;
let state = states.entry(user_id.to_string()).or_insert_with(UserRateState::new);
state.ban(duration);
}
pub async fn is_banned(&self, user_id: &str) -> bool {
let states = self.states.read().await;
states.get(user_id).is_some_and(|s| s.is_banned())
}
pub async fn unban_user(&self, user_id: &str) {
let mut states = self.states.write().await;
if let Some(state) = states.get_mut(user_id) {
state.banned_until = None;
}
}
}
impl Clone for RateLimiter {
fn clone(&self) -> Self {
Self {
limits: self.limits.clone(),
states: self.states.clone(),
}
}
}