use std::io::Error;
use std::pin::Pin;
use std::task::{Context, Poll};
use futures::{AsyncRead, Stream};
use pin_project::pin_project;
use tor_basic_utils::assert_val_impl_trait;
use tor_cell::relaycell::flow_ctrl::XonKbpsEwma;
use crate::stream::StreamTarget;
use crate::util::notify::NotifyReceiver;
#[derive(Debug)]
#[pin_project]
pub(crate) struct XonXoffReader<R, T: DrainRateNotifier = StreamTarget> {
#[pin]
ctrl: XonXoffReaderCtrl<T>,
#[pin]
reader: R,
pending_drain_rate_update: bool,
}
impl<R, T: DrainRateNotifier> XonXoffReader<R, T> {
pub(crate) fn new(ctrl: XonXoffReaderCtrl<T>, reader: R) -> Self {
Self {
ctrl,
reader,
pending_drain_rate_update: false,
}
}
pub(crate) fn inner(&self) -> &R {
&self.reader
}
pub(crate) fn inner_mut(&mut self) -> &mut R {
&mut self.reader
}
}
impl<R: AsyncRead + BufferIsEmpty, T: DrainRateNotifier> AsyncRead for XonXoffReader<R, T> {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<Result<usize, Error>> {
let mut self_ = self.project();
assert_val_impl_trait!(
self_.ctrl.drain_rate_request_stream,
futures::stream::FusedStream,
);
if let Poll::Ready(Some(())) = self_
.ctrl
.as_mut()
.project()
.drain_rate_request_stream
.poll_next(cx)
{
*self_.pending_drain_rate_update = true;
}
let res = self_.reader.as_mut().poll_read(cx, buf);
if *self_.pending_drain_rate_update && self_.reader.is_empty() {
self_
.ctrl
.drain_rate_notifier
.notify(XonKbpsEwma::Unlimited)?;
*self_.pending_drain_rate_update = false;
}
res
}
}
pub(crate) trait DrainRateNotifier {
fn notify(&mut self, rate: XonKbpsEwma) -> Result<(), Error>;
}
impl DrainRateNotifier for StreamTarget {
fn notify(&mut self, rate: XonKbpsEwma) -> Result<(), Error> {
self.drain_rate_update(rate).map_err(Into::into)
}
}
#[derive(Debug)]
#[pin_project]
pub(crate) struct XonXoffReaderCtrl<T: DrainRateNotifier = StreamTarget> {
#[pin]
drain_rate_request_stream: NotifyReceiver<DrainRateRequest>,
drain_rate_notifier: T,
}
impl<T: DrainRateNotifier> XonXoffReaderCtrl<T> {
pub(crate) fn new(
drain_rate_request_stream: NotifyReceiver<DrainRateRequest>,
drain_rate_notifier: T,
) -> Self {
Self {
drain_rate_request_stream,
drain_rate_notifier,
}
}
}
pub(crate) trait BufferIsEmpty {
fn is_empty(self: Pin<&mut Self>) -> bool;
}
#[derive(Debug)]
pub(crate) struct DrainRateRequest;
#[cfg(test)]
#[cfg(feature = "flowctl-cc")]
#[cfg(feature = "tokio")]
mod test {
#![allow(clippy::bool_assert_comparison)]
#![allow(clippy::clone_on_copy)]
#![allow(clippy::dbg_macro)]
#![allow(clippy::mixed_attributes_style)]
#![allow(clippy::print_stderr)]
#![allow(clippy::print_stdout)]
#![allow(clippy::single_char_pattern)]
#![allow(clippy::unwrap_used)]
#![allow(clippy::unchecked_time_subtraction)]
#![allow(clippy::useless_vec)]
#![allow(clippy::needless_pass_by_value)]
#![allow(clippy::string_slice)]
use super::*;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use crate::stream::flow_ctrl::params::FlowCtrlParameters;
use crate::stream::flow_ctrl::state::{FlowCtrlHooks, StreamRateLimit};
use crate::stream::flow_ctrl::xon_xoff::state::XonXoffFlowCtrl;
use crate::util::notify::NotifySender;
use futures::channel::mpsc::{self, TryRecvError};
use futures::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use tokio_crate::io::{DuplexStream, duplex};
use tokio_util::compat::{Compat, TokioAsyncReadCompatExt, TokioAsyncWriteCompatExt};
struct TestingDrainRateUpdates(mpsc::UnboundedSender<XonKbpsEwma>);
impl TestingDrainRateUpdates {
pub(crate) fn new(sender: mpsc::UnboundedSender<XonKbpsEwma>) -> Self {
Self(sender)
}
}
impl DrainRateNotifier for TestingDrainRateUpdates {
fn notify(&mut self, rate: XonKbpsEwma) -> Result<(), Error> {
self.0.unbounded_send(rate).unwrap();
Ok(())
}
}
#[pin_project::pin_project]
struct WriterWithLength<W> {
#[pin]
writer: W,
length: Arc<AtomicU64>,
}
#[pin_project::pin_project]
struct ReaderWithLength<R> {
#[pin]
reader: R,
length: Arc<AtomicU64>,
}
fn with_length<W, R>(writer: W, reader: R) -> (WriterWithLength<W>, ReaderWithLength<R>) {
let length = Arc::new(AtomicU64::new(0));
let writer = WriterWithLength {
writer,
length: Arc::clone(&length),
};
let reader = ReaderWithLength { reader, length };
(writer, reader)
}
impl<W> WriterWithLength<W> {
pub(crate) fn len(&self) -> u64 {
self.length.load(Ordering::Acquire)
}
}
impl<R> BufferIsEmpty for ReaderWithLength<R> {
fn is_empty(self: Pin<&mut Self>) -> bool {
self.length.load(Ordering::Acquire) == 0
}
}
impl<W: AsyncWrite> AsyncWrite for WriterWithLength<W> {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<std::io::Result<usize>> {
let self_ = self.project();
let rv = self_.writer.poll_write(cx, buf);
if let Poll::Ready(Ok(len)) = rv {
let len: u64 = len.try_into().expect("usize should fit into u64");
self_.length.fetch_add(len, Ordering::Release);
}
rv
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
self.project().writer.poll_flush(cx)
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
self.project().writer.poll_close(cx)
}
}
impl<R: AsyncRead> AsyncRead for ReaderWithLength<R> {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<std::io::Result<usize>> {
let self_ = self.project();
let rv = self_.reader.poll_read(cx, buf);
if let Poll::Ready(Ok(len)) = rv {
let len: u64 = len.try_into().expect("usize should fit into u64");
self_.length.fetch_sub(len, Ordering::Release);
}
rv
}
}
#[allow(clippy::type_complexity)]
fn init_flow_ctrl(
use_sidechannel_mitigations: bool,
) -> (
WriterWithLength<Compat<DuplexStream>>,
XonXoffReader<ReaderWithLength<Compat<DuplexStream>>, TestingDrainRateUpdates>,
mpsc::UnboundedReceiver<XonKbpsEwma>,
XonXoffFlowCtrl,
) {
let params = FlowCtrlParameters::defaults_for_tests();
let (rate_limit_tx, _rate_limit_rx) = postage::watch::channel_with(StreamRateLimit::MAX);
let mut drain_rate_request_tx = NotifySender::new_typed();
let drain_rate_request_rx = drain_rate_request_tx.subscribe();
let flow_ctrl = XonXoffFlowCtrl::new(
Arc::new(params),
use_sidechannel_mitigations,
rate_limit_tx,
drain_rate_request_tx,
);
let (drain_rate_sender, drain_rate_receiver) = mpsc::unbounded();
let drain_rate_updates = TestingDrainRateUpdates::new(drain_rate_sender);
let reader_ctrl = XonXoffReaderCtrl::new(drain_rate_request_rx, drain_rate_updates);
let (writer, reader) = duplex( usize::MAX);
let writer = writer.compat_write();
let reader = reader.compat();
let (writer, reader) = with_length(writer, reader);
let reader = XonXoffReader::new(reader_ctrl, reader);
(writer, reader, drain_rate_receiver, flow_ctrl)
}
async fn buffer_incoming_data(
writer: &mut WriterWithLength<impl AsyncWrite + Unpin>,
mut num_bytes: usize,
flow_ctrl: &mut XonXoffFlowCtrl,
) -> bool {
let mut wants_to_send_xoff = false;
while num_bytes > 0 {
let buf_size = num_bytes.min(100_000);
writer.write_all(&vec![0; buf_size]).await.unwrap();
num_bytes -= buf_size;
let xoff = flow_ctrl.maybe_send_xoff(writer.len() as usize).unwrap();
wants_to_send_xoff |= xoff.is_some();
}
wants_to_send_xoff
}
async fn read_incoming_data(mut reader: impl AsyncRead + Unpin, mut num_bytes: usize) {
while num_bytes > 0 {
let buf_size = num_bytes.min(100_000);
reader.read_exact(&mut vec![0; buf_size]).await.unwrap();
num_bytes -= buf_size;
}
}
#[test]
fn drain_rate_update() {
tor_rtmock::MockRuntime::test_with_various(|_rt| async move {
let (mut writer, mut reader, mut drain_rate_receiver, mut flow_ctrl) =
init_flow_ctrl( true);
let wants_to_send_xoff =
buffer_incoming_data(&mut writer, 10_000, &mut flow_ctrl).await;
assert!(!wants_to_send_xoff);
assert!(!reader.pending_drain_rate_update);
read_incoming_data(&mut reader, 10_000).await;
assert!(!reader.pending_drain_rate_update);
assert_eq!(drain_rate_receiver.try_recv(), Err(TryRecvError::Empty));
let wants_to_send_xoff =
buffer_incoming_data(&mut writer, 800_000, &mut flow_ctrl).await;
assert!(wants_to_send_xoff);
assert!(!reader.pending_drain_rate_update);
assert_eq!(drain_rate_receiver.try_recv(), Err(TryRecvError::Empty));
let _ = reader.read(&mut [0; 0]).await.unwrap();
assert!(reader.pending_drain_rate_update);
assert_eq!(drain_rate_receiver.try_recv(), Err(TryRecvError::Empty));
read_incoming_data(&mut reader, 700_000).await;
assert!(!Pin::new(reader.inner_mut()).is_empty());
assert!(reader.pending_drain_rate_update);
assert_eq!(drain_rate_receiver.try_recv(), Err(TryRecvError::Empty));
read_incoming_data(&mut reader, 100_000).await;
assert!(Pin::new(reader.inner_mut()).is_empty());
assert!(!reader.pending_drain_rate_update);
let xon_rate = drain_rate_receiver.try_recv().unwrap();
assert_eq!(xon_rate, XonKbpsEwma::Unlimited);
let xon = flow_ctrl
.maybe_send_xon(xon_rate, writer.len() as usize)
.unwrap()
.unwrap();
assert_eq!(xon.kbps_ewma(), xon_rate);
});
}
#[test]
fn drain_rate_update_then_buffer_refill() {
tor_rtmock::MockRuntime::test_with_various(|_rt| async move {
let (mut writer, mut reader, mut drain_rate_receiver, mut flow_ctrl) =
init_flow_ctrl( true);
let wants_to_send_xoff =
buffer_incoming_data(&mut writer, 800_000, &mut flow_ctrl).await;
assert!(wants_to_send_xoff);
read_incoming_data(&mut reader, 700_000).await;
assert!(reader.pending_drain_rate_update);
read_incoming_data(&mut reader, 100_000).await;
assert!(Pin::new(reader.inner_mut()).is_empty());
assert!(!reader.pending_drain_rate_update);
let wants_to_send_xoff =
buffer_incoming_data(&mut writer, 800_000, &mut flow_ctrl).await;
assert!(!wants_to_send_xoff);
let xon_rate = drain_rate_receiver.try_recv().unwrap();
assert_eq!(xon_rate, XonKbpsEwma::Unlimited);
let xon = flow_ctrl
.maybe_send_xon(xon_rate, writer.len() as usize)
.unwrap();
assert!(xon.is_none());
assert!(!reader.pending_drain_rate_update);
let _ = reader.read(&mut [0; 0]).await.unwrap();
assert!(reader.pending_drain_rate_update);
});
}
}