#![forbid(unsafe_code)]
#![warn(
anonymous_parameters,
clippy::needless_borrow,
missing_docs,
missing_copy_implementations,
missing_debug_implementations,
nonstandard_style,
rust_2018_idioms,
single_use_lifetimes,
trivial_casts,
trivial_numeric_casts,
unreachable_pub,
unused_extern_crates,
unused_qualifications,
variant_size_differences
)]
use tokio::sync::mpsc::{channel, Receiver, Sender};
use tokio::sync::oneshot;
use tokio::time::sleep;
use tokio::time::{Duration, Instant};
#[derive(Clone, Debug)]
pub struct RateLimiter {
sender: Sender<Message>,
}
impl RateLimiter {
pub fn new(count: usize, duration: Duration) -> Self {
let (sender, receiver) = channel(count);
RateLimiter::spawn_receiver(receiver, count, duration);
Self { sender }
}
pub async fn wait(&self) {
let (s, r) = oneshot::channel::<()>();
self.sender
.send(Message { sender: s })
.await
.expect("unable to send to arl channel");
r.await.expect("unable to read from arl channel");
}
fn spawn_receiver(mut receiver: Receiver<Message>, count: usize, duration: Duration) {
tokio::spawn(async move {
let mut queue = Vec::with_capacity(count);
while let Some(message) = receiver.recv().await {
while !queue.is_empty() && queue[0] <= Instant::now() {
queue.remove(0);
}
if queue.len() > count {
let alarm = queue.remove(0);
sleep(alarm - Instant::now()).await;
}
message
.sender
.send(())
.expect("unable to send to arl client channel");
queue.push(Instant::now() + duration);
}
});
}
}
#[derive(Debug)]
struct Message {
sender: oneshot::Sender<()>,
}
#[cfg(test)]
mod test {
use crate::RateLimiter;
use std::time::Duration;
use tokio::time::Instant;
#[tokio::test]
async fn up_to_limit_execute_quickly() {
const COUNT: usize = 10;
let limiter = RateLimiter::new(COUNT, Duration::from_secs(60));
let start = Instant::now();
for _ in 0..COUNT {
limiter.wait().await;
}
let elapsed = start.elapsed();
assert!(elapsed < Duration::from_millis(10));
}
#[tokio::test]
async fn over_limit_execute_proportionally() {
const COUNT: usize = 10;
const CHUNKS: usize = 3;
let limiter = RateLimiter::new(COUNT, Duration::from_secs(1));
let start = Instant::now();
for _ in 0..CHUNKS {
for _ in 0..COUNT {
limiter.wait().await;
}
}
let elapsed = start.elapsed();
assert!(elapsed > Duration::from_secs(CHUNKS as u64 - 1));
}
}