use super::{AsyncRead, AsyncWrite, ReadBuf};
use std::future::Future;
use std::io;
use std::pin::Pin;
use std::task::{Context, Poll};
const DEFAULT_BUF_SIZE: usize = 8192;
#[inline]
pub fn copy<'a, R, W>(reader: &'a mut R, writer: &'a mut W) -> Copy<'a, R, W>
where
R: AsyncRead + Unpin + ?Sized,
W: AsyncWrite + Unpin + ?Sized,
{
Copy {
reader,
writer,
buf: [0u8; DEFAULT_BUF_SIZE],
read_done: false,
need_flush: false,
pos: 0,
cap: 0,
total: 0,
completed: false,
}
}
pub struct Copy<'a, R: ?Sized, W: ?Sized> {
reader: &'a mut R,
writer: &'a mut W,
buf: [u8; DEFAULT_BUF_SIZE],
read_done: bool,
need_flush: bool,
pos: usize,
cap: usize,
total: u64,
completed: bool,
}
impl<R, W> Future for Copy<'_, R, W>
where
R: AsyncRead + Unpin + ?Sized,
W: AsyncWrite + Unpin + ?Sized,
{
type Output = io::Result<u64>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.get_mut();
if this.completed {
return Poll::Ready(Err(io::Error::other("Copy future polled after completion")));
}
let mut steps = 0;
loop {
if crate::cx::Cx::current().is_some_and(|c| c.checkpoint().is_err()) {
this.completed = true;
return Poll::Ready(Err(std::io::Error::new(
std::io::ErrorKind::Interrupted,
"cancelled",
)));
}
if steps > 32 {
cx.waker().wake_by_ref();
return Poll::Pending;
}
steps += 1;
if this.pos < this.cap {
match Pin::new(&mut *this.writer).poll_write(cx, &this.buf[this.pos..this.cap]) {
Poll::Pending => return Poll::Pending,
Poll::Ready(Err(err)) => {
this.completed = true;
return Poll::Ready(Err(err));
}
Poll::Ready(Ok(0)) => {
this.completed = true;
return Poll::Ready(Err(io::Error::from(io::ErrorKind::WriteZero)));
}
Poll::Ready(Ok(n)) => {
this.pos += n;
this.total += n as u64;
this.need_flush = true;
continue;
}
}
}
if this.read_done {
match Pin::new(&mut *this.writer).poll_flush(cx) {
Poll::Pending => return Poll::Pending,
Poll::Ready(Err(err)) => {
this.completed = true;
return Poll::Ready(Err(err));
}
Poll::Ready(Ok(())) => {
this.completed = true;
return Poll::Ready(Ok(this.total));
}
}
}
let mut read_buf = ReadBuf::new(&mut this.buf);
match Pin::new(&mut *this.reader).poll_read(cx, &mut read_buf) {
Poll::Pending => {
if this.need_flush {
match Pin::new(&mut *this.writer).poll_flush(cx) {
Poll::Pending => return Poll::Pending,
Poll::Ready(Err(err)) => {
this.completed = true;
return Poll::Ready(Err(err));
}
Poll::Ready(Ok(())) => {
this.need_flush = false;
}
}
}
return Poll::Pending;
}
Poll::Ready(Err(err)) => {
this.completed = true;
return Poll::Ready(Err(err));
}
Poll::Ready(Ok(())) => {
let n = read_buf.filled().len();
if n == 0 {
this.read_done = true;
} else {
this.pos = 0;
this.cap = n;
}
}
}
}
}
}
pub trait AsyncBufRead: AsyncRead {
fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>>;
fn consume(self: Pin<&mut Self>, amt: usize);
}
impl AsyncBufRead for &[u8] {
fn poll_fill_buf(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
let this = self.get_mut();
Poll::Ready(Ok(this))
}
fn consume(self: Pin<&mut Self>, amt: usize) {
let this = self.get_mut();
let to_consume = std::cmp::min(amt, this.len());
*this = &this[to_consume..];
}
}
impl<T> AsyncBufRead for std::io::Cursor<T>
where
T: AsRef<[u8]> + Unpin,
{
fn poll_fill_buf(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
let this = self.get_mut();
let data = this.get_ref().as_ref();
let pos = usize::try_from(this.position()).unwrap_or(usize::MAX);
let start = std::cmp::min(pos, data.len());
Poll::Ready(Ok(&data[start..]))
}
fn consume(self: Pin<&mut Self>, amt: usize) {
let this = self.get_mut();
let data_len = this.get_ref().as_ref().len() as u64;
let pos = this.position();
let advance = std::cmp::min(amt as u64, data_len.saturating_sub(pos));
this.set_position(pos.saturating_add(advance));
}
}
#[inline]
pub fn copy_buf<'a, R, W>(reader: &'a mut R, writer: &'a mut W) -> CopyBuf<'a, R, W>
where
R: AsyncBufRead + Unpin + ?Sized,
W: AsyncWrite + Unpin + ?Sized,
{
CopyBuf {
reader,
writer,
total: 0,
read_done: false,
need_flush: false,
completed: false,
}
}
pub struct CopyBuf<'a, R: ?Sized, W: ?Sized> {
reader: &'a mut R,
writer: &'a mut W,
total: u64,
read_done: bool,
need_flush: bool,
completed: bool,
}
impl<R, W> Future for CopyBuf<'_, R, W>
where
R: AsyncBufRead + Unpin + ?Sized,
W: AsyncWrite + Unpin + ?Sized,
{
type Output = io::Result<u64>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.get_mut();
if this.completed {
return Poll::Ready(Err(io::Error::other(
"CopyBuf future polled after completion",
)));
}
let mut steps = 0;
loop {
if crate::cx::Cx::current().is_some_and(|c| c.checkpoint().is_err()) {
this.completed = true;
return Poll::Ready(Err(std::io::Error::new(
std::io::ErrorKind::Interrupted,
"cancelled",
)));
}
if steps > 32 {
cx.waker().wake_by_ref();
return Poll::Pending;
}
steps += 1;
if this.read_done {
match Pin::new(&mut *this.writer).poll_flush(cx) {
Poll::Pending => return Poll::Pending,
Poll::Ready(Err(err)) => {
this.completed = true;
return Poll::Ready(Err(err));
}
Poll::Ready(Ok(())) => {
this.completed = true;
return Poll::Ready(Ok(this.total));
}
}
}
let buf = match Pin::new(&mut *this.reader).poll_fill_buf(cx) {
Poll::Pending => {
if this.need_flush {
match Pin::new(&mut *this.writer).poll_flush(cx) {
Poll::Pending => return Poll::Pending,
Poll::Ready(Err(e)) => {
this.completed = true;
return Poll::Ready(Err(e));
}
Poll::Ready(Ok(())) => {
this.need_flush = false;
}
}
}
return Poll::Pending;
}
Poll::Ready(Err(err)) => {
this.completed = true;
return Poll::Ready(Err(err));
}
Poll::Ready(Ok(buf)) => buf,
};
if buf.is_empty() {
this.read_done = true;
continue;
}
let n = match Pin::new(&mut *this.writer).poll_write(cx, buf) {
Poll::Pending => return Poll::Pending,
Poll::Ready(Err(err)) => {
this.completed = true;
return Poll::Ready(Err(err));
}
Poll::Ready(Ok(0)) => {
this.completed = true;
return Poll::Ready(Err(io::Error::from(io::ErrorKind::WriteZero)));
}
Poll::Ready(Ok(n)) => n,
};
Pin::new(&mut *this.reader).consume(n);
this.total += n as u64;
this.need_flush = true;
}
}
}
#[inline]
pub fn copy_with_progress<'a, R, W, F>(
reader: &'a mut R,
writer: &'a mut W,
on_progress: F,
) -> CopyWithProgress<'a, R, W, F>
where
R: AsyncRead + Unpin + ?Sized,
W: AsyncWrite + Unpin + ?Sized,
F: FnMut(u64),
{
CopyWithProgress {
reader,
writer,
on_progress,
buf: [0u8; DEFAULT_BUF_SIZE],
read_done: false,
need_flush: false,
pos: 0,
cap: 0,
total: 0,
completed: false,
}
}
pub struct CopyWithProgress<'a, R: ?Sized, W: ?Sized, F> {
reader: &'a mut R,
writer: &'a mut W,
on_progress: F,
buf: [u8; DEFAULT_BUF_SIZE],
read_done: bool,
need_flush: bool,
pos: usize,
cap: usize,
total: u64,
completed: bool,
}
impl<R, W, F> Future for CopyWithProgress<'_, R, W, F>
where
R: AsyncRead + Unpin + ?Sized,
W: AsyncWrite + Unpin + ?Sized,
F: FnMut(u64) + Unpin,
{
type Output = io::Result<u64>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.get_mut();
if this.completed {
return Poll::Ready(Err(io::Error::other(
"CopyWithProgress future polled after completion",
)));
}
let mut steps = 0;
loop {
if crate::cx::Cx::current().is_some_and(|c| c.checkpoint().is_err()) {
this.completed = true;
return Poll::Ready(Err(std::io::Error::new(
std::io::ErrorKind::Interrupted,
"cancelled",
)));
}
if steps > 32 {
cx.waker().wake_by_ref();
return Poll::Pending;
}
steps += 1;
if this.pos < this.cap {
match Pin::new(&mut *this.writer).poll_write(cx, &this.buf[this.pos..this.cap]) {
Poll::Pending => return Poll::Pending,
Poll::Ready(Err(err)) => {
this.completed = true;
return Poll::Ready(Err(err));
}
Poll::Ready(Ok(0)) => {
this.completed = true;
return Poll::Ready(Err(io::Error::from(io::ErrorKind::WriteZero)));
}
Poll::Ready(Ok(n)) => {
this.pos += n;
this.total += n as u64;
(this.on_progress)(this.total);
this.need_flush = true;
continue;
}
}
}
if this.read_done {
match Pin::new(&mut *this.writer).poll_flush(cx) {
Poll::Pending => return Poll::Pending,
Poll::Ready(Err(err)) => {
this.completed = true;
return Poll::Ready(Err(err));
}
Poll::Ready(Ok(())) => {
this.completed = true;
return Poll::Ready(Ok(this.total));
}
}
}
let mut read_buf = ReadBuf::new(&mut this.buf);
match Pin::new(&mut *this.reader).poll_read(cx, &mut read_buf) {
Poll::Pending => {
if this.need_flush {
match Pin::new(&mut *this.writer).poll_flush(cx) {
Poll::Pending => return Poll::Pending,
Poll::Ready(Err(err)) => {
this.completed = true;
return Poll::Ready(Err(err));
}
Poll::Ready(Ok(())) => {
this.need_flush = false;
}
}
}
return Poll::Pending;
}
Poll::Ready(Err(err)) => {
this.completed = true;
return Poll::Ready(Err(err));
}
Poll::Ready(Ok(())) => {
let n = read_buf.filled().len();
if n == 0 {
this.read_done = true;
} else {
this.pos = 0;
this.cap = n;
}
}
}
}
}
}
#[inline]
pub fn copy_bidirectional<'a, A, B>(a: &'a mut A, b: &'a mut B) -> CopyBidirectional<'a, A, B>
where
A: AsyncRead + AsyncWrite + Unpin + ?Sized,
B: AsyncRead + AsyncWrite + Unpin + ?Sized,
{
CopyBidirectional {
a,
b,
a_to_b_buf: [0u8; DEFAULT_BUF_SIZE],
b_to_a_buf: [0u8; DEFAULT_BUF_SIZE],
a_to_b: TransferState::default(),
b_to_a: TransferState::default(),
a_to_b_total: 0,
b_to_a_total: 0,
completed: false,
}
}
#[derive(Default)]
struct TransferState {
read_done: bool,
shutdown_done: bool,
need_flush: bool,
pos: usize,
cap: usize,
}
pub struct CopyBidirectional<'a, A: ?Sized, B: ?Sized> {
a: &'a mut A,
b: &'a mut B,
a_to_b_buf: [u8; DEFAULT_BUF_SIZE],
b_to_a_buf: [u8; DEFAULT_BUF_SIZE],
a_to_b: TransferState,
b_to_a: TransferState,
a_to_b_total: u64,
b_to_a_total: u64,
completed: bool,
}
const YIELD_BUDGET: usize = 64;
enum TransferResult {
Done,
Pending,
Progress,
Error(io::Error),
}
impl<A, B> CopyBidirectional<'_, A, B>
where
A: AsyncRead + AsyncWrite + Unpin + ?Sized,
B: AsyncRead + AsyncWrite + Unpin + ?Sized,
{
fn step_a_to_b(&mut self, cx: &mut Context<'_>) -> TransferResult {
let state = &mut self.a_to_b;
if state.pos < state.cap {
match Pin::new(&mut *self.b).poll_write(cx, &self.a_to_b_buf[state.pos..state.cap]) {
Poll::Pending => return TransferResult::Pending,
Poll::Ready(Err(err)) => return TransferResult::Error(err),
Poll::Ready(Ok(0)) => {
return TransferResult::Error(io::Error::from(io::ErrorKind::WriteZero));
}
Poll::Ready(Ok(n)) => {
state.pos += n;
self.a_to_b_total += n as u64;
state.need_flush = true;
return TransferResult::Progress;
}
}
}
if state.read_done {
if state.need_flush {
match Pin::new(&mut *self.b).poll_flush(cx) {
Poll::Pending => return TransferResult::Pending,
Poll::Ready(Err(err)) => return TransferResult::Error(err),
Poll::Ready(Ok(())) => {
state.need_flush = false;
return TransferResult::Progress;
}
}
}
if !state.shutdown_done {
match Pin::new(&mut *self.b).poll_shutdown(cx) {
Poll::Pending => return TransferResult::Pending,
Poll::Ready(Err(err)) => return TransferResult::Error(err),
Poll::Ready(Ok(())) => {
state.shutdown_done = true;
return TransferResult::Progress;
}
}
}
return TransferResult::Done;
}
state.pos = 0;
state.cap = 0;
let mut read_buf = ReadBuf::new(&mut self.a_to_b_buf);
match Pin::new(&mut *self.a).poll_read(cx, &mut read_buf) {
Poll::Pending => {
if state.need_flush {
match Pin::new(&mut *self.b).poll_flush(cx) {
Poll::Ready(Ok(())) => {
state.need_flush = false;
TransferResult::Progress
}
Poll::Ready(Err(e)) => TransferResult::Error(e),
Poll::Pending => TransferResult::Pending,
}
} else {
TransferResult::Pending
}
}
Poll::Ready(Err(err)) => TransferResult::Error(err),
Poll::Ready(Ok(())) => {
let n = read_buf.filled().len();
if n == 0 {
state.read_done = true;
}
state.cap = n;
TransferResult::Progress
}
}
}
fn step_b_to_a(&mut self, cx: &mut Context<'_>) -> TransferResult {
let state = &mut self.b_to_a;
if state.pos < state.cap {
match Pin::new(&mut *self.a).poll_write(cx, &self.b_to_a_buf[state.pos..state.cap]) {
Poll::Pending => return TransferResult::Pending,
Poll::Ready(Err(err)) => return TransferResult::Error(err),
Poll::Ready(Ok(0)) => {
return TransferResult::Error(io::Error::from(io::ErrorKind::WriteZero));
}
Poll::Ready(Ok(n)) => {
state.pos += n;
self.b_to_a_total += n as u64;
state.need_flush = true;
return TransferResult::Progress;
}
}
}
if state.read_done {
if state.need_flush {
match Pin::new(&mut *self.a).poll_flush(cx) {
Poll::Pending => return TransferResult::Pending,
Poll::Ready(Err(err)) => return TransferResult::Error(err),
Poll::Ready(Ok(())) => {
state.need_flush = false;
return TransferResult::Progress;
}
}
}
if !state.shutdown_done {
match Pin::new(&mut *self.a).poll_shutdown(cx) {
Poll::Pending => return TransferResult::Pending,
Poll::Ready(Err(err)) => return TransferResult::Error(err),
Poll::Ready(Ok(())) => {
state.shutdown_done = true;
return TransferResult::Progress;
}
}
}
return TransferResult::Done;
}
state.pos = 0;
state.cap = 0;
let mut read_buf = ReadBuf::new(&mut self.b_to_a_buf);
match Pin::new(&mut *self.b).poll_read(cx, &mut read_buf) {
Poll::Pending => {
if state.need_flush {
match Pin::new(&mut *self.a).poll_flush(cx) {
Poll::Ready(Ok(())) => {
state.need_flush = false;
TransferResult::Progress
}
Poll::Ready(Err(e)) => TransferResult::Error(e),
Poll::Pending => TransferResult::Pending,
}
} else {
TransferResult::Pending
}
}
Poll::Ready(Err(err)) => TransferResult::Error(err),
Poll::Ready(Ok(())) => {
let n = read_buf.filled().len();
if n == 0 {
state.read_done = true;
}
state.cap = n;
TransferResult::Progress
}
}
}
}
impl<A, B> Future for CopyBidirectional<'_, A, B>
where
A: AsyncRead + AsyncWrite + Unpin + ?Sized,
B: AsyncRead + AsyncWrite + Unpin + ?Sized,
{
type Output = io::Result<(u64, u64)>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.get_mut();
if this.completed {
return Poll::Ready(Err(io::Error::other(
"CopyBidirectional future polled after completion",
)));
}
let mut steps = 0;
loop {
if crate::cx::Cx::current().is_some_and(|c| c.checkpoint().is_err()) {
this.completed = true;
return Poll::Ready(Err(std::io::Error::new(
std::io::ErrorKind::Interrupted,
"cancelled",
)));
}
if steps >= YIELD_BUDGET {
cx.waker().wake_by_ref();
return Poll::Pending;
}
steps += 1;
let mut made_progress = false;
match this.step_a_to_b(cx) {
TransferResult::Progress => made_progress = true,
TransferResult::Error(e) => {
this.completed = true;
return Poll::Ready(Err(e));
}
TransferResult::Done | TransferResult::Pending => {}
}
match this.step_b_to_a(cx) {
TransferResult::Progress => made_progress = true,
TransferResult::Error(e) => {
this.completed = true;
return Poll::Ready(Err(e));
}
TransferResult::Done | TransferResult::Pending => {}
}
if made_progress {
steps += 1;
} else {
let a_to_b_done = this.a_to_b.read_done
&& this.a_to_b.pos >= this.a_to_b.cap
&& this.a_to_b.shutdown_done;
let b_to_a_done = this.b_to_a.read_done
&& this.b_to_a.pos >= this.b_to_a.cap
&& this.b_to_a.shutdown_done;
if a_to_b_done && b_to_a_done {
this.completed = true;
return Poll::Ready(Ok((this.a_to_b_total, this.b_to_a_total)));
}
return Poll::Pending;
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::task::{Context, Waker};
fn noop_waker() -> Waker {
std::task::Waker::noop().clone()
}
fn poll_ready<F: Future>(fut: &mut Pin<&mut F>) -> Option<F::Output> {
let waker = noop_waker();
let mut cx = Context::from_waker(&waker);
for _ in 0..1024 {
if let Poll::Ready(output) = fut.as_mut().poll(&mut cx) {
return Some(output);
}
}
None
}
fn init_test(name: &str) {
crate::test_utils::init_test_logging();
crate::test_phase!(name);
}
#[test]
fn copy_small_data() {
init_test("copy_small_data");
let mut reader: &[u8] = b"hello world";
let mut writer = Vec::new();
let mut fut = copy(&mut reader, &mut writer);
let mut fut = Pin::new(&mut fut);
let n = poll_ready(&mut fut)
.expect("future did not resolve")
.unwrap();
crate::assert_with_log!(n == 11, "bytes", 11, n);
crate::assert_with_log!(writer == b"hello world", "writer", b"hello world", writer);
crate::test_complete!("copy_small_data");
}
#[test]
fn copy_empty_data() {
init_test("copy_empty_data");
let mut reader: &[u8] = b"";
let mut writer = Vec::new();
let mut fut = copy(&mut reader, &mut writer);
let mut fut = Pin::new(&mut fut);
let n = poll_ready(&mut fut)
.expect("future did not resolve")
.unwrap();
crate::assert_with_log!(n == 0, "bytes", 0, n);
let empty = writer.is_empty();
crate::assert_with_log!(empty, "writer empty", true, empty);
crate::test_complete!("copy_empty_data");
}
#[test]
fn copy_large_data() {
init_test("copy_large_data");
let data: Vec<u8> = (0u32..32768).map(|i| (i % 256) as u8).collect();
let mut reader: &[u8] = &data;
let mut writer = Vec::new();
let mut fut = copy(&mut reader, &mut writer);
let mut fut = Pin::new(&mut fut);
let n = poll_ready(&mut fut)
.expect("future did not resolve")
.unwrap();
crate::assert_with_log!(n == 32768, "bytes", 32768, n);
crate::assert_with_log!(writer == data, "writer", data, writer);
crate::test_complete!("copy_large_data");
}
struct InterruptingWriter {
written: Vec<u8>,
remaining_before_interrupt: usize,
}
impl InterruptingWriter {
fn new(prefix_len: usize) -> Self {
Self {
written: Vec::new(),
remaining_before_interrupt: prefix_len,
}
}
}
impl AsyncWrite for InterruptingWriter {
fn poll_write(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
let this = self.get_mut();
if this.remaining_before_interrupt == 0 {
return Poll::Ready(Err(io::Error::new(
io::ErrorKind::Interrupted,
"writer interrupted",
)));
}
let to_write = this.remaining_before_interrupt.min(buf.len());
this.written.extend_from_slice(&buf[..to_write]);
this.remaining_before_interrupt -= to_write;
Poll::Ready(Ok(to_write))
}
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
}
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
}
}
#[test]
fn copy_partial_write_interrupt_preserves_committed_prefix() {
init_test("copy_partial_write_interrupt_preserves_committed_prefix");
let data: Vec<u8> = (0u32..16384).map(|i| (i % 251) as u8).collect();
let committed_prefix_len = 5000usize;
let mut reader: &[u8] = &data;
let mut writer = InterruptingWriter::new(committed_prefix_len);
let mut fut = copy(&mut reader, &mut writer);
let mut fut = Pin::new(&mut fut);
let err = poll_ready(&mut fut)
.expect("future did not resolve")
.expect_err("copy should stop on interrupted writer");
crate::assert_with_log!(
err.kind() == io::ErrorKind::Interrupted,
"error kind",
io::ErrorKind::Interrupted,
err.kind()
);
crate::assert_with_log!(
writer.written.len() == committed_prefix_len,
"committed prefix len",
committed_prefix_len,
writer.written.len()
);
crate::assert_with_log!(
writer.written == data[..committed_prefix_len],
"committed prefix data",
&data[..committed_prefix_len],
writer.written
);
crate::test_complete!("copy_partial_write_interrupt_preserves_committed_prefix");
}
#[test]
fn copy_with_progress_interrupt_reports_only_committed_prefix() {
init_test("copy_with_progress_interrupt_reports_only_committed_prefix");
let data: Vec<u8> = (0u32..16384).map(|i| ((i * 7) % 253) as u8).collect();
let committed_prefix_len = 4097usize;
let mut reader: &[u8] = &data;
let mut writer = InterruptingWriter::new(committed_prefix_len);
let mut progress = Vec::new();
let mut fut = copy_with_progress(&mut reader, &mut writer, |total| progress.push(total));
let mut fut = Pin::new(&mut fut);
let err = poll_ready(&mut fut)
.expect("future did not resolve")
.expect_err("copy_with_progress should stop on interrupted writer");
crate::assert_with_log!(
err.kind() == io::ErrorKind::Interrupted,
"error kind",
io::ErrorKind::Interrupted,
err.kind()
);
let last_progress = progress.last().copied().unwrap_or_default() as usize;
crate::assert_with_log!(
last_progress == committed_prefix_len,
"last progress equals committed prefix",
committed_prefix_len,
last_progress
);
crate::assert_with_log!(
writer.written == data[..committed_prefix_len],
"committed prefix data",
&data[..committed_prefix_len],
writer.written
);
crate::test_complete!("copy_with_progress_interrupt_reports_only_committed_prefix");
}
#[test]
fn copy_with_progress_tracks_bytes() {
init_test("copy_with_progress_tracks_bytes");
let mut reader: &[u8] = b"hello world";
let mut writer = Vec::new();
let mut progress_calls = Vec::new();
let mut fut = copy_with_progress(&mut reader, &mut writer, |total| {
progress_calls.push(total);
});
let mut fut = Pin::new(&mut fut);
let n = poll_ready(&mut fut)
.expect("future did not resolve")
.unwrap();
crate::assert_with_log!(n == 11, "bytes", 11, n);
crate::assert_with_log!(writer == b"hello world", "writer", b"hello world", writer);
let empty = progress_calls.is_empty();
crate::assert_with_log!(!empty, "progress calls", false, empty);
let last = *progress_calls.last().unwrap();
crate::assert_with_log!(last == 11, "last progress", 11, last);
crate::test_complete!("copy_with_progress_tracks_bytes");
}
#[test]
fn copy_buf_reads_from_slice() {
init_test("copy_buf_reads_from_slice");
let mut reader: &[u8] = b"hello buffer";
let mut writer = Vec::new();
let mut fut = copy_buf(&mut reader, &mut writer);
let mut fut = Pin::new(&mut fut);
let n = poll_ready(&mut fut)
.expect("future did not resolve")
.unwrap();
crate::assert_with_log!(n == 12, "bytes", 12, n);
crate::assert_with_log!(writer == b"hello buffer", "writer", b"hello buffer", writer);
let empty = reader.is_empty();
crate::assert_with_log!(empty, "reader empty", true, empty);
crate::test_complete!("copy_buf_reads_from_slice");
}
#[test]
fn copy_buf_reads_from_cursor() {
init_test("copy_buf_reads_from_cursor");
let data = b"cursor data";
let mut reader = std::io::Cursor::new(data);
let mut writer = Vec::new();
let mut fut = copy_buf(&mut reader, &mut writer);
let mut fut = Pin::new(&mut fut);
let n = poll_ready(&mut fut)
.expect("future did not resolve")
.unwrap();
crate::assert_with_log!(n == 11, "bytes", 11, n);
crate::assert_with_log!(writer == data, "writer", data, writer);
crate::test_complete!("copy_buf_reads_from_cursor");
}
struct TestDuplex {
read_data: Vec<u8>,
read_pos: usize,
written: Vec<u8>,
shutdown_called: bool,
}
impl TestDuplex {
fn new(read_data: &[u8]) -> Self {
Self {
read_data: read_data.to_vec(),
read_pos: 0,
written: Vec::new(),
shutdown_called: false,
}
}
}
impl AsyncRead for TestDuplex {
fn poll_read(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
let this = self.get_mut();
if this.read_pos >= this.read_data.len() {
return Poll::Ready(Ok(()));
}
let to_copy = std::cmp::min(this.read_data.len() - this.read_pos, buf.remaining());
buf.put_slice(&this.read_data[this.read_pos..this.read_pos + to_copy]);
this.read_pos += to_copy;
Poll::Ready(Ok(()))
}
}
impl AsyncWrite for TestDuplex {
fn poll_write(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
let this = self.get_mut();
this.written.extend_from_slice(buf);
Poll::Ready(Ok(buf.len()))
}
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
}
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
self.get_mut().shutdown_called = true;
Poll::Ready(Ok(()))
}
}
#[test]
fn copy_bidirectional_basic() {
init_test("copy_bidirectional_basic");
let mut a = TestDuplex::new(b"from A");
let mut b = TestDuplex::new(b"from B");
let mut fut = copy_bidirectional(&mut a, &mut b);
let mut fut = Pin::new(&mut fut);
let (a_to_b, b_to_a) = poll_ready(&mut fut)
.expect("future did not resolve")
.unwrap();
crate::assert_with_log!(a_to_b == 6, "a_to_b", 6, a_to_b);
crate::assert_with_log!(b_to_a == 6, "b_to_a", 6, b_to_a);
crate::assert_with_log!(b.written == b"from A", "b written", b"from A", b.written);
crate::assert_with_log!(a.written == b"from B", "a written", b"from B", a.written);
crate::test_complete!("copy_bidirectional_basic");
}
#[test]
fn copy_bidirectional_propagates_shutdown() {
init_test("copy_bidirectional_propagates_shutdown");
let mut a = TestDuplex::new(b"from A");
let mut b = TestDuplex::new(b"from B");
let mut fut = copy_bidirectional(&mut a, &mut b);
let mut fut = Pin::new(&mut fut);
let _ = poll_ready(&mut fut)
.expect("future did not resolve")
.unwrap();
crate::assert_with_log!(a.shutdown_called, "a shutdown", true, a.shutdown_called);
crate::assert_with_log!(b.shutdown_called, "b shutdown", true, b.shutdown_called);
crate::test_complete!("copy_bidirectional_propagates_shutdown");
}
#[test]
fn copy_bidirectional_asymmetric() {
init_test("copy_bidirectional_asymmetric");
let mut a = TestDuplex::new(b"short");
let mut b = TestDuplex::new(b"this is a longer message");
let mut fut = copy_bidirectional(&mut a, &mut b);
let mut fut = Pin::new(&mut fut);
let (a_to_b, b_to_a) = poll_ready(&mut fut)
.expect("future did not resolve")
.unwrap();
crate::assert_with_log!(a_to_b == 5, "a_to_b", 5, a_to_b);
crate::assert_with_log!(b_to_a == 24, "b_to_a", 24, b_to_a);
crate::assert_with_log!(b.written == b"short", "b written", b"short", b.written);
crate::assert_with_log!(
a.written == b"this is a longer message",
"a written",
b"this is a longer message",
a.written
);
crate::test_complete!("copy_bidirectional_asymmetric");
}
#[test]
fn copy_bidirectional_empty() {
init_test("copy_bidirectional_empty");
let mut a = TestDuplex::new(b"");
let mut b = TestDuplex::new(b"");
let mut fut = copy_bidirectional(&mut a, &mut b);
let mut fut = Pin::new(&mut fut);
let (a_to_b, b_to_a) = poll_ready(&mut fut)
.expect("future did not resolve")
.unwrap();
crate::assert_with_log!(a_to_b == 0, "a_to_b", 0, a_to_b);
crate::assert_with_log!(b_to_a == 0, "b_to_a", 0, b_to_a);
crate::test_complete!("copy_bidirectional_empty");
}
struct DeferredShutdownDuplex {
inner: TestDuplex,
shutdown_poll_count: usize,
}
impl DeferredShutdownDuplex {
fn new(read_data: &[u8]) -> Self {
Self {
inner: TestDuplex::new(read_data),
shutdown_poll_count: 0,
}
}
}
impl AsyncRead for DeferredShutdownDuplex {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
Pin::new(&mut self.get_mut().inner).poll_read(cx, buf)
}
}
impl AsyncWrite for DeferredShutdownDuplex {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
Pin::new(&mut self.get_mut().inner).poll_write(cx, buf)
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.get_mut().inner).poll_flush(cx)
}
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
let this = self.get_mut();
this.shutdown_poll_count += 1;
if this.shutdown_poll_count <= 1 {
cx.waker().wake_by_ref();
Poll::Pending
} else {
this.inner.shutdown_called = true;
Poll::Ready(Ok(()))
}
}
}
#[test]
fn copy_bidirectional_waits_for_shutdown_completion() {
init_test("copy_bidirectional_waits_for_shutdown_completion");
let mut a = DeferredShutdownDuplex::new(b"hello");
let mut b = DeferredShutdownDuplex::new(b"world");
let mut fut = copy_bidirectional(&mut a, &mut b);
let mut fut = Pin::new(&mut fut);
let (a_to_b, b_to_a) = poll_ready(&mut fut)
.expect("future did not resolve")
.unwrap();
crate::assert_with_log!(a_to_b == 5, "a_to_b", 5, a_to_b);
crate::assert_with_log!(b_to_a == 5, "b_to_a", 5, b_to_a);
let a_shut = a.inner.shutdown_called;
let b_shut = b.inner.shutdown_called;
crate::assert_with_log!(a_shut, "a shutdown done", true, a_shut);
crate::assert_with_log!(b_shut, "b shutdown done", true, b_shut);
crate::test_complete!("copy_bidirectional_waits_for_shutdown_completion");
}
#[test]
fn copy_bidirectional_yields_on_fast_streams() {
struct InfiniteStream;
impl AsyncRead for InfiniteStream {
fn poll_read(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
let space = buf.remaining();
let zeros = vec![0u8; space];
buf.put_slice(&zeros);
Poll::Ready(Ok(()))
}
}
impl AsyncWrite for InfiniteStream {
fn poll_write(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
Poll::Ready(Ok(buf.len()))
}
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
}
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
}
}
init_test("copy_bidirectional_yields_on_fast_streams");
let mut a = InfiniteStream;
let mut b = InfiniteStream;
let mut fut = copy_bidirectional(&mut a, &mut b);
let mut fut = Pin::new(&mut fut);
let waker = noop_waker();
let mut cx = Context::from_waker(&waker);
let poll_result = fut.as_mut().poll(&mut cx);
let is_pending = matches!(poll_result, Poll::Pending);
crate::assert_with_log!(is_pending, "poll result is pending", true, is_pending);
crate::test_complete!("copy_bidirectional_yields_on_fast_streams");
}
#[test]
fn copy_repoll_after_completion_fails_closed() {
init_test("copy_repoll_after_completion_fails_closed");
let mut reader: &[u8] = b"data";
let mut writer = Vec::new();
let mut fut = copy(&mut reader, &mut writer);
let mut pinned = Pin::new(&mut fut);
let first = poll_ready(&mut pinned).expect("future did not resolve");
assert!(first.is_ok(), "first poll should succeed");
let waker = noop_waker();
let mut cx = Context::from_waker(&waker);
let second = pinned.as_mut().poll(&mut cx);
match second {
Poll::Ready(Err(e)) => {
let msg = e.to_string();
let ok = msg.contains("polled after completion");
crate::assert_with_log!(ok, "error message", "polled after completion", msg);
}
other => panic!("expected Ready(Err), got {other:?}"), }
crate::test_complete!("copy_repoll_after_completion_fails_closed");
}
#[test]
fn copy_buf_repoll_after_completion_fails_closed() {
init_test("copy_buf_repoll_after_completion_fails_closed");
let mut reader: &[u8] = b"data";
let mut writer = Vec::new();
let mut fut = copy_buf(&mut reader, &mut writer);
let mut pinned = Pin::new(&mut fut);
let first = poll_ready(&mut pinned).expect("future did not resolve");
assert!(first.is_ok(), "first poll should succeed");
let waker = noop_waker();
let mut cx = Context::from_waker(&waker);
let second = pinned.as_mut().poll(&mut cx);
match second {
Poll::Ready(Err(e)) => {
let msg = e.to_string();
let ok = msg.contains("polled after completion");
crate::assert_with_log!(ok, "error message", "polled after completion", msg);
}
other => panic!("expected Ready(Err), got {other:?}"), }
crate::test_complete!("copy_buf_repoll_after_completion_fails_closed");
}
#[test]
fn copy_with_progress_repoll_after_completion_fails_closed() {
init_test("copy_with_progress_repoll_after_completion_fails_closed");
let mut reader: &[u8] = b"data";
let mut writer = Vec::new();
let mut fut = copy_with_progress(&mut reader, &mut writer, |_| {});
let mut pinned = Pin::new(&mut fut);
let first = poll_ready(&mut pinned).expect("future did not resolve");
assert!(first.is_ok(), "first poll should succeed");
let waker = noop_waker();
let mut cx = Context::from_waker(&waker);
let second = pinned.as_mut().poll(&mut cx);
match second {
Poll::Ready(Err(e)) => {
let msg = e.to_string();
let ok = msg.contains("polled after completion");
crate::assert_with_log!(ok, "error message", "polled after completion", msg);
}
other => panic!("expected Ready(Err), got {other:?}"), }
crate::test_complete!("copy_with_progress_repoll_after_completion_fails_closed");
}
#[test]
fn copy_bidirectional_repoll_after_completion_fails_closed() {
init_test("copy_bidirectional_repoll_after_completion_fails_closed");
let mut a = TestDuplex::new(b"hello");
let mut b = TestDuplex::new(b"world");
let mut fut = copy_bidirectional(&mut a, &mut b);
let mut pinned = Pin::new(&mut fut);
let first = poll_ready(&mut pinned).expect("future did not resolve");
assert!(first.is_ok(), "first poll should succeed");
let waker = noop_waker();
let mut cx = Context::from_waker(&waker);
let second = pinned.as_mut().poll(&mut cx);
match second {
Poll::Ready(Err(e)) => {
let msg = e.to_string();
let ok = msg.contains("polled after completion");
crate::assert_with_log!(ok, "error message", "polled after completion", msg);
}
other => panic!("expected Ready(Err), got {other:?}"), }
crate::test_complete!("copy_bidirectional_repoll_after_completion_fails_closed");
}
}