use crate::{futures_util::FuturesOps, PartialOp};
use futures::{io, prelude::*};
use pin_project::pin_project;
use std::{
fmt,
pin::Pin,
task::{Context, Poll},
};
#[pin_project]
pub struct PartialAsyncWrite<W> {
#[pin]
inner: W,
ops: FuturesOps,
}
impl<W> PartialAsyncWrite<W> {
pub fn new<I>(inner: W, iter: I) -> Self
where
I: IntoIterator<Item = PartialOp> + 'static,
I::IntoIter: Send,
{
PartialAsyncWrite {
inner,
ops: FuturesOps::new(iter),
}
}
pub fn set_ops<I>(&mut self, iter: I) -> &mut Self
where
I: IntoIterator<Item = PartialOp> + 'static,
I::IntoIter: Send,
{
self.ops.replace(iter);
self
}
pub fn pin_set_ops<I>(self: Pin<&mut Self>, iter: I) -> Pin<&mut Self>
where
I: IntoIterator<Item = PartialOp> + 'static,
I::IntoIter: Send,
{
let mut this = self;
this.as_mut().project().ops.replace(iter);
this
}
pub fn get_ref(&self) -> &W {
&self.inner
}
pub fn get_mut(&mut self) -> &mut W {
&mut self.inner
}
pub fn pin_get_mut(self: Pin<&mut Self>) -> Pin<&mut W> {
self.project().inner
}
pub fn into_inner(self) -> W {
self.inner
}
}
impl<W> AsyncWrite for PartialAsyncWrite<W>
where
W: AsyncWrite,
{
fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll<io::Result<usize>> {
let this = self.project();
let inner = this.inner;
this.ops.poll_impl(
cx,
|cx, len| match len {
Some(len) => inner.poll_write(cx, &buf[..len]),
None => inner.poll_write(cx, buf),
},
buf.len(),
"error during poll_write, generated by partial-io",
)
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
let this = self.project();
let inner = this.inner;
this.ops.poll_impl_no_limit(
cx,
|cx| inner.poll_flush(cx),
"error during poll_flush, generated by partial-io",
)
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
let this = self.project();
let inner = this.inner;
this.ops.poll_impl_no_limit(
cx,
|cx| inner.poll_close(cx),
"error during poll_close, generated by partial-io",
)
}
}
impl<W> AsyncRead for PartialAsyncWrite<W>
where
W: AsyncRead,
{
#[inline]
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
self.project().inner.poll_read(cx, buf)
}
#[inline]
fn poll_read_vectored(
self: Pin<&mut Self>,
cx: &mut Context,
bufs: &mut [io::IoSliceMut],
) -> Poll<io::Result<usize>> {
self.project().inner.poll_read_vectored(cx, bufs)
}
}
impl<W> AsyncBufRead for PartialAsyncWrite<W>
where
W: AsyncBufRead,
{
#[inline]
fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<&[u8]>> {
self.project().inner.poll_fill_buf(cx)
}
#[inline]
fn consume(self: Pin<&mut Self>, amt: usize) {
self.project().inner.consume(amt)
}
}
impl<W> AsyncSeek for PartialAsyncWrite<W>
where
W: AsyncSeek,
{
#[inline]
fn poll_seek(
self: Pin<&mut Self>,
cx: &mut Context,
pos: io::SeekFrom,
) -> Poll<io::Result<u64>> {
self.project().inner.poll_seek(cx, pos)
}
}
#[cfg(feature = "tokio1")]
mod tokio_impl {
use super::PartialAsyncWrite;
use std::{
io::{self, SeekFrom},
pin::Pin,
task::{Context, Poll},
};
use tokio::io::{AsyncBufRead, AsyncRead, AsyncSeek, AsyncWrite, ReadBuf};
impl<W> AsyncWrite for PartialAsyncWrite<W>
where
W: AsyncWrite,
{
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context,
buf: &[u8],
) -> Poll<io::Result<usize>> {
let this = self.project();
let inner = this.inner;
this.ops.poll_impl(
cx,
|cx, len| match len {
Some(len) => inner.poll_write(cx, &buf[..len]),
None => inner.poll_write(cx, buf),
},
buf.len(),
"error during poll_write, generated by partial-io",
)
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
let this = self.project();
let inner = this.inner;
this.ops.poll_impl_no_limit(
cx,
|cx| inner.poll_flush(cx),
"error during poll_flush, generated by partial-io",
)
}
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
let this = self.project();
let inner = this.inner;
this.ops.poll_impl_no_limit(
cx,
|cx| inner.poll_shutdown(cx),
"error during poll_shutdown, generated by partial-io",
)
}
}
impl<W> AsyncRead for PartialAsyncWrite<W>
where
W: AsyncRead,
{
#[inline]
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
self.project().inner.poll_read(cx, buf)
}
}
impl<W> AsyncBufRead for PartialAsyncWrite<W>
where
W: AsyncBufRead,
{
#[inline]
fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<&[u8]>> {
self.project().inner.poll_fill_buf(cx)
}
#[inline]
fn consume(self: Pin<&mut Self>, amt: usize) {
self.project().inner.consume(amt)
}
}
impl<W> AsyncSeek for PartialAsyncWrite<W>
where
W: AsyncSeek,
{
#[inline]
fn start_seek(self: Pin<&mut Self>, position: SeekFrom) -> io::Result<()> {
self.project().inner.start_seek(position)
}
#[inline]
fn poll_complete(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<u64>> {
self.project().inner.poll_complete(cx)
}
}
}
impl<W> fmt::Debug for PartialAsyncWrite<W>
where
W: fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("PartialAsyncWrite")
.field("inner", &self.inner)
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::fs::File;
use crate::tests::assert_send;
#[test]
fn test_sendable() {
assert_send::<PartialAsyncWrite<File>>();
}
}