asteroid_mq/protocol/node/edge/connection/
tokio_channel.rs1use asteroid_mq_model::{
2 connection::{EdgeConnectionError, EdgeConnectionErrorKind, EdgeNodeConnection},
3 EdgePayload,
4};
5use futures_util::{Sink, Stream};
6use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender};
7impl TokioChannelSocket {
8 pub fn pair() -> (TokioChannelSocket, TokioChannelSocket) {
9 let (tx_a_to_b, rx_a_to_b) = unbounded_channel();
10 let (tx_b_to_a, rx_b_to_a) = unbounded_channel();
11 (
12 TokioChannelSocket {
13 tx: tx_a_to_b,
14 rx: rx_b_to_a,
15 count: 0,
16 report_every: 1000,
17 },
18 TokioChannelSocket {
19 tx: tx_b_to_a,
20 rx: rx_a_to_b,
21 count: 0,
22 report_every: 1000,
23 },
24 )
25 }
26}
27
28pin_project_lite::pin_project! {
29 pub struct TokioChannelSocket {
30 #[pin]
31 tx: UnboundedSender<EdgePayload>,
32 #[pin]
33 rx: UnboundedReceiver<EdgePayload>,
34 count: usize,
35 report_every: usize,
36 }
37}
38
39impl Stream for TokioChannelSocket {
40 type Item = Result<EdgePayload, EdgeConnectionError>;
41 fn poll_next(
42 self: std::pin::Pin<&mut Self>,
43 cx: &mut std::task::Context<'_>,
44 ) -> std::task::Poll<Option<Self::Item>> {
45 let mut this = self.project();
46 let message = futures_util::ready!(this.rx.poll_recv(cx));
47 std::task::Poll::Ready(message.map(Ok))
48 }
49}
50
51impl Sink<EdgePayload> for TokioChannelSocket {
52 type Error = EdgeConnectionError;
53 fn poll_close(
54 self: std::pin::Pin<&mut Self>,
55 _cx: &mut std::task::Context<'_>,
56 ) -> std::task::Poll<Result<(), Self::Error>> {
57 std::task::Poll::Ready(Ok(()))
58 }
59 fn poll_flush(
60 self: std::pin::Pin<&mut Self>,
61 _cx: &mut std::task::Context<'_>,
62 ) -> std::task::Poll<Result<(), Self::Error>> {
63 std::task::Poll::Ready(Ok(()))
64 }
65 fn poll_ready(
66 self: std::pin::Pin<&mut Self>,
67 _cx: &mut std::task::Context<'_>,
68 ) -> std::task::Poll<Result<(), Self::Error>> {
69 std::task::Poll::Ready(Ok(()))
70 }
71 fn start_send(self: std::pin::Pin<&mut Self>, item: EdgePayload) -> Result<(), Self::Error> {
72 let this = self.project();
73 let send_result = this.tx.send(item);
74 if send_result.is_err() {
75 return Err(EdgeConnectionError::new(
76 EdgeConnectionErrorKind::Closed,
77 "tokio channel closed",
78 ));
79 } else {
80 *this.count = this.count.wrapping_add_signed(1);
81 if *this.count % *this.report_every == 0 {
82 tracing::info!(count = *this.count, "tokio channel send count");
83 }
84 }
85 Ok(())
86 }
87}
88impl EdgeNodeConnection for TokioChannelSocket {}