atm_io_utils/
partial.rs

1//! Helpers to test partial and `Pending` io operations.
2//!
3//! Inspired by (and bluntly stealing from) the [partial-io](https://crates.io/crates/partial-io) crate.
4
5use std::cmp::min;
6
7use futures_core::Poll;
8use futures_core::Async::Pending;
9use futures_core::task::Context;
10use futures_io::{AsyncRead, AsyncWrite, Error, IoVec};
11
12/// The different operations supported by the partial wrappers.
13#[derive(Debug, Copy, Clone, PartialEq, Eq)]
14pub enum PartialOp {
15    /// Perform the io operation as normal.
16    Unlimited,
17    /// Perform the io operation, but limit it to a maximum number of bytes.
18    Limited(usize),
19    /// Emit `Ok(Async::Pending)` and reschedule the task.
20    Pending,
21}
22
23/// Wraps a reader and modifies its read operations according to the given iterator of `PartialOp`s.
24#[derive(Debug)]
25pub struct PartialRead<R, Ops> {
26    reader: R,
27    ops: Ops,
28}
29
30impl<R, Ops> PartialRead<R, Ops> {
31    /// Create a new `PartialRead`, wrapping the given `R` and modifying its io operations via the
32    /// given `Ops`.
33    pub fn new(reader: R, ops: Ops) -> PartialRead<R, Ops> {
34        PartialRead { reader, ops }
35    }
36
37    /// Gets a reference to the underlying `R`.
38    pub fn get_ref(&self) -> &R {
39        &self.reader
40    }
41
42    /// Gets a mutable reference to the underlying `R`.
43    pub fn get_mut(&mut self) -> &mut R {
44        &mut self.reader
45    }
46
47    /// Consumes this `PartialRead`, returning the underlying reader.
48    pub fn into_inner(self) -> R {
49        self.reader
50    }
51}
52
53impl<R, Ops> AsyncRead for PartialRead<R, Ops>
54    where R: AsyncRead,
55          Ops: Iterator<Item = PartialOp>
56{
57    fn poll_read(&mut self, cx: &mut Context, buf: &mut [u8]) -> Poll<usize, Error> {
58        match self.ops.next() {
59            None |
60            Some(PartialOp::Unlimited) => self.reader.poll_read(cx, buf),
61            Some(PartialOp::Pending) => {
62                cx.waker().wake();
63                Ok(Pending)
64            }
65            Some(PartialOp::Limited(n)) => {
66                let len = min(n, buf.len());
67                self.reader.poll_read(cx, &mut buf[..len])
68            }
69        }
70    }
71}
72
73impl<W, Ops> AsyncWrite for PartialRead<W, Ops>
74    where W: AsyncWrite
75{
76    fn poll_write(&mut self, cx: &mut Context, buf: &[u8]) -> Poll<usize, Error> {
77        self.reader.poll_write(cx, buf)
78    }
79
80    fn poll_flush(&mut self, cx: &mut Context) -> Poll<(), Error> {
81        self.reader.poll_flush(cx)
82    }
83
84    fn poll_close(&mut self, cx: &mut Context) -> Poll<(), Error> {
85        self.reader.poll_close(cx)
86    }
87
88    fn poll_vectored_write(&mut self, cx: &mut Context, vec: &[&IoVec]) -> Poll<usize, Error> {
89        self.reader.poll_vectored_write(cx, vec)
90    }
91}
92
93/// Wraps a reader and modifies its read operations according to the given iterator of `PartialOp`s.
94#[derive(Debug)]
95pub struct PartialWrite<W, Ops> {
96    writer: W,
97    ops: Ops,
98}
99
100impl<W, Ops> PartialWrite<W, Ops> {
101    /// Create a new `PartialWrite`, wrapping the given `W` and modifying its io operations via the
102    /// given `Ops`.
103    pub fn new(writer: W, ops: Ops) -> PartialWrite<W, Ops> {
104        PartialWrite { writer, ops }
105    }
106
107    /// Gets a reference to the underlying `W`.
108    pub fn get_ref(&self) -> &W {
109        &self.writer
110    }
111
112    /// Gets a mutable reference to the underlying `W`.
113    pub fn get_mut(&mut self) -> &mut W {
114        &mut self.writer
115    }
116
117    /// Consumes this `PartialWrite`, returning the underlying writer.
118    pub fn into_inner(self) -> W {
119        self.writer
120    }
121}
122
123impl<W, Ops> AsyncWrite for PartialWrite<W, Ops>
124    where W: AsyncWrite,
125          Ops: Iterator<Item = PartialOp>
126{
127    fn poll_write(&mut self, cx: &mut Context, buf: &[u8]) -> Poll<usize, Error> {
128        match self.ops.next() {
129            None |
130            Some(PartialOp::Unlimited) => self.writer.poll_write(cx, buf),
131            Some(PartialOp::Pending) => {
132                cx.waker().wake();
133                Ok(Pending)
134            }
135            Some(PartialOp::Limited(n)) => {
136                let len = min(n, buf.len());
137                self.writer.poll_write(cx, &buf[..len])
138            }
139        }
140    }
141
142    fn poll_flush(&mut self, cx: &mut Context) -> Poll<(), Error> {
143        match self.ops.next() {
144            Some(PartialOp::Pending) => {
145                cx.waker().wake();
146                Ok(Pending)
147            }
148            _ => self.writer.poll_flush(cx),
149        }
150    }
151
152    fn poll_close(&mut self, cx: &mut Context) -> Poll<(), Error> {
153        match self.ops.next() {
154            Some(PartialOp::Pending) => {
155                cx.waker().wake();
156                Ok(Pending)
157            }
158            _ => self.writer.poll_close(cx),
159        }
160    }
161}
162
163impl<W, Ops> AsyncRead for PartialWrite<W, Ops>
164    where W: AsyncRead
165{
166    fn poll_read(&mut self, cx: &mut Context, buf: &mut [u8]) -> Poll<usize, Error> {
167        self.writer.poll_read(cx, buf)
168    }
169}
170
171#[cfg(feature = "quickcheck")]
172mod qs {
173    use super::*;
174
175    use quickcheck::{Arbitrary, Gen, empty_shrinker};
176
177    impl Arbitrary for PartialOp {
178        fn arbitrary<G: Gen>(g: &mut G) -> Self {
179            let rnd = g.next_f32();
180            if rnd < 0.2 {
181                PartialOp::Pending
182            } else if rnd < 0.4 {
183                PartialOp::Unlimited
184            } else {
185                if g.size() <= 1 {
186                    PartialOp::Limited(1)
187                } else {
188                    let max = g.size();
189                    PartialOp::Limited(g.gen_range(1, max))
190                }
191            }
192        }
193
194        fn shrink(&self) -> Box<Iterator<Item = Self>> {
195            match *self {
196                PartialOp::Limited(n) => {
197                    Box::new(n.shrink().filter(|k| k != &0).map(PartialOp::Limited))
198                }
199                _ => empty_shrinker(),
200            }
201        }
202    }
203}
204
205#[cfg(feature = "quickcheck")]
206pub use self::qs::*;