1#![forbid(unsafe_code)]
5#![warn(
6 anonymous_parameters,
7 clippy::needless_borrow,
8 missing_docs,
9 missing_copy_implementations,
10 missing_debug_implementations,
11 nonstandard_style,
12 rust_2018_idioms,
13 single_use_lifetimes,
14 trivial_casts,
15 trivial_numeric_casts,
16 unreachable_pub,
17 unused_extern_crates,
18 unused_qualifications,
19 variant_size_differences
20)]
21
22use tokio::sync::mpsc::{channel, Receiver, Sender};
23use tokio::sync::oneshot;
24use tokio::time::sleep;
25use tokio::time::{Duration, Instant};
26
27#[derive(Clone, Debug)]
30pub struct RateLimiter {
31 sender: Sender<Message>,
32}
33
34impl RateLimiter {
35 pub fn new(count: usize, duration: Duration) -> Self {
52 let (sender, receiver) = channel(count);
53 RateLimiter::spawn_receiver(receiver, count, duration);
54 Self { sender }
55 }
56
57 pub async fn wait(&self) {
72 let (s, r) = oneshot::channel::<()>();
73 self.sender
74 .send(Message { sender: s })
75 .await
76 .expect("unable to send to arl channel");
77 r.await.expect("unable to read from arl channel");
78 }
79
80 fn spawn_receiver(mut receiver: Receiver<Message>, count: usize, duration: Duration) {
81 tokio::spawn(async move {
82 let mut queue = Vec::with_capacity(count);
83 while let Some(message) = receiver.recv().await {
84 while !queue.is_empty() && queue[0] <= Instant::now() {
85 queue.remove(0);
86 }
87 if queue.len() > count {
88 let alarm = queue.remove(0);
89 sleep(alarm - Instant::now()).await;
90 }
91 message
92 .sender
93 .send(())
94 .expect("unable to send to arl client channel");
95 queue.push(Instant::now() + duration);
96 }
97 });
98 }
99}
100
101#[derive(Debug)]
102struct Message {
103 sender: oneshot::Sender<()>,
104}
105
106#[cfg(test)]
107mod test {
108 use crate::RateLimiter;
109 use std::time::Duration;
110 use tokio::time::Instant;
111
112 #[tokio::test]
113 async fn up_to_limit_execute_quickly() {
114 const COUNT: usize = 10;
115 let limiter = RateLimiter::new(COUNT, Duration::from_secs(60));
116 let start = Instant::now();
117 for _ in 0..COUNT {
118 limiter.wait().await;
119 }
120 let elapsed = start.elapsed();
121 assert!(elapsed < Duration::from_millis(10));
122 }
123
124 #[tokio::test]
125 async fn over_limit_execute_proportionally() {
126 const COUNT: usize = 10;
127 const CHUNKS: usize = 3;
128 let limiter = RateLimiter::new(COUNT, Duration::from_secs(1));
129 let start = Instant::now();
130 for _ in 0..CHUNKS {
131 for _ in 0..COUNT {
132 limiter.wait().await;
133 }
134 }
135 let elapsed = start.elapsed();
136 assert!(elapsed > Duration::from_secs(CHUNKS as u64 - 1));
141 }
142}