1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
use std::marker::PhantomData;
use async_codec::{AsyncEncode, AsyncEncodeLen, PollEnc};
use async_codec::PollEnc::{Done, Progress, Pending, Errored};
use futures_core::task::Context;
use futures_io::AsyncWrite;
enum State<S, T> {
First(S, T),
Second(T),
}
pub struct Chain<W, S, T>(State<S, T>, PhantomData<W>);
impl<W, S, T> Chain<W, S, T> {
pub fn new(first: S, second: T) -> Chain<W, S, T> {
Chain(State::First(first, second), PhantomData)
}
}
impl<W, S, T> AsyncEncode<W> for Chain<W, S, T>
where W: AsyncWrite,
S: AsyncEncode<W>,
T: AsyncEncode<W>
{
fn poll_encode(mut self, cx: &mut Context, writer: &mut W) -> PollEnc<Self> {
match self.0 {
State::First(first, second) => {
match first.poll_encode(cx, writer) {
Done(written) => {
self.0 = State::Second(second);
Progress(self, written)
}
Progress(first, written) => {
self.0 = State::First(first, second);
Progress(self, written)
}
Pending(first) => {
self.0 = State::First(first, second);
Pending(self)
}
Errored(err) => Errored(err),
}
}
State::Second(second) => {
match second.poll_encode(cx, writer) {
Done(written) => Done(written),
Progress(second, written) => {
self.0 = State::Second(second);
Progress(self, written)
}
Pending(second) => {
self.0 = State::Second(second);
Pending(self)
}
Errored(err) => Errored(err),
}
}
}
}
}
impl<W, S, T> AsyncEncodeLen<W> for Chain<W, S, T>
where W: AsyncWrite,
S: AsyncEncodeLen<W>,
T: AsyncEncodeLen<W>
{
fn remaining_bytes(&self) -> usize {
match self.0 {
State::First(ref first, ref second) => {
first.remaining_bytes() + second.remaining_bytes()
}
State::Second(ref second) => second.remaining_bytes(),
}
}
}
#[cfg(test)]
mod tests {
use atm_io_utils::partial::*;
use async_ringbuffer::ring_buffer;
use async_byteorder::{decode_i32_native, decode_u64_native, encode_i32_native,
encode_u64_native};
use super::super::super::testing::test_codec_len;
use super::super::super::decoder::chain as dec_chain;
use super::super::super::encoder::chain as enc_chain;
quickcheck! {
fn codec(buf_size: usize, read_ops: Vec<PartialOp>, write_ops: Vec<PartialOp>, int_0: i32, int_1: u64) -> bool {
let mut read_ops = read_ops;
let mut write_ops = write_ops;
let (w, r) = ring_buffer(buf_size + 1);
let w = PartialWrite::new(w, write_ops.drain(..));
let r = PartialRead::new(r, read_ops.drain(..));
let test_outcome = test_codec_len(r, w, dec_chain(decode_i32_native(), decode_u64_native()), enc_chain(encode_i32_native(int_0), encode_u64_native(int_1)));
test_outcome.1 && (test_outcome.0).0 == int_0 && (test_outcome.0).1 == int_1
}
}
}