use std::cmp;
use std::fmt;
use std::io::{self, Read, Write};
use futures::{task, Poll};
use tokio_io::{AsyncRead, AsyncWrite};
use crate::{make_ops, PartialOp};
pub struct PartialAsyncWrite<W> {
inner: W,
ops: Box<dyn Iterator<Item = PartialOp> + Send>,
}
impl<W> PartialAsyncWrite<W>
where
W: AsyncWrite,
{
pub fn new<I>(inner: W, iter: I) -> Self
where
I: IntoIterator<Item = PartialOp> + 'static,
I::IntoIter: Send,
{
PartialAsyncWrite {
inner,
ops: make_ops(iter),
}
}
pub fn set_ops<I>(&mut self, iter: I) -> &mut Self
where
I: IntoIterator<Item = PartialOp> + 'static,
I::IntoIter: Send,
{
self.ops = make_ops(iter);
self
}
pub fn get_mut(&mut self) -> &mut W {
&mut self.inner
}
pub fn into_inner(self) -> W {
self.inner
}
}
impl<W> Write for PartialAsyncWrite<W>
where
W: Write,
{
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
match self.ops.next() {
Some(PartialOp::Limited(n)) => {
let len = cmp::min(n, buf.len());
self.inner.write(&buf[..len])
}
Some(PartialOp::Err(err)) => {
if err == io::ErrorKind::WouldBlock {
task::park().unpark();
}
Err(io::Error::new(
err,
"error during write, generated by partial-io",
))
}
Some(PartialOp::Unlimited) | None => self.inner.write(buf),
}
}
fn flush(&mut self) -> io::Result<()> {
match self.ops.next() {
Some(PartialOp::Err(err)) => {
if err == io::ErrorKind::WouldBlock {
task::park().unpark();
}
Err(io::Error::new(
err,
"error during flush, generated by partial-io",
))
}
_ => self.inner.flush(),
}
}
}
impl<W> AsyncWrite for PartialAsyncWrite<W>
where
W: AsyncWrite,
{
#[inline]
fn shutdown(&mut self) -> Poll<(), io::Error> {
self.inner.shutdown()
}
}
impl<W> Read for PartialAsyncWrite<W>
where
W: AsyncWrite + Read,
{
#[inline]
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
self.inner.read(buf)
}
}
impl<W> AsyncRead for PartialAsyncWrite<W> where W: AsyncRead + AsyncWrite {}
impl<W> fmt::Debug for PartialAsyncWrite<W>
where
W: fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("PartialAsyncWrite")
.field("inner", &self.inner)
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::fs::File;
use crate::tests::assert_send;
#[test]
fn test_sendable() {
assert_send::<PartialAsyncWrite<File>>();
}
}