1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
use async_codec::{AsyncEncode, AsyncEncodeLen, PollEnc};
use async_codec::PollEnc::{Done, Progress, Pending, Errored};
use futures_core::task::Context;
use futures_io::AsyncWrite;

enum State<S, T> {
    First(S, T),
    Second(T),
}

/// Wraps two `AsyncEncode`s and encodes them in sequence.
pub struct Chain<S, T>(State<S, T>);

impl<S, T> Chain<S, T> {
    /// Create new `Chain` which first encodes the given `S` and then encodes the given `T`.
    pub fn new(first: S, second: T) -> Chain<S, T> {
        Chain(State::First(first, second))
    }
}

impl<S, T> AsyncEncode for Chain<S, T>
    where S: AsyncEncode,
          T: AsyncEncode
{
    fn poll_encode<W: AsyncWrite>(mut self, cx: &mut Context, writer: &mut W) -> PollEnc<Self> {
        match self.0 {
            State::First(first, second) => {
                match first.poll_encode(cx, writer) {
                    Done(written) => {
                        self.0 = State::Second(second);
                        Progress(self, written)
                    }
                    Progress(first, written) => {
                        self.0 = State::First(first, second);
                        Progress(self, written)
                    }
                    Pending(first) => {
                        self.0 = State::First(first, second);
                        Pending(self)
                    }
                    Errored(err) => Errored(err),
                }
            }

            State::Second(second) => {
                match second.poll_encode(cx, writer) {
                    Done(written) => Done(written),
                    Progress(second, written) => {
                        self.0 = State::Second(second);
                        Progress(self, written)
                    }
                    Pending(second) => {
                        self.0 = State::Second(second);
                        Pending(self)
                    }
                    Errored(err) => Errored(err),
                }
            }
        }
    }
}

impl<S, T> AsyncEncodeLen for Chain<S, T>
    where S: AsyncEncodeLen,
          T: AsyncEncodeLen
{
    fn remaining_bytes(&self) -> usize {
        match self.0 {
            State::First(ref first, ref second) => {
                first.remaining_bytes() + second.remaining_bytes()
            }
            State::Second(ref second) => second.remaining_bytes(),
        }
    }
}

#[cfg(test)]
mod tests {
    use atm_io_utils::partial::*;
    use async_ringbuffer::ring_buffer;

    use async_byteorder::{decode_i32_native, decode_u64_native, encode_i32_native,
                          encode_u64_native};
    use super::super::super::testing::test_codec_len;
    use super::super::super::decoder::chain as dec_chain;
    use super::super::super::encoder::chain as enc_chain;

    quickcheck! {
        fn codec(buf_size: usize, read_ops: Vec<PartialOp>, write_ops: Vec<PartialOp>, int_0: i32, int_1: u64) -> bool {
            let mut read_ops = read_ops;
            let mut write_ops = write_ops;
            let (w, r) = ring_buffer(buf_size + 1);
            let w = PartialWrite::new(w, write_ops.drain(..));
            let r = PartialRead::new(r, read_ops.drain(..));

            let test_outcome = test_codec_len(r, w, dec_chain(decode_i32_native(), decode_u64_native()), enc_chain(encode_i32_native(int_0), encode_u64_native(int_1)));
            test_outcome.1 && (test_outcome.0).0 == int_0 && (test_outcome.0).1 == int_1
        }
    }
}