use std::pin::Pin;
use std::task::{Context, Poll};
use futures::AsyncWrite;
use futures::io::Error;
use futures::stream::FusedStream;
use tor_rtcompat::SleepProvider;
use super::writer::{RateLimitedWriter, RateLimitedWriterConfig};
#[derive(educe::Educe)]
#[educe(Debug)]
#[pin_project::pin_project]
pub(crate) struct DynamicRateLimitedWriter<W: AsyncWrite, S, P: SleepProvider> {
#[pin]
writer: RateLimitedWriter<W, P>,
#[educe(Debug(ignore))]
#[pin]
updates: S,
}
impl<W, S, P> DynamicRateLimitedWriter<W, S, P>
where
W: AsyncWrite,
P: SleepProvider,
{
pub(crate) fn new(writer: RateLimitedWriter<W, P>, updates: S) -> Self {
Self { writer, updates }
}
pub(crate) fn inner(&self) -> &W {
self.writer.inner()
}
}
impl<W, S, P> AsyncWrite for DynamicRateLimitedWriter<W, S, P>
where
W: AsyncWrite,
S: FusedStream<Item = RateLimitedWriterConfig>,
P: SleepProvider,
{
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, Error>> {
let mut self_ = self.as_mut().project();
let mut iters = 0;
while let Poll::Ready(Some(config)) = self_.updates.as_mut().poll_next(cx) {
let now = self_.writer.sleep_provider().now();
self_.writer.adjust(now, &config);
iters += 1;
if iters > 100_000 {
const MSG: &str =
"possible infinite loop in `DynamicRateLimitedWriter::poll_write`";
tracing::debug!(MSG);
return Poll::Ready(Err(Error::other(MSG)));
}
}
self_.writer.poll_write(cx, buf)
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
self.project().writer.poll_flush(cx)
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
self.project().writer.poll_close(cx)
}
}
#[cfg(feature = "tokio")]
mod tokio_impl {
use super::*;
use tokio_crate::io::AsyncWrite as TokioAsyncWrite;
use tokio_util::compat::FuturesAsyncWriteCompatExt;
use std::io::Result as IoResult;
impl<W, S, P> TokioAsyncWrite for DynamicRateLimitedWriter<W, S, P>
where
W: AsyncWrite,
S: FusedStream<Item = RateLimitedWriterConfig>,
P: SleepProvider,
{
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<IoResult<usize>> {
TokioAsyncWrite::poll_write(Pin::new(&mut self.compat_write()), cx, buf)
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IoResult<()>> {
TokioAsyncWrite::poll_flush(Pin::new(&mut self.compat_write()), cx)
}
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IoResult<()>> {
TokioAsyncWrite::poll_shutdown(Pin::new(&mut self.compat_write()), cx)
}
}
}
#[cfg(test)]
mod test {
#![allow(clippy::unwrap_used)]
use super::*;
use std::num::NonZero;
use std::time::Duration;
use futures::{AsyncReadExt, AsyncWriteExt, FutureExt, SinkExt};
use tor_rtcompat::SpawnExt;
#[cfg(feature = "tokio")]
use tokio_util::compat::{TokioAsyncReadCompatExt, TokioAsyncWriteCompatExt};
#[cfg(feature = "tokio")]
#[test]
fn alternating_on_off() {
tor_rtmock::MockRuntime::test_with_various(|rt| async move {
let rt_clone = rt.clone();
rt.spawn(async move {
for _ in 0..8_000 {
rt_clone.progress_until_stalled().await;
rt_clone.advance_by(Duration::from_millis(1)).await;
}
})
.unwrap();
let config = RateLimitedWriterConfig {
rate: 0,
burst: 0,
wake_when_bytes_available: NonZero::new(10).unwrap(),
};
let (writer, reader) = tokio_crate::io::duplex( 1000);
let writer = writer.compat_write();
let mut reader = reader.compat();
let writer = RateLimitedWriter::new(writer, &config, rt.clone());
let (mut rate_tx, rate_rx) = futures::channel::mpsc::unbounded();
let mut writer = DynamicRateLimitedWriter::new(writer, rate_rx);
const UPDATE_INTERVAL: Duration = Duration::from_millis(841);
let rt_clone = rt.clone();
rt.spawn(async move {
for rate in [100, 0, 200, 0, 400, 0] {
rt_clone.sleep(UPDATE_INTERVAL).await;
let mut config = config.clone();
config.rate = rate;
config.burst = rate;
rate_tx.send(config).now_or_never().unwrap().unwrap();
}
})
.unwrap();
rt.spawn(async move {
while writer.write(&[0; 100]).await.is_ok() {}
})
.unwrap();
let res_unwrap = Result::unwrap;
let mut buf = vec![0; 1000];
let buf = &mut buf;
rt.sleep(Duration::from_millis(1)).await;
rt.sleep(UPDATE_INTERVAL).await;
assert_eq!(None, reader.read(buf).now_or_never().map(res_unwrap));
rt.sleep(UPDATE_INTERVAL).await;
assert_eq!(Some(80), reader.read(buf).now_or_never().map(res_unwrap));
rt.sleep(UPDATE_INTERVAL).await;
assert_eq!(None, reader.read(buf).now_or_never().map(res_unwrap));
rt.sleep(UPDATE_INTERVAL).await;
assert_eq!(Some(160), reader.read(buf).now_or_never().map(res_unwrap));
rt.sleep(UPDATE_INTERVAL).await;
assert_eq!(None, reader.read(buf).now_or_never().map(res_unwrap));
rt.sleep(UPDATE_INTERVAL).await;
assert_eq!(Some(330), reader.read(buf).now_or_never().map(res_unwrap));
rt.sleep(UPDATE_INTERVAL).await;
assert_eq!(None, reader.read(buf).now_or_never().map(res_unwrap));
});
}
}