1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
//! An in-memory channel for testing purposes. Allows sending items and errors to a receiver.

use futures_core::{Stream, Poll, Async, Never};
use futures_core::task::Context;
use futures_sink::Sink;
use futures_channel::mpsc::{channel, Sender, Receiver};

/// Create a test channel of a given capacity.
///
/// `I` is the type of items sent over the channel, `E` is the type of errors sent over the channel.
///
/// # Panics
/// Panics if the given capacity is 0.
pub fn test_channel<I, E>(capacity: usize) -> (TestSender<I, E>, TestReceiver<I, E>) {
    if capacity == 0 {
        panic!("TestChannel must have capacity greater than 0")
    }
    let (sender, receiver) = channel(capacity - 1);
    (TestSender::new(sender), TestReceiver::new(receiver))
}

/// The transmission end of a test channel.
///
/// This is built upon `futures::channel::mpcs::sender` and panics if the underlying `Sender` emits
/// an error.
pub struct TestSender<I, E>(Sender<Result<I, E>>);

impl<I, E> TestSender<I, E> {
    fn new(sender: Sender<Result<I, E>>) -> TestSender<I, E> {
        TestSender(sender)
    }
}

impl<I, E> Sink for TestSender<I, E> {
    type SinkItem = Result<I, E>;
    type SinkError = Never;

    fn poll_ready(&mut self, cx: &mut Context) -> Poll<(), Self::SinkError> {
        match self.0.poll_ready(cx) {
            Err(err) => panic!("TestSender got a send error: {:?}", err),
            Ok(non_err) => Ok(non_err),
        }
    }

    fn start_send(&mut self, item: Self::SinkItem) -> Result<(), Self::SinkError> {
        match self.0.start_send(item) {
            Err(err) => panic!("TestSender got a send error: {:?}", err),
            Ok(non_err) => Ok(non_err),
        }
    }

    fn poll_flush(&mut self, cx: &mut Context) -> Poll<(), Self::SinkError> {
        match self.0.poll_flush(cx) {
            Err(err) => panic!("TestSender got a send error: {:?}", err),
            Ok(non_err) => Ok(non_err),
        }
    }

    fn poll_close(&mut self, cx: &mut Context) -> Poll<(), Self::SinkError> {
        match self.0.poll_close(cx) {
            Err(err) => panic!("TestSender got a send error: {:?}", err),
            Ok(non_err) => Ok(non_err),
        }
    }
}

/// The receiving end of a test channel.
pub struct TestReceiver<I, E>(Receiver<Result<I, E>>);

impl<I, E> TestReceiver<I, E> {
    fn new(receiver: Receiver<Result<I, E>>) -> TestReceiver<I, E> {
        TestReceiver(receiver)
    }
}

impl<I, E> Stream for TestReceiver<I, E> {
    type Item = I;
    type Error = E;

    fn poll_next(&mut self, cx: &mut Context) -> Poll<Option<Self::Item>, Self::Error> {
        match self.0.poll_next(cx) {
            Ok(Async::Ready(Some(Ok(item)))) => Ok(Async::Ready(Some(item))),
            Ok(Async::Ready(Some(Err(err)))) => Err(err),
            Ok(Async::Ready(None)) => Ok(Async::Ready(None)),
            Ok(Async::Pending) => Ok(Async::Pending),
            Err(_) => unreachable!(),
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    use futures::{SinkExt, StreamExt, FutureExt};
    use futures::sink::close;
    use futures::stream::iter_ok;
    use futures::executor::block_on;

    #[test]
    fn it_works() {
        let (sender, receiver) = test_channel(2);

        let send_stuff = sender
            .send_all(iter_ok::<_, Never>(vec![Ok(0), Ok(1), Err(0), Ok(2), Err(1)]))
            .and_then(|(sender, _)| close(sender).map(|_| ()));

        let receive_stuff = receiver
            .then(|result| match result {
                      Ok(foo) => Ok(Ok(foo)),
                      Err(err) => Ok(Err(err)),
                  })
            .collect()
            .map(|results| {
                     assert_eq!(results, vec![Ok(0), Ok(1), Err(0), Ok(2), Err(1)]);
                 });

        assert!(block_on(receive_stuff.join(send_stuff)).is_ok());
    }
}