postage 0.5.0

An async channel library
Documentation
use std::pin::Pin;

use atomic::{Atomic, Ordering};

use crate::stream::{PollRecv, Stream};
use crate::Context;
use pin_project::pin_project;

#[derive(Copy, Clone)]
enum State {
    Left,
    Right,
    Closed,
}

#[pin_project]
pub struct ChainStream<Left, Right> {
    state: Atomic<State>,
    #[pin]
    left: Left,
    #[pin]
    right: Right,
}

impl<Left, Right> ChainStream<Left, Right>
where
    Left: Stream,
    Right: Stream<Item = Left::Item>,
{
    pub fn new(left: Left, right: Right) -> Self {
        Self {
            state: Atomic::new(State::Left),
            left,
            right,
        }
    }
}

impl<Left, Right> Stream for ChainStream<Left, Right>
where
    Left: Stream,
    Right: Stream<Item = Left::Item>,
{
    type Item = Left::Item;

    fn poll_recv(self: Pin<&mut Self>, cx: &mut Context<'_>) -> PollRecv<Self::Item> {
        let this = self.project();
        let mut state = this.state.load(Ordering::Acquire);

        if let State::Left = state {
            match this.left.poll_recv(cx) {
                PollRecv::Ready(v) => return PollRecv::Ready(v),
                PollRecv::Pending => return PollRecv::Pending,
                PollRecv::Closed => {
                    this.state.store(State::Right, Ordering::Release);
                    state = State::Right;
                }
            }
        }

        if let State::Right = state {
            match this.right.poll_recv(cx) {
                PollRecv::Ready(v) => return PollRecv::Ready(v),
                PollRecv::Pending => return PollRecv::Pending,
                PollRecv::Closed => {
                    this.state.store(State::Closed, Ordering::Release);
                    return PollRecv::Closed;
                }
            }
        }

        if let State::Closed = state {
            return PollRecv::Closed;
        }

        unreachable!();
    }
}

#[cfg(test)]
mod tests {
    use std::pin::Pin;

    use crate::test::stream::*;
    use crate::{
        stream::{PollRecv, Stream},
        Context,
    };

    use super::ChainStream;

    #[test]
    fn chain() {
        let left = from_poll_iter(vec![PollRecv::Ready(1), PollRecv::Ready(2)]);
        let right = from_poll_iter(vec![PollRecv::Ready(3)]);
        let mut find = ChainStream::new(left, right);

        let mut cx = Context::empty();

        assert_eq!(PollRecv::Ready(1), Pin::new(&mut find).poll_recv(&mut cx));
        assert_eq!(PollRecv::Ready(2), Pin::new(&mut find).poll_recv(&mut cx));
        assert_eq!(PollRecv::Ready(3), Pin::new(&mut find).poll_recv(&mut cx));
        assert_eq!(PollRecv::Closed, Pin::new(&mut find).poll_recv(&mut cx));
    }

    #[test]
    fn waits_for_right() {
        let left = from_poll_iter(vec![PollRecv::Pending]);
        let right = from_poll_iter(vec![PollRecv::Ready(1)]);
        let mut find = ChainStream::new(left, right);

        let mut cx = Context::empty();

        assert_eq!(PollRecv::Pending, Pin::new(&mut find).poll_recv(&mut cx));
        assert_eq!(PollRecv::Ready(1), Pin::new(&mut find).poll_recv(&mut cx));
        assert_eq!(PollRecv::Closed, Pin::new(&mut find).poll_recv(&mut cx));
    }

    #[test]
    fn ignores_after_close() {
        let left = from_poll_iter(vec![PollRecv::Closed, PollRecv::Ready(1)]);
        let right = from_poll_iter(vec![PollRecv::Closed, PollRecv::Ready(2)]);
        let mut find = ChainStream::new(left, right);

        let mut cx = Context::empty();

        assert_eq!(PollRecv::Closed, Pin::new(&mut find).poll_recv(&mut cx));
        assert_eq!(PollRecv::Closed, Pin::new(&mut find).poll_recv(&mut cx));
        assert_eq!(PollRecv::Closed, Pin::new(&mut find).poll_recv(&mut cx));
    }
}