use std::cmp;
use std::fmt;
use std::io::{self, Read, Write};
use futures::{task, Poll};
use tokio_io::{AsyncRead, AsyncWrite};
use crate::{make_ops, PartialOp};
pub struct PartialAsyncRead<R> {
inner: R,
ops: Box<dyn Iterator<Item = PartialOp> + Send>,
}
impl<R> PartialAsyncRead<R>
where
R: AsyncRead,
{
pub fn new<I>(inner: R, iter: I) -> Self
where
I: IntoIterator<Item = PartialOp> + 'static,
I::IntoIter: Send,
{
PartialAsyncRead {
inner,
ops: make_ops(iter),
}
}
pub fn set_ops<I>(&mut self, iter: I) -> &mut Self
where
I: IntoIterator<Item = PartialOp> + 'static,
I::IntoIter: Send,
{
self.ops = make_ops(iter);
self
}
pub fn get_ref(&self) -> &R {
&self.inner
}
pub fn get_mut(&mut self) -> &mut R {
&mut self.inner
}
pub fn into_inner(self) -> R {
self.inner
}
}
impl<R> Read for PartialAsyncRead<R>
where
R: AsyncRead,
{
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
match self.ops.next() {
Some(PartialOp::Limited(n)) => {
let len = cmp::min(n, buf.len());
self.inner.read(&mut buf[..len])
}
Some(PartialOp::Err(err)) => {
if err == io::ErrorKind::WouldBlock {
task::park().unpark();
}
Err(io::Error::new(
err,
"error during read, generated by partial-io",
))
}
Some(PartialOp::Unlimited) | None => self.inner.read(buf),
}
}
}
impl<R> AsyncRead for PartialAsyncRead<R> where R: AsyncRead {}
impl<R> Write for PartialAsyncRead<R>
where
R: AsyncRead + Write,
{
#[inline]
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.inner.write(buf)
}
#[inline]
fn flush(&mut self) -> io::Result<()> {
self.inner.flush()
}
}
impl<R> AsyncWrite for PartialAsyncRead<R>
where
R: AsyncRead + AsyncWrite,
{
#[inline]
fn shutdown(&mut self) -> Poll<(), io::Error> {
self.inner.shutdown()
}
}
impl<R> fmt::Debug for PartialAsyncRead<R>
where
R: fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("PartialAsyncRead")
.field("inner", &self.inner)
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::fs::File;
use crate::tests::assert_send;
#[test]
fn test_sendable() {
assert_send::<PartialAsyncRead<File>>();
}
}