use std::collections::HashMap;
use std::time::{Duration, Instant};
#[derive(Clone, Copy)]
pub struct RateLimit {
limit: usize,
duration: Duration,
}
impl RateLimit {
pub const fn new(limit: usize, duration: Duration) -> Self {
Self { limit, duration }
}
}
struct Cooldown {
attempt_count: usize,
last_update: Option<Instant>,
}
impl Cooldown {
fn new() -> Self {
Self {
attempt_count: 0,
last_update: None,
}
}
fn update(&mut self, rate_limit: &RateLimit) -> Option<Duration> {
if rate_limit.limit == 0 {
return Some(rate_limit.duration);
}
let now = Instant::now();
let elapsed = self
.last_update
.map_or(Duration::ZERO, |last| now.duration_since(last));
if elapsed > rate_limit.duration {
self.attempt_count = 1;
} else {
if self.attempt_count >= rate_limit.limit {
return Some(rate_limit.duration - elapsed);
}
self.attempt_count += 1;
}
self.last_update = Some(now);
None
}
}
struct CooldownTracker<K> {
rate_limit: RateLimit,
cooldowns: HashMap<K, Cooldown>,
}
impl<K: Eq + std::hash::Hash> CooldownTracker<K> {
fn new(rate_limit: RateLimit) -> Self {
Self {
rate_limit,
cooldowns: HashMap::new(),
}
}
fn get_and<R>(&mut self, key: K, f: impl FnOnce(&RateLimit, &mut Cooldown) -> R) -> R {
let cooldown = self.cooldowns.entry(key).or_insert(Cooldown::new());
f(&self.rate_limit, cooldown)
}
}
pub struct Spam {
user_limiter: CooldownTracker<String>,
global_command_limiter: CooldownTracker<String>,
failed_command_limiter: CooldownTracker<String>,
}
impl Spam {
pub fn new(
user_limit: RateLimit,
global_command_limit: RateLimit,
failed_command_limit: RateLimit,
) -> Self {
Self {
user_limiter: CooldownTracker::new(user_limit),
global_command_limiter: CooldownTracker::new(global_command_limit),
failed_command_limiter: CooldownTracker::new(failed_command_limit),
}
}
pub fn update_user_cooldown(&mut self, user_id: &str) -> Option<Duration> {
self.user_limiter
.get_and(user_id.into(), |rate_limit, cooldown| {
cooldown.update(rate_limit)
})
}
pub fn update_global_command_cooldown(
&mut self,
command_name: &str,
rate_limit: &RateLimit,
) -> Option<Duration> {
self.global_command_limiter
.get_and(command_name.into(), |_, cooldown| {
cooldown.update(rate_limit)
})
}
pub fn update_failed_command_cooldown(&mut self, user_id: &str) -> Option<Duration> {
self.failed_command_limiter
.get_and(user_id.into(), |rate_limit, cooldown| {
cooldown.update(rate_limit)
})
}
}
impl Default for Spam {
fn default() -> Self {
Self::new(
RateLimit::new(1, Duration::from_secs(5)),
RateLimit::new(1, Duration::from_secs(5)),
RateLimit::new(2, Duration::from_secs(30)),
)
}
}
#[cfg(test)]
mod test {
use super::{RateLimit, Spam};
use std::thread::sleep;
use std::time::Duration;
const USER_RATE_LIMIT: RateLimit = RateLimit::new(3, Duration::from_millis(50));
const GLOBAL_COMMAND_RATE_LIMIT: RateLimit = USER_RATE_LIMIT;
const FAILED_COMMAND_RATE_LIMIT: RateLimit = RateLimit::new(2, Duration::from_millis(100));
const USER_A: &str = "user_a";
const USER_B: &str = "user_b";
const CMD_A: &str = "cmd_a";
fn test_cooldown(
rate_limit: &RateLimit,
mut cooldown_update: impl FnMut() -> Option<Duration>,
) {
sleep(rate_limit.duration);
for i in 0..rate_limit.limit {
assert!(
cooldown_update().is_none(),
"Expected no cooldown on attempt {}, but cooldown was returned",
i + 1
);
}
let remaining_cooldown = cooldown_update();
assert!(
remaining_cooldown.is_some(),
"Expected cooldown after hitting the rate limit, but none was returned"
);
assert!(
remaining_cooldown.unwrap() <= rate_limit.duration,
"Cooldown should be less than or equal to the rate limit duration"
);
sleep(rate_limit.duration);
for i in 0..rate_limit.limit {
assert!(
cooldown_update().is_none(),
"Expected no cooldown after reset, but cooldown was returned on attempt {}",
i + 1
);
}
assert!(
cooldown_update().is_some(),
"Expected cooldown to trigger again after hitting rate limit"
);
}
#[test]
fn test_user_rate_limiter() {
let mut spam = Spam::new(
USER_RATE_LIMIT,
GLOBAL_COMMAND_RATE_LIMIT,
FAILED_COMMAND_RATE_LIMIT,
);
test_cooldown(&USER_RATE_LIMIT, || spam.update_user_cooldown(USER_A));
test_cooldown(&USER_RATE_LIMIT, || spam.update_user_cooldown(USER_B));
test_cooldown(&USER_RATE_LIMIT, || spam.update_user_cooldown(USER_A));
}
#[test]
fn test_global_command_rate_limiter() {
let mut spam = Spam::new(
USER_RATE_LIMIT,
GLOBAL_COMMAND_RATE_LIMIT,
FAILED_COMMAND_RATE_LIMIT,
);
test_cooldown(&GLOBAL_COMMAND_RATE_LIMIT, || {
spam.update_global_command_cooldown(CMD_A, &GLOBAL_COMMAND_RATE_LIMIT)
});
}
#[test]
fn test_zero_limit() {
let rate_limit = RateLimit::new(0, Duration::from_millis(50));
let mut spam = Spam::new(rate_limit, rate_limit, rate_limit);
for _ in 0..5 {
assert!(
spam.update_user_cooldown(USER_A).is_some(),
"Expected cooldown when limit is 0, but got none"
)
}
}
}