1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
use std::{
    io::ErrorKind,
    pin::Pin,
    sync::Arc,
    task::{Context, Poll},
};

use async_broadcast::Receiver as ActiveReceiver;
use async_channel::Receiver;
use futures_core::{ready, stream, Future};
use futures_util::{
    future::{select, Either},
    StreamExt,
};
use static_assertions::assert_impl_all;

use crate::{Connection, Error, Message, Result};

/// A [`stream::Stream`] implementation that yields [`Message`] items.
///
/// You can convert a [`Connection`] to this type.
///
/// **NOTE**: You must ensure a `MessageStream` is continuously polled or you will experience hangs.
/// If you don't need to continuously poll the `MessageStream` but need to keep it around for later
/// use, keep the connection around and convert it into a `MessageStream` when needed. The
/// conversion is not an expensive operation so you don't need to  worry about performance, unless
/// you do it very frequently. If you need to convert back and forth frequently, you may want to
/// consider keeping both a connection and stream around.
#[derive(Clone, Debug)]
#[must_use = "streams do nothing unless polled"]
pub struct MessageStream {
    msg_receiver: ActiveReceiver<Arc<Message>>,
    error_receiver: Receiver<Error>,
}

assert_impl_all!(MessageStream: Send, Sync, Unpin);

impl stream::Stream for MessageStream {
    type Item = Result<Arc<Message>>;

    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
        let this = self.get_mut();
        let msg_fut = this.msg_receiver.next();
        let err_fut = this.error_receiver.next();
        let mut select_fut = select(msg_fut, err_fut);

        match ready!(Pin::new(&mut select_fut).poll(cx)) {
            Either::Left((msg, _)) => Poll::Ready(msg.map(Ok)),
            Either::Right((error, _)) => Poll::Ready(
                error
                    .map(|e| match &e {
                        Error::Io(io_error) => {
                            let kind = io_error.kind();
                            if kind == ErrorKind::UnexpectedEof || kind == ErrorKind::BrokenPipe {
                                None
                            } else {
                                Some(Err(e))
                            }
                        }
                        _ => Some(Err(e)),
                    })
                    .flatten(),
            ),
        }
    }
}

impl From<Connection> for MessageStream {
    fn from(conn: Connection) -> Self {
        let msg_receiver = conn.msg_receiver.activate();
        let error_receiver = conn.error_receiver;

        Self {
            msg_receiver,
            error_receiver,
        }
    }
}

impl From<&Connection> for MessageStream {
    fn from(conn: &Connection) -> Self {
        Self::from(conn.clone())
    }
}