Skip to main content

asteroid_mq/protocol/node/edge/connection/
tokio_channel.rs

1use asteroid_mq_model::{
2    connection::{EdgeConnectionError, EdgeConnectionErrorKind, EdgeNodeConnection},
3    EdgePayload,
4};
5use futures_util::{Sink, Stream};
6use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender};
7impl TokioChannelSocket {
8    pub fn pair() -> (TokioChannelSocket, TokioChannelSocket) {
9        let (tx_a_to_b, rx_a_to_b) = unbounded_channel();
10        let (tx_b_to_a, rx_b_to_a) = unbounded_channel();
11        (
12            TokioChannelSocket {
13                tx: tx_a_to_b,
14                rx: rx_b_to_a,
15                count: 0,
16                report_every: 1000,
17            },
18            TokioChannelSocket {
19                tx: tx_b_to_a,
20                rx: rx_a_to_b,
21                count: 0,
22                report_every: 1000,
23            },
24        )
25    }
26}
27
28pin_project_lite::pin_project! {
29    pub struct TokioChannelSocket {
30        #[pin]
31        tx: UnboundedSender<EdgePayload>,
32        #[pin]
33        rx: UnboundedReceiver<EdgePayload>,
34        count: usize,
35        report_every: usize,
36    }
37}
38
39impl Stream for TokioChannelSocket {
40    type Item = Result<EdgePayload, EdgeConnectionError>;
41    fn poll_next(
42        self: std::pin::Pin<&mut Self>,
43        cx: &mut std::task::Context<'_>,
44    ) -> std::task::Poll<Option<Self::Item>> {
45        let mut this = self.project();
46        let message = futures_util::ready!(this.rx.poll_recv(cx));
47        std::task::Poll::Ready(message.map(Ok))
48    }
49}
50
51impl Sink<EdgePayload> for TokioChannelSocket {
52    type Error = EdgeConnectionError;
53    fn poll_close(
54        self: std::pin::Pin<&mut Self>,
55        _cx: &mut std::task::Context<'_>,
56    ) -> std::task::Poll<Result<(), Self::Error>> {
57        std::task::Poll::Ready(Ok(()))
58    }
59    fn poll_flush(
60        self: std::pin::Pin<&mut Self>,
61        _cx: &mut std::task::Context<'_>,
62    ) -> std::task::Poll<Result<(), Self::Error>> {
63        std::task::Poll::Ready(Ok(()))
64    }
65    fn poll_ready(
66        self: std::pin::Pin<&mut Self>,
67        _cx: &mut std::task::Context<'_>,
68    ) -> std::task::Poll<Result<(), Self::Error>> {
69        std::task::Poll::Ready(Ok(()))
70    }
71    fn start_send(self: std::pin::Pin<&mut Self>, item: EdgePayload) -> Result<(), Self::Error> {
72        let this = self.project();
73        let send_result = this.tx.send(item);
74        if send_result.is_err() {
75            return Err(EdgeConnectionError::new(
76                EdgeConnectionErrorKind::Closed,
77                "tokio channel closed",
78            ));
79        } else {
80            *this.count = this.count.wrapping_add_signed(1);
81            if *this.count % *this.report_every == 0 {
82                tracing::info!(count = *this.count, "tokio channel send count");
83            }
84        }
85        Ok(())
86    }
87}
88impl EdgeNodeConnection for TokioChannelSocket {}