use std::collections::{HashMap, VecDeque};
use std::sync::Mutex;
use std::time::{Duration, Instant};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct Rate {
pub num: u32,
pub period: Duration,
}
impl Rate {
pub fn new(num: u32, period: Duration) -> Self {
Self { num, period }
}
pub fn parse(s: &str) -> Result<Self, String> {
let s = s.trim();
if s.is_empty() {
return Err("empty rate string".to_string());
}
let (num_part, period_part) = match s.split_once('/') {
Some((n, p)) => (n.trim(), p.trim()),
None => (s, "sec"),
};
let num: u32 = num_part
.parse()
.map_err(|_| format!("invalid rate count `{num_part}` in `{s}`"))?;
if num == 0 {
return Err(format!("rate count must be positive in `{s}`"));
}
let period = match period_part.to_ascii_lowercase().as_str() {
"sec" | "s" | "second" => Duration::from_secs(1),
"min" | "m" | "minute" => Duration::from_secs(60),
"hour" | "h" => Duration::from_secs(3600),
"day" | "d" => Duration::from_secs(86_400),
other => return Err(format!("unknown rate period `{other}` in `{s}`")),
};
Ok(Self { num, period })
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct RateDecision {
pub allowed: bool,
pub retry_after: Option<Duration>,
pub limit: u32,
pub remaining: u32,
}
#[derive(Debug)]
pub struct RateLimiter {
rate: Rate,
buckets: Mutex<HashMap<String, VecDeque<Instant>>>,
}
impl RateLimiter {
pub fn new(rate: Rate) -> Self {
Self {
rate,
buckets: Mutex::new(HashMap::new()),
}
}
pub fn rate(&self) -> Rate {
self.rate
}
pub fn check(&self, key: &str) -> RateDecision {
self.check_at(key, Instant::now())
}
pub fn check_at(&self, key: &str, now: Instant) -> RateDecision {
let window = self.rate.period;
let mut buckets = self.buckets.lock().unwrap_or_else(|e| e.into_inner());
let entries = buckets.entry(key.to_string()).or_default();
while let Some(front) = entries.front() {
match now.checked_duration_since(*front) {
Some(age) if age >= window => {
entries.pop_front();
}
_ => break,
}
}
let count = entries.len() as u32;
if count < self.rate.num {
entries.push_back(now);
RateDecision {
allowed: true,
retry_after: None,
limit: self.rate.num,
remaining: self.rate.num - count - 1,
}
} else {
let retry_after = entries
.front()
.and_then(|oldest| now.checked_duration_since(*oldest))
.map(|age| window.saturating_sub(age))
.unwrap_or(window);
RateDecision {
allowed: false,
retry_after: Some(retry_after),
limit: self.rate.num,
remaining: 0,
}
}
}
pub fn clear(&self, key: &str) {
let mut buckets = self.buckets.lock().unwrap_or_else(|e| e.into_inner());
buckets.remove(key);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parse_each_period() {
assert_eq!(
Rate::parse("1/sec").unwrap().period,
Duration::from_secs(1)
);
assert_eq!(Rate::parse("1/s").unwrap().period, Duration::from_secs(1));
assert_eq!(
Rate::parse("1/second").unwrap().period,
Duration::from_secs(1)
);
assert_eq!(
Rate::parse("1/min").unwrap().period,
Duration::from_secs(60)
);
assert_eq!(
Rate::parse("1/hour").unwrap().period,
Duration::from_secs(3600)
);
assert_eq!(
Rate::parse("1/day").unwrap().period,
Duration::from_secs(86_400)
);
}
#[test]
fn parse_rejects_garbage() {
assert!(Rate::parse("").is_err());
assert!(Rate::parse("oops").is_err());
assert!(Rate::parse("10/fortnight").is_err());
assert!(Rate::parse("0/sec").is_err());
assert!(Rate::parse("abc/min").is_err());
}
#[test]
fn third_request_in_window_denied() {
let limiter = RateLimiter::new(Rate::parse("2/min").unwrap());
let t0 = Instant::now();
let d1 = limiter.check_at("a", t0);
assert!(d1.allowed);
assert_eq!(d1.remaining, 1);
let d2 = limiter.check_at("a", t0 + Duration::from_secs(1));
assert!(d2.allowed);
assert_eq!(d2.remaining, 0);
let d3 = limiter.check_at("a", t0 + Duration::from_secs(2));
assert!(!d3.allowed);
assert!(d3.retry_after.is_some());
assert_eq!(d3.retry_after.unwrap(), Duration::from_secs(58));
}
#[test]
fn distinct_keys_are_independent() {
let limiter = RateLimiter::new(Rate::parse("1/min").unwrap());
let t0 = Instant::now();
assert!(limiter.check_at("a", t0).allowed);
assert!(limiter.check_at("b", t0).allowed);
assert!(!limiter.check_at("a", t0).allowed);
}
#[test]
fn allowed_again_after_window_elapses() {
let limiter = RateLimiter::new(Rate::parse("1/min").unwrap());
let t0 = Instant::now();
assert!(limiter.check_at("a", t0).allowed);
assert!(!limiter.check_at("a", t0 + Duration::from_secs(30)).allowed);
assert!(limiter.check_at("a", t0 + Duration::from_secs(61)).allowed);
}
#[test]
fn clear_forgets_a_key() {
let limiter = RateLimiter::new(Rate::parse("1/min").unwrap());
let t0 = Instant::now();
assert!(limiter.check_at("a", t0).allowed);
assert!(!limiter.check_at("a", t0).allowed);
limiter.clear("a");
assert!(limiter.check_at("a", t0).allowed);
limiter.clear("never-seen");
}
}