use crate::frame::{FinalizedFrame, Frame};
use crate::loom::{Arc, AtomicBool, AtomicU32, AtomicWaker, Ordering};
use bytes::{Buf, Bytes};
use std::io;
use std::io::ErrorKind::BrokenPipe;
use std::pin::Pin;
use std::task::{Context, Poll, ready};
use tokio::io::{AsyncBufRead, AsyncRead, AsyncWrite, BufReader};
use tokio::sync::mpsc;
use tracing::{debug, trace, warn};
pub struct MuxStream {
pub(super) frame_rx: mpsc::Receiver<Bytes>,
pub(super) flow_id: u32,
pub dest_host: Bytes,
pub dest_port: u16,
pub(super) finish_sent: Arc<AtomicBool>,
pub(super) psh_send_remaining: Arc<AtomicU32>,
pub(super) psh_recvd_since: u32,
pub(super) writer_waker: Arc<AtomicWaker>,
pub(super) buf: Bytes,
pub(super) frame_tx: mpsc::UnboundedSender<FinalizedFrame>,
pub(super) dropped_ports_tx: mpsc::UnboundedSender<u32>,
pub(super) rwnd_threshold: u32,
}
impl std::fmt::Debug for MuxStream {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("MuxStream")
.field("flow_id", &format_args!("{:08x}", self.flow_id))
.field("dest_host", &self.dest_host)
.field("dest_port", &self.dest_port)
.field("finish_sent", &self.finish_sent)
.field("psh_send_remaining", &self.psh_send_remaining)
.field("psh_recvd_since", &self.psh_recvd_since)
.field("rwnd_threshold", &self.rwnd_threshold)
.field("buf.len", &self.buf.len())
.finish_non_exhaustive()
}
}
impl Drop for MuxStream {
fn drop(&mut self) {
self.dropped_ports_tx
.send(self.flow_id)
.ok();
}
}
impl AsyncRead for MuxStream {
#[tracing::instrument(skip_all, level = "trace")]
#[inline]
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> Poll<io::Result<()>> {
let got = ready!(self.as_mut().poll_fill_buf(cx))?;
let amt = std::cmp::min(got.len(), buf.remaining());
buf.put_slice(&got[..amt]);
self.consume(amt);
Poll::Ready(Ok(()))
}
}
impl AsyncWrite for MuxStream {
#[tracing::instrument(skip_all, level = "trace", fields(flow_id = self.flow_id))]
#[inline]
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
ready!(self.poll_obtain_write_permission(cx))?;
let frame = Frame::new_push(self.flow_id, buf).finalize();
self.frame_tx.send(frame).or(Err(BrokenPipe))?;
trace!("sent a frame");
Poll::Ready(Ok(buf.len()))
}
#[tracing::instrument(skip(_cx), level = "trace", fields(flow_id = self.flow_id))]
#[inline]
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
}
#[tracing::instrument(skip(_cx), level = "trace", fields(flow_id = self.flow_id))]
#[inline]
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Poll::Ready(self.shutdown_inner())
}
}
impl AsyncBufRead for MuxStream {
#[tracing::instrument(skip_all, level = "trace", fields(flow_id = self.flow_id))]
#[inline]
fn poll_fill_buf(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
if self.buf.is_empty() {
trace!("polling the stream");
let Some(next) = ready!(self.frame_rx.poll_recv(cx)) else {
trace!("stream has been closed");
self.frame_rx.close();
debug_assert!(self.frame_rx.try_recv().is_err());
return Poll::Ready(Ok(&[]));
};
debug_assert!(!next.is_empty());
self.buf = next;
self.increment_psh_recvd_since();
} else {
trace!("using the remaining buffer");
}
Poll::Ready(Ok(&self.get_mut().buf))
}
#[inline]
fn consume(mut self: Pin<&mut Self>, amt: usize) {
self.buf.advance(amt);
}
}
impl MuxStream {
#[tracing::instrument(skip_all, level = "trace", fields(count = self.psh_recvd_since + 1))]
#[inline]
fn increment_psh_recvd_since(&mut self) {
trace!("received a frame");
let new = self.psh_recvd_since + 1;
self.psh_recvd_since = new;
if new >= self.rwnd_threshold {
self.psh_recvd_since = 0;
trace!("sending `Acknowledge` of {new} frames");
self.frame_tx
.send(Frame::new_acknowledge(self.flow_id, new).finalize())
.ok();
}
}
#[tracing::instrument(skip_all, level = "trace")]
#[inline]
fn poll_obtain_write_permission(&self, cx: &Context<'_>) -> Poll<io::Result<()>> {
if self.finish_sent.load(Ordering::Relaxed) {
debug!("stream has been closed, returning `BrokenPipe`");
return Poll::Ready(Err(BrokenPipe.into()));
}
loop {
let original = self.psh_send_remaining.load(Ordering::Acquire);
trace!("congestion window: {original}");
if original == 0 {
debug!("waiting for `Acknowledge`");
self.writer_waker.register(cx.waker());
return Poll::Pending;
}
let new = original - 1;
if self
.psh_send_remaining
.compare_exchange_weak(original, new, Ordering::AcqRel, Ordering::Relaxed)
.is_ok()
{
break;
}
trace!("congestion window race condition, retrying");
}
Poll::Ready(Ok(()))
}
#[inline]
fn shutdown_inner(&self) -> io::Result<()> {
if self.finish_sent.swap(true, Ordering::AcqRel) {
return Ok(());
}
self.frame_tx
.send(Frame::new_finish(self.flow_id).finalize())
.or(Err(BrokenPipe))?;
Ok(())
}
#[inline]
pub fn into_copy_bidirectional<RW>(self, other: RW) -> CopyBidirectional<BufReader<RW>>
where
RW: AsyncRead + AsyncWrite + Unpin,
{
let other_bufreader = BufReader::new(other);
self.into_copy_bidirectional_with_buf(other_bufreader)
}
#[inline]
pub const fn into_copy_bidirectional_with_buf<BRW>(self, other: BRW) -> CopyBidirectional<BRW>
where
BRW: AsyncBufRead + AsyncWrite + Unpin,
{
CopyBidirectional {
us: self,
other,
read_state: ReadState::Transferring(0),
write_state: WriteState::Transferring(0),
}
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
enum ReadState {
Transferring(u64),
ShuttingDown(u64),
Done(u64),
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
enum WriteState {
Transferring(u64),
Done(u64),
}
#[derive(Debug)]
pub struct CopyBidirectional<RW> {
us: MuxStream,
other: RW,
read_state: ReadState,
write_state: WriteState,
}
impl<RW> CopyBidirectional<RW>
where
RW: AsyncBufRead + AsyncWrite + Unpin,
{
#[tracing::instrument(skip_all, level = "trace")]
#[inline]
fn poll_read_us(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<u64>> {
match self.read_state {
ReadState::Transferring(mut read_amt) => {
loop {
trace!("polling us");
let new_buf = ready!(Pin::new(&mut self.us).poll_fill_buf(cx))?;
if new_buf.is_empty() {
self.read_state = ReadState::ShuttingDown(read_amt);
ready!(Pin::new(&mut self.other).poll_shutdown(cx))?;
self.read_state = ReadState::Done(read_amt);
break Poll::Ready(Ok(read_amt));
}
let processed = ready!(Pin::new(&mut self.other).poll_write(cx, new_buf))?;
Pin::new(&mut self.us).consume(processed);
read_amt += processed as u64;
self.read_state = ReadState::Transferring(read_amt);
}
}
ReadState::ShuttingDown(read_amt) => {
ready!(Pin::new(&mut self.other).poll_shutdown(cx))?;
self.read_state = ReadState::Done(read_amt);
Poll::Ready(Ok(read_amt))
}
ReadState::Done(read_amt) => Poll::Ready(Ok(read_amt)),
}
}
#[inline]
fn poll_write_us(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<u64>> {
match self.write_state {
WriteState::Transferring(mut written_amt) => {
loop {
let mut other = Pin::new(&mut self.other);
trace!("polling other");
let new_buf = if let Poll::Ready(res) = other.as_mut().poll_fill_buf(cx) {
res?
} else {
trace!("flushing other");
ready!(other.as_mut().poll_flush(cx))?;
break Poll::Pending;
};
if new_buf.is_empty() {
self.us.shutdown_inner()?;
self.write_state = WriteState::Done(written_amt);
break Poll::Ready(Ok(written_amt));
}
ready!(self.us.poll_obtain_write_permission(cx))?;
let frame = Frame::new_push(self.us.flow_id, new_buf).finalize();
self.us.frame_tx.send(frame).or(Err(BrokenPipe))?;
let processed = new_buf.len();
Pin::new(&mut self.other).consume(processed);
written_amt += processed as u64;
self.write_state = WriteState::Transferring(written_amt);
}
}
WriteState::Done(written_amt) => Poll::Ready(Ok(written_amt)),
}
}
}
impl<RW> Future for CopyBidirectional<RW>
where
RW: AsyncBufRead + AsyncWrite + Unpin,
{
type Output = io::Result<(u64, u64)>;
#[tracing::instrument(skip_all, level = "trace", fields(flow_id = self.us.flow_id))]
#[inline]
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let r = self.poll_read_us(cx);
let w = self.poll_write_us(cx);
Poll::Ready(Ok((ready!(r)?, ready!(w)?)))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{Dupe, tests::setup_logging};
use std::pin::pin;
use tokio::io::{AsyncReadExt, AsyncWriteExt, ReadBuf};
const DEFAULT_RWND_THRESHOLD: u32 = 4;
#[tokio::test]
#[cfg(not(loom))]
async fn test_mux_stream_read() {
setup_logging();
test_mux_stream_read_inner().await;
}
#[test]
#[cfg(loom)]
fn test_mux_stream_read_loom() {
loom::model(|| {
loom::future::block_on(test_mux_stream_read_inner());
})
}
async fn test_mux_stream_read_inner() {
let (rx_frame_tx, rx_frame_rx) = mpsc::channel(10);
let (tx_frame_tx, mut tx_frame_rx) = mpsc::unbounded_channel();
let (dropped_ports_tx, _) = mpsc::unbounded_channel();
let stream = MuxStream {
frame_rx: rx_frame_rx,
flow_id: 1,
dest_host: Bytes::new(),
dest_port: 8080,
finish_sent: Arc::new(AtomicBool::new(false)),
psh_send_remaining: Arc::new(AtomicU32::new(2)),
psh_recvd_since: 0,
writer_waker: Arc::new(AtomicWaker::new()),
frame_tx: tx_frame_tx,
buf: Bytes::new(),
dropped_ports_tx,
rwnd_threshold: 2,
};
let mut stream = pin!(stream);
let mut buf = vec![0u8; 5];
let mut read_buf = tokio::io::ReadBuf::new(&mut buf);
let waker = futures_util::task::noop_waker();
{
let mut cx = Context::from_waker(&waker);
let rs = stream.as_mut().poll_read(&mut cx, &mut read_buf);
assert!(matches!(rs, Poll::Pending));
}
rx_frame_tx.send(Bytes::from("hello")).await.unwrap();
{
let mut cx = Context::from_waker(&waker);
let rs = stream.as_mut().poll_read(&mut cx, &mut read_buf);
assert!(matches!(rs, Poll::Ready(Ok(()))));
assert_eq!(read_buf.filled().len(), 5);
assert_eq!(read_buf.filled(), b"hello");
read_buf.clear();
}
{
let mut cx = Context::from_waker(&waker);
let rs = stream.as_mut().poll_read(&mut cx, &mut read_buf);
assert!(matches!(rs, Poll::Pending));
}
rx_frame_tx.send(Bytes::from("world")).await.unwrap();
{
let mut cx = Context::from_waker(&waker);
let rs = stream.as_mut().poll_read(&mut cx, &mut read_buf);
assert!(matches!(rs, Poll::Ready(Ok(()))));
assert_eq!(read_buf.filled().len(), 5);
assert_eq!(read_buf.filled(), b"world");
read_buf.clear();
}
let frame = tx_frame_rx.recv().await.unwrap();
assert_eq!(frame.opcode().unwrap(), crate::frame::OpCode::Acknowledge);
let frame = Frame::try_from(frame).unwrap();
assert_eq!(frame.id, 1);
if let crate::frame::Payload::Acknowledge(ack) = frame.payload {
assert_eq!(ack, 2);
} else {
panic!("Expected an `Acknowledge` frame");
}
drop(rx_frame_tx);
{
let mut cx = Context::from_waker(&waker);
let rs = stream.as_mut().poll_read(&mut cx, &mut read_buf);
assert!(matches!(rs, Poll::Ready(Ok(()))));
assert_eq!(read_buf.filled().len(), 0);
}
}
#[tokio::test]
#[cfg(not(loom))]
async fn test_mux_stream_write() {
setup_logging();
test_mux_stream_write_inner().await;
}
#[test]
#[cfg(loom)]
fn test_mux_stream_write_loom() {
loom::model(|| {
loom::future::block_on(test_mux_stream_write_inner());
})
}
async fn test_mux_stream_write_inner() {
let (_, rx_frame_rx) = mpsc::channel(DEFAULT_RWND_THRESHOLD as usize);
let (tx_frame_tx, mut tx_frame_rx) = mpsc::unbounded_channel();
let (dropped_ports_tx, _) = mpsc::unbounded_channel();
let stream = MuxStream {
frame_rx: rx_frame_rx,
flow_id: 1,
dest_host: Bytes::new(),
dest_port: 8080,
finish_sent: Arc::new(AtomicBool::new(false)),
psh_send_remaining: Arc::new(AtomicU32::new(2)),
psh_recvd_since: 0,
writer_waker: Arc::new(AtomicWaker::new()),
frame_tx: tx_frame_tx,
buf: Bytes::new(),
dropped_ports_tx,
rwnd_threshold: DEFAULT_RWND_THRESHOLD,
};
let mut stream = pin!(stream);
let waker = futures_util::task::noop_waker();
{
let mut cx = Context::from_waker(&waker);
let rs = stream.as_mut().poll_write(&mut cx, b"hello");
assert!(matches!(rs, Poll::Ready(Ok(5))));
let rs = stream.as_mut().poll_write(&mut cx, b"world");
assert!(matches!(rs, Poll::Ready(Ok(5))));
}
let frame1 = tx_frame_rx.recv().await.unwrap();
assert_eq!(frame1.opcode().unwrap(), crate::frame::OpCode::Push);
let frame1 = Frame::try_from(frame1).unwrap();
assert_eq!(frame1.id, 1);
if let crate::frame::Payload::Push(push) = frame1.payload {
assert_eq!(&push.as_ref(), b"hello");
} else {
panic!("Expected a `Push` frame");
}
let frame2 = tx_frame_rx.recv().await.unwrap();
assert_eq!(frame2.opcode().unwrap(), crate::frame::OpCode::Push);
let frame2 = Frame::try_from(frame2).unwrap();
assert_eq!(frame2.id, 1);
if let crate::frame::Payload::Push(push) = frame2.payload {
assert_eq!(&push.as_ref(), b"world");
} else {
panic!("Expected a `Push` frame");
}
{
let mut cx = Context::from_waker(&waker);
let rs = stream.as_mut().poll_write(&mut cx, b"maybe");
assert!(matches!(rs, Poll::Pending));
}
stream.psh_send_remaining.fetch_add(1, Ordering::Release);
{
let mut cx = Context::from_waker(&waker);
let rs = stream.as_mut().poll_write(&mut cx, b"maybe");
assert!(matches!(rs, Poll::Ready(Ok(5))));
}
let frame4 = tx_frame_rx.recv().await.unwrap();
assert_eq!(frame4.opcode().unwrap(), crate::frame::OpCode::Push);
let frame4 = Frame::try_from(frame4).unwrap();
assert_eq!(frame4.id, 1);
if let crate::frame::Payload::Push(push) = frame4.payload {
assert_eq!(&push.as_ref(), b"maybe");
} else {
panic!("Expected a `Push` frame");
}
}
#[tokio::test]
#[cfg(not(loom))]
async fn test_copy_bidirectional_normal() {
const TX1: Bytes = Bytes::from_static(b"hello from mux");
const RX1: Bytes = Bytes::from_static(b"hello from other");
const TX2: Bytes = Bytes::from_static(b"short");
const RX2: Bytes = Bytes::from_static(b"hello after half-close");
const RX3: Bytes = Bytes::from_static(b"stout");
setup_logging();
let (rx_frame_tx, rx_frame_rx) = mpsc::channel(DEFAULT_RWND_THRESHOLD as usize);
let (tx_frame_tx, mut tx_frame_rx) = mpsc::unbounded_channel();
let (dropped_ports_tx, _) = mpsc::unbounded_channel();
let (other_stream, mut check_side) = tokio::io::duplex(1024);
let mux_stream = MuxStream {
frame_rx: rx_frame_rx,
flow_id: 1,
dest_host: Bytes::new(),
dest_port: 8080,
finish_sent: Arc::new(AtomicBool::new(false)),
psh_send_remaining: Arc::new(AtomicU32::new(10)), psh_recvd_since: 0,
writer_waker: Arc::new(AtomicWaker::new()),
frame_tx: tx_frame_tx.clone(),
buf: Bytes::new(),
dropped_ports_tx: dropped_ports_tx.clone(),
rwnd_threshold: DEFAULT_RWND_THRESHOLD,
};
let copy_task = tokio::spawn(mux_stream.into_copy_bidirectional(other_stream));
let waker = futures_util::task::noop_waker();
let mut cx = Context::from_waker(&waker);
let mut buf = [0u8; 14];
let mut rbuf = ReadBuf::new(&mut buf);
let rs = Pin::new(&mut check_side).poll_read(&mut cx, &mut rbuf);
assert!(matches!(rs, Poll::Pending));
rx_frame_tx.send(TX1.dupe()).await.unwrap();
let size = check_side.read(&mut buf).await.unwrap();
assert_eq!(size, TX1.len());
assert_eq!(&buf[..size], TX1);
check_side.write_all(&RX1).await.unwrap();
check_side.flush().await.unwrap();
let frame = tx_frame_rx.recv().await.unwrap();
let frame = Frame::try_from(frame).unwrap();
assert_eq!(frame.id, 1);
if let crate::frame::Payload::Push(push) = frame.payload {
assert_eq!(push.as_ref(), RX1);
} else {
panic!("Expected a `Push` frame");
}
rx_frame_tx.send(TX2.dupe()).await.unwrap();
drop(rx_frame_tx);
let read = check_side.read(&mut buf).await.unwrap();
assert_eq!(read, TX2.len());
assert_eq!(&buf[..read], TX2);
let m = check_side.read(&mut buf).await.unwrap();
assert_eq!(m, 0);
check_side.write_all(&RX2).await.unwrap();
check_side.flush().await.unwrap();
let frame = tx_frame_rx.recv().await.unwrap();
let frame = Frame::try_from(frame).unwrap();
assert_eq!(frame.id, 1);
if let crate::frame::Payload::Push(push) = frame.payload {
assert_eq!(push.as_ref(), RX2);
} else {
panic!("Expected a `Push` frame");
}
check_side.write_all(&RX3).await.unwrap();
check_side.shutdown().await.unwrap();
let frame = tx_frame_rx.recv().await.unwrap();
let frame = Frame::try_from(frame).unwrap();
assert_eq!(frame.id, 1);
if let crate::frame::Payload::Push(push) = frame.payload {
assert_eq!(push.as_ref(), RX3);
} else {
panic!("Expected a `Push` frame");
}
let (bytes_read, bytes_written) = copy_task.await.unwrap().unwrap();
assert_eq!(bytes_read, (TX1.len() + TX2.len()) as u64);
assert_eq!(bytes_written, (RX1.len() + RX2.len() + RX3.len()) as u64);
}
#[tokio::test]
#[cfg(not(loom))]
async fn test_flow_control() {
const TEST_ACK_THRESHOLD: usize = 5;
const TEST_ACK_THRESHOLD_U32: u32 = 5;
assert_eq!(TEST_ACK_THRESHOLD, TEST_ACK_THRESHOLD_U32 as usize);
setup_logging();
let (rx_frame_tx, rx_frame_rx) = mpsc::channel(TEST_ACK_THRESHOLD);
let (tx_frame_tx, mut tx_frame_rx) = mpsc::unbounded_channel();
let (dropped_ports_tx, _) = mpsc::unbounded_channel();
let (other_stream, mut check_side) = tokio::io::duplex(1024);
let mut mux_stream = MuxStream {
frame_rx: rx_frame_rx,
flow_id: 1,
dest_host: Bytes::new(),
dest_port: 8080,
finish_sent: Arc::new(AtomicBool::new(false)),
psh_send_remaining: Arc::new(AtomicU32::new(10)), psh_recvd_since: 0,
writer_waker: Arc::new(AtomicWaker::new()),
frame_tx: tx_frame_tx.clone(),
buf: Bytes::new(),
dropped_ports_tx: dropped_ports_tx.clone(),
rwnd_threshold: TEST_ACK_THRESHOLD_U32,
};
for i in 0..TEST_ACK_THRESHOLD {
debug!("sending frame {i}");
rx_frame_tx
.send(Bytes::from_static(b"hello"))
.await
.unwrap();
}
tx_frame_rx.try_recv().unwrap_err();
let mut buf = [0u8; 5 * TEST_ACK_THRESHOLD];
let n = mux_stream.read_exact(&mut buf).await.unwrap();
assert_eq!(n, 5 * TEST_ACK_THRESHOLD);
assert_eq!(&buf[..n], b"hello".repeat(TEST_ACK_THRESHOLD).as_slice());
let frame = tx_frame_rx.recv().await.unwrap();
let frame = Frame::try_from(frame).unwrap();
assert_eq!(frame.id, 1);
if let crate::frame::Payload::Acknowledge(ack) = frame.payload {
assert_eq!(ack, TEST_ACK_THRESHOLD_U32);
} else {
panic!("Expected an `Acknowledge` frame");
}
let task = tokio::spawn(mux_stream.into_copy_bidirectional(other_stream));
for i in 0..2 * TEST_ACK_THRESHOLD {
debug!("sending frame {i}");
rx_frame_tx
.send(Bytes::from_static(b"hello"))
.await
.unwrap();
}
let frame = tx_frame_rx.recv().await.unwrap();
let frame = Frame::try_from(frame).unwrap();
assert_eq!(frame.id, 1);
if let crate::frame::Payload::Acknowledge(ack) = frame.payload {
assert_eq!(ack, TEST_ACK_THRESHOLD_U32);
} else {
panic!("Expected an `Acknowledge` frame");
}
let frame = tx_frame_rx.recv().await.unwrap();
let frame = Frame::try_from(frame).unwrap();
assert_eq!(frame.id, 1);
if let crate::frame::Payload::Acknowledge(ack) = frame.payload {
assert_eq!(ack, TEST_ACK_THRESHOLD_U32);
} else {
panic!("Expected an `Acknowledge` frame");
}
let mut buf = [0u8; 5 * 2 * TEST_ACK_THRESHOLD];
let n = check_side.read_exact(&mut buf).await.unwrap();
assert_eq!(n, 5 * 2 * TEST_ACK_THRESHOLD);
assert_eq!(
&buf[..n],
b"hello".repeat(2 * TEST_ACK_THRESHOLD).as_slice()
);
drop(rx_frame_tx);
for i in 0..TEST_ACK_THRESHOLD {
debug!("sending data chunk {i}");
check_side.write_all(b"hello").await.unwrap();
check_side.flush().await.unwrap();
}
check_side.shutdown().await.unwrap();
task.await.unwrap().unwrap();
let mut buf = [0u8; 5 * TEST_ACK_THRESHOLD];
while let Some(frame) = tx_frame_rx.recv().await {
let frame = Frame::try_from(frame).unwrap();
assert_eq!(frame.id, 1);
match frame.payload {
crate::frame::Payload::Push(push) => {
buf.copy_from_slice(push.as_ref());
}
crate::frame::Payload::Finish => {
break;
}
_ => panic!("Expected a `Push` frame"),
}
}
assert_eq!(&buf[..], b"hello".repeat(TEST_ACK_THRESHOLD).as_slice());
}
#[tokio::test]
#[cfg(not(loom))]
async fn test_mux_stream_shutdown() {
test_mux_stream_shutdown_inner().await;
}
#[test]
#[cfg(loom)]
fn test_mux_stream_shutdown_loom() {
loom::model(|| {
loom::future::block_on(test_mux_stream_shutdown_inner());
})
}
async fn test_mux_stream_shutdown_inner() {
setup_logging();
let (_, rx_frame_rx) = mpsc::channel(10);
let (tx_frame_tx, mut tx_frame_rx) = mpsc::unbounded_channel();
let (dropped_ports_tx, mut dropped_ports_rx) = mpsc::unbounded_channel();
let mut stream = MuxStream {
frame_rx: rx_frame_rx,
flow_id: 15,
dest_host: Bytes::new(),
dest_port: 8080,
finish_sent: Arc::new(AtomicBool::new(false)),
psh_send_remaining: Arc::new(AtomicU32::new(2)),
psh_recvd_since: 0,
writer_waker: Arc::new(AtomicWaker::new()),
frame_tx: tx_frame_tx,
buf: Bytes::new(),
dropped_ports_tx,
rwnd_threshold: 2,
};
let waker = futures_util::task::noop_waker();
let mut cx = Context::from_waker(&waker);
let rs = Pin::new(&mut stream).as_mut().poll_shutdown(&mut cx);
assert!(matches!(rs, Poll::Ready(Ok(()))));
let frame = tx_frame_rx.recv().await.unwrap();
assert_eq!(frame.opcode().unwrap(), crate::frame::OpCode::Finish);
let frame = Frame::try_from(frame).unwrap();
assert_eq!(frame.id, 15);
drop(stream);
let dropped_port = dropped_ports_rx.recv().await.unwrap();
assert_eq!(dropped_port, 15);
}
}