async_codec_util/encoder/
chain.rs

1use async_codec::{AsyncEncode, AsyncEncodeLen, PollEnc};
2use async_codec::PollEnc::{Done, Progress, Pending, Errored};
3use futures_core::task::Context;
4use futures_io::AsyncWrite;
5
6enum State<S, T> {
7    First(S, T),
8    Second(T),
9}
10
11/// Wraps two `AsyncEncode`s and encodes them in sequence.
12pub struct Chain<S, T>(State<S, T>);
13
14impl<S, T> Chain<S, T> {
15    /// Create new `Chain` which first encodes the given `S` and then encodes the given `T`.
16    pub fn new(first: S, second: T) -> Chain<S, T> {
17        Chain(State::First(first, second))
18    }
19}
20
21impl<S, T> AsyncEncode for Chain<S, T>
22    where S: AsyncEncode,
23          T: AsyncEncode
24{
25    fn poll_encode<W: AsyncWrite>(mut self, cx: &mut Context, writer: &mut W) -> PollEnc<Self> {
26        match self.0 {
27            State::First(first, second) => {
28                match first.poll_encode(cx, writer) {
29                    Done(written) => {
30                        self.0 = State::Second(second);
31                        Progress(self, written)
32                    }
33                    Progress(first, written) => {
34                        self.0 = State::First(first, second);
35                        Progress(self, written)
36                    }
37                    Pending(first) => {
38                        self.0 = State::First(first, second);
39                        Pending(self)
40                    }
41                    Errored(err) => Errored(err),
42                }
43            }
44
45            State::Second(second) => {
46                match second.poll_encode(cx, writer) {
47                    Done(written) => Done(written),
48                    Progress(second, written) => {
49                        self.0 = State::Second(second);
50                        Progress(self, written)
51                    }
52                    Pending(second) => {
53                        self.0 = State::Second(second);
54                        Pending(self)
55                    }
56                    Errored(err) => Errored(err),
57                }
58            }
59        }
60    }
61}
62
63impl<S, T> AsyncEncodeLen for Chain<S, T>
64    where S: AsyncEncodeLen,
65          T: AsyncEncodeLen
66{
67    fn remaining_bytes(&self) -> usize {
68        match self.0 {
69            State::First(ref first, ref second) => {
70                first.remaining_bytes() + second.remaining_bytes()
71            }
72            State::Second(ref second) => second.remaining_bytes(),
73        }
74    }
75}
76
77#[cfg(test)]
78mod tests {
79    use atm_io_utils::partial::*;
80    use async_ringbuffer::ring_buffer;
81
82    use async_byteorder::{decode_i32_native, decode_u64_native, encode_i32_native,
83                          encode_u64_native};
84    use super::super::super::testing::test_codec_len;
85    use super::super::super::decoder::chain as dec_chain;
86    use super::super::super::encoder::chain as enc_chain;
87
88    quickcheck! {
89        fn codec(buf_size: usize, read_ops: Vec<PartialOp>, write_ops: Vec<PartialOp>, int_0: i32, int_1: u64) -> bool {
90            let mut read_ops = read_ops;
91            let mut write_ops = write_ops;
92            let (w, r) = ring_buffer(buf_size + 1);
93            let w = PartialWrite::new(w, write_ops.drain(..));
94            let r = PartialRead::new(r, read_ops.drain(..));
95
96            let test_outcome = test_codec_len(r, w, dec_chain(decode_i32_native(), decode_u64_native()), enc_chain(encode_i32_native(int_0), encode_u64_native(int_1)));
97            test_outcome.1 && (test_outcome.0).0 == int_0 && (test_outcome.0).1 == int_1
98        }
99    }
100}