use std::collections::HashMap;
use std::hash::Hash;
use std::time::{Instant, Duration};
#[derive(Debug)]
enum LimitStatus {
Unlimited,
ClearsAt(Instant, Duration),
}
#[derive(Debug)]
pub struct RateLimiter<Id: Hash + Eq + PartialEq> {
data: HashMap<Id, LimitStatus>,
incr_step: Duration,
limit_after: Duration,
}
impl<Id: Hash + Eq + PartialEq> RateLimiter<Id> {
pub fn new(allowed: u64, interval: u64) -> Self {
RateLimiter {
data: HashMap::new(),
incr_step: Duration::from_millis(
(interval as f64 / allowed as f64 * 1000.0) as u64
),
limit_after: Duration::new(interval, 0),
}
}
pub fn is_limited(&mut self, id: &Id) -> Option<Duration> {
if let Some(status) = self.data.get_mut(id) {
let mut reset = false;
let result = match *status {
LimitStatus::ClearsAt(ref start, ref duration) => {
let now = Instant::now();
let delta = now.duration_since(*start);
if delta > *duration {
reset = true;
None
} else if *duration - delta > self.limit_after {
Some(*duration - delta - self.limit_after)
} else {
None
}
},
LimitStatus::Unlimited => None,
};
if reset {
*status = LimitStatus::Unlimited;
}
result
} else {
None
}
}
pub fn increment(&mut self, id: Id) {
let item = self.data.entry(id).or_insert(LimitStatus::Unlimited);
if let LimitStatus::ClearsAt(start, duration) = *item {
*item = LimitStatus::ClearsAt(start, duration + self.incr_step);
} else {
*item = LimitStatus::ClearsAt(Instant::now(), self.incr_step);
}
}
}
#[cfg(test)]
mod tests {
use std::thread;
use std::time::Duration;
use super::RateLimiter;
#[test]
fn test_rate_limiter() {
let mut limiter = RateLimiter::<u8>::new(10, 1);
for _ in 0..10 {
limiter.increment(1);
assert!(limiter.is_limited(&1).is_none());
}
limiter.increment(1);
assert!(limiter.is_limited(&1).is_some());
}
#[test]
#[ignore]
fn test_rate_limiter_slow() {
let mut limiter = RateLimiter::<u8>::new(2, 1);
for _ in 0..4 {
limiter.increment(1);
assert!(limiter.is_limited(&1).is_none());
thread::sleep(Duration::from_millis(500));
}
for _ in 0..2 {
limiter.increment(1);
assert!(limiter.is_limited(&1).is_none());
}
limiter.increment(1);
assert!(limiter.is_limited(&1).is_some());
thread::sleep(Duration::from_secs(1));
assert!(limiter.is_limited(&1).is_none());
}
}