use core::pin::Pin;
use core::task::Context;
use core::time::Duration;
use std::time::Instant;
use async_io::Timer;
use futures::task::Poll;
use futures::{pin_mut, ready, Future, Sink, Stream};
use pin_project_lite::pin_project;
pin_project! {
pub struct RateLimitedStream<S> {
#[pin]
inner: S,
rate_limiter: RateLimiter,
state: PollState,
}
}
impl<S> RateLimitedStream<S> {
#[inline]
pub fn new(stream: S, rate_limiter: RateLimiter) -> Self {
Self { inner: stream, rate_limiter, state: PollState::Start }
}
}
impl<S: Stream> Stream for RateLimitedStream<S> {
type Item = S::Item;
#[inline]
fn poll_next(
self: Pin<&mut Self>,
ctx: &mut Context<'_>,
) -> Poll<Option<Self::Item>> {
let this = self.project();
loop {
match &mut *this.state {
PollState::Start => {
let sleep = this.rate_limiter.wait();
*this.state = PollState::PollLimiter(sleep);
},
PollState::PollLimiter(sleep) => {
pin_mut!(sleep);
ready!(sleep.poll(ctx));
*this.state = PollState::PollInner;
},
PollState::PollInner => {
let item = ready!(this.inner.poll_next(ctx));
*this.state = PollState::Start;
return Poll::Ready(item);
},
}
}
}
}
pin_project! {
pub struct RateLimitedSink<S> {
#[pin]
inner: S,
rate_limiter: RateLimiter,
state: PollState,
}
}
impl<S> RateLimitedSink<S> {
#[inline]
pub fn new(sink: S, rate_limiter: RateLimiter) -> Self {
Self { inner: sink, rate_limiter, state: PollState::Start }
}
}
impl<T, S: Sink<T>> Sink<T> for RateLimitedSink<S> {
type Error = S::Error;
#[inline]
fn poll_ready(
self: Pin<&mut Self>,
ctx: &mut Context<'_>,
) -> Poll<Result<(), Self::Error>> {
let this = self.project();
loop {
match &mut *this.state {
PollState::Start => {
let sleep = this.rate_limiter.wait();
*this.state = PollState::PollLimiter(sleep);
},
PollState::PollLimiter(sleep) => {
pin_mut!(sleep);
ready!(sleep.poll(ctx));
*this.state = PollState::PollInner;
},
PollState::PollInner => {
let item = ready!(this.inner.poll_ready(ctx));
*this.state = PollState::Start;
return Poll::Ready(item);
},
}
}
}
#[inline]
fn start_send(self: Pin<&mut Self>, item: T) -> Result<(), Self::Error> {
self.project().inner.start_send(item)
}
#[inline]
fn poll_flush(
self: Pin<&mut Self>,
ctx: &mut Context<'_>,
) -> Poll<Result<(), Self::Error>> {
self.project().inner.poll_flush(ctx)
}
#[inline]
fn poll_close(
self: Pin<&mut Self>,
ctx: &mut Context<'_>,
) -> Poll<Result<(), Self::Error>> {
self.project().inner.poll_close(ctx)
}
}
enum PollState {
Start,
PollLimiter(Sleep),
PollInner,
}
#[derive(Debug)]
pub struct RateLimiter {
window_interval: Duration,
events_per_window: usize,
current_window_start: Option<Instant>,
events_processed_in_current_window: usize,
}
impl RateLimiter {
#[inline]
pub fn new(window_interval: Duration, events_per_window: usize) -> Self {
assert!(events_per_window > 0);
Self {
window_interval,
events_per_window,
current_window_start: None,
events_processed_in_current_window: 0,
}
}
#[inline]
fn start_window(&mut self, start: Instant) {
self.current_window_start = Some(start);
self.events_processed_in_current_window = 1;
}
#[inline]
fn wait(&mut self) -> Sleep {
let now = Instant::now();
let Some(window_start) = self.current_window_start else {
self.start_window(now);
return Sleep::Done;
};
let window_end = window_start + self.window_interval;
if window_end <= now {
self.start_window(now);
return Sleep::Done;
}
if self.events_processed_in_current_window < self.events_per_window {
self.events_processed_in_current_window += 1;
return Sleep::Done;
}
self.start_window(window_end);
Sleep::until(window_end)
}
}
pin_project! {
#[derive(Debug)]
#[project = SleepProj]
enum Sleep {
Sleeps {
#[pin]
inner: Timer,
},
Done,
}
}
impl Sleep {
#[inline]
fn until(instant: Instant) -> Self {
Self::Sleeps { inner: Timer::at(instant) }
}
}
impl Future for Sleep {
type Output = ();
#[inline]
fn poll(
self: Pin<&mut Self>,
ctx: &mut Context<'_>,
) -> Poll<Self::Output> {
match self.project() {
SleepProj::Sleeps { inner } => inner.poll(ctx).map(|_| ()),
SleepProj::Done => Poll::Ready(()),
}
}
}
#[cfg(test)]
mod tests {
use async_stream::stream;
use futures::sink::drain;
use futures::{FutureExt, SinkExt, StreamExt};
use tokio::pin;
use tokio::time::sleep;
use super::*;
pin_project! {
struct SleepsOnSecondSink<S> {
#[pin]
sink: S,
state: SleepsOnSecondState,
sleep_on_second: Duration,
}
}
enum SleepsOnSecondState {
WaitingForFirst,
WaitingForSecond { sleep: Pin<Box<tokio::time::Sleep>> },
SentSecond,
}
impl<S> SleepsOnSecondSink<S> {
fn new(sink: S, sleep_on_second: Duration) -> Self {
Self {
sink,
state: SleepsOnSecondState::WaitingForFirst,
sleep_on_second,
}
}
}
impl<T, S: Sink<T>> Sink<T> for SleepsOnSecondSink<S> {
type Error = S::Error;
#[inline]
fn poll_ready(
self: Pin<&mut Self>,
ctx: &mut Context<'_>,
) -> Poll<Result<(), Self::Error>> {
let this = self.project();
match this.state {
SleepsOnSecondState::WaitingForFirst => {
let sleep = Box::pin(sleep(*this.sleep_on_second));
*this.state =
SleepsOnSecondState::WaitingForSecond { sleep };
Poll::Ready(Ok(()))
},
SleepsOnSecondState::WaitingForSecond { ref mut sleep } => {
match sleep.as_mut().poll(ctx) {
Poll::Ready(()) => {
*this.state = SleepsOnSecondState::SentSecond;
Poll::Ready(Ok(()))
},
Poll::Pending => Poll::Pending,
}
},
SleepsOnSecondState::SentSecond => Poll::Ready(Ok(())),
}
}
#[inline]
fn start_send(
self: Pin<&mut Self>,
item: T,
) -> Result<(), Self::Error> {
let this = self.project();
this.sink.start_send(item)
}
#[inline]
fn poll_flush(
self: Pin<&mut Self>,
ctx: &mut Context<'_>,
) -> Poll<Result<(), Self::Error>> {
let this = self.project();
this.sink.poll_flush(ctx)
}
#[inline]
fn poll_close(
self: Pin<&mut Self>,
ctx: &mut Context<'_>,
) -> Poll<Result<(), Self::Error>> {
let this = self.project();
this.sink.poll_close(ctx)
}
}
fn rate_limited_sink<T, S: Sink<T>>(
sink: S,
window_interval: Duration,
events_per_window: usize,
) -> RateLimitedSink<S> {
let rate_limiter =
RateLimiter::new(window_interval, events_per_window);
RateLimitedSink::new(sink, rate_limiter)
}
#[tokio::test]
async fn rate_limited_sink_first_call() {
let sink = rate_limited_sink(drain(), Duration::from_secs(1), 1);
pin!(sink);
assert_eq!(sink.send(()).now_or_never().unwrap(), Ok(()));
}
#[tokio::test]
async fn rate_limited_sink_first_n_calls() {
let sink = rate_limited_sink(drain(), Duration::from_millis(100), 3);
pin!(sink);
assert_eq!(sink.send(()).now_or_never().unwrap(), Ok(()));
assert_eq!(sink.send(()).now_or_never().unwrap(), Ok(()));
assert_eq!(sink.send(()).now_or_never().unwrap(), Ok(()));
}
#[tokio::test]
async fn rate_limited_sink_n_plus_one_call() {
let window_duration = Duration::from_millis(100);
let sink = rate_limited_sink(drain(), window_duration, 3);
pin!(sink);
assert_eq!(sink.send(()).now_or_never().unwrap(), Ok(()));
assert_eq!(sink.send(()).now_or_never().unwrap(), Ok(()));
assert_eq!(sink.send(()).now_or_never().unwrap(), Ok(()));
assert!(sink.send(()).now_or_never().is_none());
sleep(window_duration).await;
assert_eq!(sink.send(()).now_or_never().unwrap(), Ok(()));
}
#[tokio::test]
async fn rate_limited_sink_wait_for_underlying() {
let wait_for_second = Duration::from_millis(100);
let sleep_on_second =
SleepsOnSecondSink::new(drain(), wait_for_second);
let sink =
rate_limited_sink(sleep_on_second, Duration::from_millis(200), 5);
pin!(sink);
assert_eq!(sink.send(()).now_or_never().unwrap(), Ok(()));
let second_send = sink.send(());
pin!(second_send);
for _ in 0..10 {
assert!(second_send.as_mut().now_or_never().is_none());
}
sleep(wait_for_second).await;
assert_eq!(sink.send(()).now_or_never().unwrap(), Ok(()));
assert_eq!(sink.send(()).now_or_never().unwrap(), Ok(()));
}
fn rate_limited_stream<T, I: IntoIterator<Item = T>>(
iter: I,
window_interval: Duration,
events_per_window: usize,
) -> RateLimitedStream<tokio_stream::Iter<I::IntoIter>> {
let stream = tokio_stream::iter(iter);
let rate_limiter =
RateLimiter::new(window_interval, events_per_window);
RateLimitedStream::new(stream, rate_limiter)
}
#[tokio::test]
async fn rate_limited_stream_first_call() {
let stream = rate_limited_stream([1], Duration::from_secs(1), 1);
pin!(stream);
assert_eq!(stream.next().now_or_never().unwrap(), Some(1));
}
#[tokio::test]
async fn rate_limited_stream_first_n_calls() {
let stream =
rate_limited_stream([1, 2, 3], Duration::from_millis(100), 3);
pin!(stream);
let one = stream.next().now_or_never().unwrap().unwrap();
assert_eq!(one, 1);
let two = stream.next().now_or_never().unwrap().unwrap();
assert_eq!(two, 2);
let three = stream.next().now_or_never().unwrap().unwrap();
assert_eq!(three, 3);
}
#[tokio::test]
async fn rate_limited_stream_n_plus_one_call() {
let window_duration = Duration::from_millis(100);
let stream = rate_limited_stream([1, 2, 3, 4], window_duration, 3);
pin!(stream);
let _ = stream.next().now_or_never().unwrap().unwrap();
let _ = stream.next().now_or_never().unwrap().unwrap();
let _ = stream.next().now_or_never().unwrap().unwrap();
assert!(stream.next().now_or_never().is_none());
sleep(window_duration).await;
let four = stream.next().now_or_never().unwrap().unwrap();
assert_eq!(four, 4);
}
#[tokio::test]
async fn rate_limited_stream_wait_for_underlying_only() {
let wait_for_two = Duration::from_millis(100);
let stream = stream! {
yield 1;
sleep(wait_for_two).await;
yield 2;
yield 3;
};
let rate_limiter = RateLimiter::new(Duration::from_millis(200), 5);
let stream = RateLimitedStream::new(stream, rate_limiter);
pin!(stream);
let _one = stream.next().now_or_never().unwrap().unwrap();
for _ in 0..10 {
assert!(stream.next().now_or_never().is_none());
}
sleep(wait_for_two).await;
let _two = stream.next().now_or_never().unwrap().unwrap();
let _three = stream.next().now_or_never().unwrap().unwrap();
assert_eq!(stream.next().now_or_never().unwrap(), None);
}
#[tokio::test]
async fn rate_limited_stream_wait_for_rate_limiter_and_underlying() {
let wait_for_3 = Duration::from_millis(200);
let stream = stream! {
yield 1;
yield 2;
sleep(wait_for_3).await;
yield 3;
};
let window = wait_for_3 / 2;
let rate_limiter = RateLimiter::new(window, 2);
let stream = RateLimitedStream::new(stream, rate_limiter);
pin!(stream);
let _one = stream.next().now_or_never().unwrap().unwrap();
let _two = stream.next().now_or_never().unwrap().unwrap();
for _ in 0..10 {
assert!(stream.next().now_or_never().is_none());
}
sleep(window).await;
for _ in 0..10 {
assert!(stream.next().now_or_never().is_none());
}
sleep(wait_for_3 / 2).await;
for _ in 0..10 {
assert!(stream.next().now_or_never().is_none());
}
sleep(wait_for_3 / 2).await;
let _three = stream.next().now_or_never().unwrap().unwrap();
}
}