use std::marker::PhantomData;
use async_codec::{AsyncEncode, AsyncEncodeLen, PollEnc};
use async_codec::PollEnc::{Done, Progress, Pending, Errored};
use futures_core::task::Context;
use futures_io::AsyncWrite;
enum State<S, T> {
First(S, T),
Second(T),
}
pub struct Chain<W, S, T>(State<S, T>, PhantomData<W>);
impl<W, S, T> Chain<W, S, T> {
pub fn new(first: S, second: T) -> Chain<W, S, T> {
Chain(State::First(first, second), PhantomData)
}
}
impl<W, S, T> AsyncEncode<W> for Chain<W, S, T>
where W: AsyncWrite,
S: AsyncEncode<W>,
T: AsyncEncode<W>
{
fn poll_encode(mut self, cx: &mut Context, writer: &mut W) -> PollEnc<Self> {
match self.0 {
State::First(first, second) => {
match first.poll_encode(cx, writer) {
Done(written) => {
self.0 = State::Second(second);
Progress(self, written)
}
Progress(first, written) => {
self.0 = State::First(first, second);
Progress(self, written)
}
Pending(first) => {
self.0 = State::First(first, second);
Pending(self)
}
Errored(err) => Errored(err),
}
}
State::Second(second) => {
match second.poll_encode(cx, writer) {
Done(written) => Done(written),
Progress(second, written) => {
self.0 = State::Second(second);
Progress(self, written)
}
Pending(second) => {
self.0 = State::Second(second);
Pending(self)
}
Errored(err) => Errored(err),
}
}
}
}
}
impl<W, S, T> AsyncEncodeLen<W> for Chain<W, S, T>
where W: AsyncWrite,
S: AsyncEncodeLen<W>,
T: AsyncEncodeLen<W>
{
fn remaining_bytes(&self) -> usize {
match self.0 {
State::First(ref first, ref second) => {
first.remaining_bytes() + second.remaining_bytes()
}
State::Second(ref second) => second.remaining_bytes(),
}
}
}