use std::io::{self, IoSlice};
use std::ops::DerefMut;
use std::pin::Pin;
use std::task::{Context, Poll};
pub trait AsyncWrite {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>>;
fn poll_write_vectored(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[IoSlice<'_>],
) -> Poll<io::Result<usize>> {
for buf in bufs {
if !buf.is_empty() {
return self.poll_write(cx, buf);
}
}
Poll::Ready(Ok(0))
}
#[inline]
fn is_write_vectored(&self) -> bool {
false
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>>;
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>>;
}
pub trait AsyncWriteVectored: AsyncWrite {
fn poll_write_vectored(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[IoSlice<'_>],
) -> Poll<io::Result<usize>> {
AsyncWrite::poll_write_vectored(self, cx, bufs)
}
#[inline]
fn is_write_vectored(&self) -> bool {
AsyncWrite::is_write_vectored(self)
}
}
impl<W> AsyncWriteVectored for W where W: AsyncWrite + ?Sized {}
impl AsyncWrite for Vec<u8> {
#[inline]
fn poll_write(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
let this = self.get_mut();
this.extend_from_slice(buf);
Poll::Ready(Ok(buf.len()))
}
#[inline]
fn poll_write_vectored(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
bufs: &[IoSlice<'_>],
) -> Poll<io::Result<usize>> {
let this = self.get_mut();
let mut total = 0;
for buf in bufs {
this.extend_from_slice(buf);
total += buf.len();
}
Poll::Ready(Ok(total))
}
#[inline]
fn is_write_vectored(&self) -> bool {
true
}
#[inline]
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
}
#[inline]
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
}
}
impl AsyncWrite for std::io::Cursor<&mut [u8]> {
#[inline]
fn poll_write(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
use std::io::Write as _;
let this = self.get_mut();
let n = this.write(buf)?;
Poll::Ready(Ok(n))
}
#[inline]
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
}
#[inline]
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
}
}
impl AsyncWrite for std::io::Cursor<Vec<u8>> {
#[inline]
fn poll_write(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
use std::io::Write as _;
let this = self.get_mut();
let n = this.write(buf)?;
Poll::Ready(Ok(n))
}
#[inline]
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
}
#[inline]
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
}
}
impl AsyncWrite for std::io::Cursor<Box<[u8]>> {
#[inline]
fn poll_write(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
use std::io::Write as _;
let this = self.get_mut();
let n = this.write(buf)?;
Poll::Ready(Ok(n))
}
#[inline]
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
}
#[inline]
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
}
}
impl<W> AsyncWrite for &mut W
where
W: AsyncWrite + Unpin + ?Sized,
{
#[inline]
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
let this = self.get_mut();
Pin::new(&mut **this).poll_write(cx, buf)
}
#[inline]
fn poll_write_vectored(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[IoSlice<'_>],
) -> Poll<io::Result<usize>> {
let this = self.get_mut();
Pin::new(&mut **this).poll_write_vectored(cx, bufs)
}
#[inline]
fn is_write_vectored(&self) -> bool {
(**self).is_write_vectored()
}
#[inline]
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
let this = self.get_mut();
Pin::new(&mut **this).poll_flush(cx)
}
#[inline]
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
let this = self.get_mut();
Pin::new(&mut **this).poll_shutdown(cx)
}
}
impl<W> AsyncWrite for Box<W>
where
W: AsyncWrite + Unpin + ?Sized,
{
#[inline]
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
let this = self.get_mut();
Pin::new(&mut **this).poll_write(cx, buf)
}
#[inline]
fn poll_write_vectored(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[IoSlice<'_>],
) -> Poll<io::Result<usize>> {
let this = self.get_mut();
Pin::new(&mut **this).poll_write_vectored(cx, bufs)
}
#[inline]
fn is_write_vectored(&self) -> bool {
(**self).is_write_vectored()
}
#[inline]
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
let this = self.get_mut();
Pin::new(&mut **this).poll_flush(cx)
}
#[inline]
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
let this = self.get_mut();
Pin::new(&mut **this).poll_shutdown(cx)
}
}
impl<W, P> AsyncWrite for Pin<P>
where
P: DerefMut<Target = W> + Unpin,
W: AsyncWrite + ?Sized,
{
#[inline]
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
self.get_mut().as_mut().poll_write(cx, buf)
}
#[inline]
fn poll_write_vectored(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[IoSlice<'_>],
) -> Poll<io::Result<usize>> {
self.get_mut().as_mut().poll_write_vectored(cx, bufs)
}
#[inline]
fn is_write_vectored(&self) -> bool {
(**self).is_write_vectored()
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
self.get_mut().as_mut().poll_flush(cx)
}
#[inline]
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
self.get_mut().as_mut().poll_shutdown(cx)
}
}
#[cfg(test)]
mod tests {
#![allow(
clippy::pedantic,
clippy::nursery,
clippy::expect_fun_call,
clippy::map_unwrap_or,
clippy::cast_possible_wrap,
clippy::future_not_send
)]
use super::*;
use pin_project::pin_project;
use std::marker::PhantomPinned;
use std::task::{Context, Waker};
fn noop_waker() -> Waker {
std::task::Waker::noop().clone()
}
fn init_test(name: &str) {
crate::test_utils::init_test_logging();
crate::test_phase!(name);
}
#[test]
fn write_to_vec() {
init_test("write_to_vec");
let mut output = Vec::new();
let waker = noop_waker();
let mut cx = Context::from_waker(&waker);
let poll = Pin::new(&mut output).poll_write(&mut cx, b"hello");
let ready = matches!(poll, Poll::Ready(Ok(5)));
crate::assert_with_log!(ready, "write 5", true, ready);
crate::assert_with_log!(output == b"hello", "output", b"hello", output);
crate::test_complete!("write_to_vec");
}
#[test]
fn write_empty_to_vec_reports_zero_without_mutation() {
init_test("write_empty_to_vec_reports_zero_without_mutation");
let mut output = b"prefix".to_vec();
let waker = noop_waker();
let mut cx = Context::from_waker(&waker);
let poll = Pin::new(&mut output).poll_write(&mut cx, b"");
let ready = matches!(poll, Poll::Ready(Ok(0)));
crate::assert_with_log!(ready, "write 0", true, ready);
crate::assert_with_log!(output == b"prefix", "output", b"prefix", output);
crate::test_complete!("write_empty_to_vec_reports_zero_without_mutation");
}
#[test]
fn write_to_cursor() {
init_test("write_to_cursor");
let mut buf = [0u8; 8];
let mut cursor = std::io::Cursor::new(&mut buf[..]);
let waker = noop_waker();
let mut cx = Context::from_waker(&waker);
let poll = Pin::new(&mut cursor).poll_write(&mut cx, b"test");
let ready = matches!(poll, Poll::Ready(Ok(4)));
crate::assert_with_log!(ready, "write 4", true, ready);
crate::assert_with_log!(&buf[..4] == b"test", "buf", b"test", &buf[..4]);
crate::test_complete!("write_to_cursor");
}
#[test]
fn flush_and_shutdown_vec() {
init_test("flush_and_shutdown_vec");
let mut output = Vec::new();
let waker = noop_waker();
let mut cx = Context::from_waker(&waker);
let poll = Pin::new(&mut output).poll_flush(&mut cx);
let ready = matches!(poll, Poll::Ready(Ok(())));
crate::assert_with_log!(ready, "flush ready", true, ready);
let poll = Pin::new(&mut output).poll_shutdown(&mut cx);
let ready = matches!(poll, Poll::Ready(Ok(())));
crate::assert_with_log!(ready, "shutdown ready", true, ready);
crate::test_complete!("flush_and_shutdown_vec");
}
#[derive(Default)]
struct DefaultVectoredWriter {
writes: Vec<Vec<u8>>,
}
impl AsyncWrite for DefaultVectoredWriter {
fn poll_write(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
let this = self.get_mut();
this.writes.push(buf.to_vec());
Poll::Ready(Ok(buf.len()))
}
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
}
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
}
}
#[test]
fn default_write_vectored_uses_first_non_empty_buffer() {
init_test("default_write_vectored_uses_first_non_empty_buffer");
let mut writer = DefaultVectoredWriter::default();
let bufs = [
IoSlice::new(b""),
IoSlice::new(b"alpha"),
IoSlice::new(b"omega"),
];
let waker = noop_waker();
let mut cx = Context::from_waker(&waker);
let poll = AsyncWrite::poll_write_vectored(Pin::new(&mut writer), &mut cx, &bufs);
let wrote_first = matches!(poll, Poll::Ready(Ok(5)));
crate::assert_with_log!(
wrote_first,
"first non-empty buffer length",
true,
wrote_first
);
let expected_writes = vec![b"alpha".to_vec()];
crate::assert_with_log!(
writer.writes == expected_writes,
"only first non-empty buffer written",
expected_writes,
writer.writes.clone()
);
crate::assert_with_log!(
!AsyncWrite::is_write_vectored(&writer),
"default vectored capability flag",
false,
AsyncWrite::is_write_vectored(&writer)
);
crate::test_complete!("default_write_vectored_uses_first_non_empty_buffer");
}
#[test]
fn default_write_vectored_empty_buffers_make_no_write_call() {
init_test("default_write_vectored_empty_buffers_make_no_write_call");
let mut writer = DefaultVectoredWriter::default();
let bufs = [IoSlice::new(b""), IoSlice::new(b"")];
let waker = noop_waker();
let mut cx = Context::from_waker(&waker);
let poll = AsyncWrite::poll_write_vectored(Pin::new(&mut writer), &mut cx, &bufs);
let returned_zero = matches!(poll, Poll::Ready(Ok(0)));
crate::assert_with_log!(
returned_zero,
"empty vectored write returns zero",
true,
returned_zero
);
crate::assert_with_log!(
writer.writes.is_empty(),
"no scalar write for empty buffers",
true,
writer.writes.is_empty()
);
crate::test_complete!("default_write_vectored_empty_buffers_make_no_write_call");
}
#[test]
fn vec_write_vectored_appends_all_buffers_and_reports_total() {
init_test("vec_write_vectored_appends_all_buffers_and_reports_total");
let mut output = Vec::new();
let bufs = [IoSlice::new(b"ab"), IoSlice::new(b""), IoSlice::new(b"cd")];
let waker = noop_waker();
let mut cx = Context::from_waker(&waker);
let poll = AsyncWrite::poll_write_vectored(Pin::new(&mut output), &mut cx, &bufs);
let wrote_four = matches!(poll, Poll::Ready(Ok(4)));
crate::assert_with_log!(wrote_four, "vectored byte count", true, wrote_four);
crate::assert_with_log!(
output == b"abcd",
"vectored output",
b"abcd",
output.as_slice()
);
crate::assert_with_log!(
AsyncWrite::is_write_vectored(&output),
"vec advertises vectored writes",
true,
AsyncWrite::is_write_vectored(&output)
);
crate::test_complete!("vec_write_vectored_appends_all_buffers_and_reports_total");
}
#[test]
fn write_via_ref() {
init_test("write_via_ref");
let mut output = Vec::new();
let mut writer = &mut output;
let waker = noop_waker();
let mut cx = Context::from_waker(&waker);
let poll = Pin::new(&mut writer).poll_write(&mut cx, b"via ref");
let ready = matches!(poll, Poll::Ready(Ok(7)));
crate::assert_with_log!(ready, "write 7", true, ready);
crate::assert_with_log!(output == b"via ref", "output", b"via ref", output);
crate::test_complete!("write_via_ref");
}
#[test]
fn write_via_box() {
init_test("write_via_box");
let mut output: Box<Vec<u8>> = Box::default();
let waker = noop_waker();
let mut cx = Context::from_waker(&waker);
let poll = Pin::new(&mut output).poll_write(&mut cx, b"boxed");
let ready = matches!(poll, Poll::Ready(Ok(5)));
crate::assert_with_log!(ready, "write 5", true, ready);
crate::assert_with_log!(*output == b"boxed", "output", b"boxed", *output);
crate::test_complete!("write_via_box");
}
#[pin_project]
struct PinnedWriter<W> {
#[pin]
inner: W,
_pin: PhantomPinned,
}
impl<W> AsyncWrite for PinnedWriter<W>
where
W: AsyncWrite,
{
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
self.project().inner.poll_write(cx, buf)
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
self.project().inner.poll_flush(cx)
}
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
self.project().inner.poll_shutdown(cx)
}
}
#[test]
fn pin_wrapper_write_supports_non_unpin_inner() {
init_test("pin_wrapper_write_supports_non_unpin_inner");
let mut writer = Box::pin(PinnedWriter {
inner: Vec::<u8>::new(),
_pin: PhantomPinned,
});
let waker = noop_waker();
let mut cx = Context::from_waker(&waker);
let poll = Pin::new(&mut writer).poll_write(&mut cx, b"ok");
let ready = matches!(poll, Poll::Ready(Ok(2)));
crate::assert_with_log!(ready, "write 2", true, ready);
crate::assert_with_log!(
writer.as_ref().get_ref().inner == b"ok",
"inner output",
b"ok",
writer.as_ref().get_ref().inner
);
crate::test_complete!("pin_wrapper_write_supports_non_unpin_inner");
}
}