noosphere_common/
channel.rs

1//! Utility wrapper around [tokio::sync::mpsc] channels, enabling multiple
2//! producers to send messages to a single subscriber, with each message
3//! able to be responded to by the subscriber.
4
5use core::{fmt, result::Result};
6use tokio;
7use tokio::sync::{mpsc, mpsc::error::SendError, oneshot, oneshot::error::RecvError};
8
9impl std::error::Error for ChannelError {}
10impl fmt::Display for ChannelError {
11    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
12        match self {
13            ChannelError::SendError => write!(fmt, "channel send error"),
14            ChannelError::RecvError => write!(fmt, "channel receiver error"),
15        }
16    }
17}
18/// Error type to wrap the potential tokio sync errors,
19/// and distinguish between user-land respond errors.
20#[derive(Debug)]
21pub enum ChannelError {
22    /// An occurred during sending a message.
23    /// From [tokio::sync::mpsc::error::SendError].
24    SendError,
25    /// An occurred during the receiving of a message.
26    /// From [tokio::sync::mpsc::error::RecvError].
27    RecvError,
28}
29
30impl<Q, S, E> From<SendError<Message<Q, S, E>>> for ChannelError {
31    fn from(_: SendError<Message<Q, S, E>>) -> Self {
32        ChannelError::SendError
33    }
34}
35
36impl From<RecvError> for ChannelError {
37    fn from(_: RecvError) -> Self {
38        ChannelError::RecvError
39    }
40}
41
42/// Represents a request to be processed in [MessageProcessor],
43/// sent from the associated [MessageClient].
44pub struct Message<Q, S, E> {
45    /// The initial request the [Message] is wrapping.
46    pub request: Q,
47    sender: oneshot::Sender<Result<S, E>>,
48}
49
50impl<Q, S, E> Message<Q, S, E> {
51    /// Send `response` to the originator of this [Message].
52    /// Each message can only be responded to once.
53    pub fn respond(self, response: Result<S, E>) -> bool {
54        self.sender.send(response).map_or_else(|_| false, |_| true)
55    }
56}
57
58impl<Q: std::fmt::Debug, S, E> fmt::Debug for Message<Q, S, E> {
59    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
60        f.debug_struct("Message")
61            .field("request", &self.request)
62            .finish()
63    }
64}
65
66/// Sends requests to the associated `MessageProcessor`.
67///
68/// Instances are created by the
69/// [`message_channel`](message_channel) function.
70#[derive(Debug)]
71pub struct MessageClient<Q, S, E> {
72    tx: mpsc::UnboundedSender<Message<Q, S, E>>,
73}
74
75impl<Q, S, E> MessageClient<Q, S, E> {
76    /// Sends a one-way request to the corresponding receiver. Use
77    /// [MessageClient::send] if the receiver should be able to respond.
78    #[allow(dead_code)]
79    pub fn send_oneshot(&self, request: Q) -> Result<(), ChannelError> {
80        self.send_request_impl(request)
81            .map(|_| Ok(()))
82            .map_err(ChannelError::from)?
83    }
84
85    /// Sends a request to the corresponding receiver where it can be
86    /// responded to.
87    pub async fn send(&self, request: Q) -> Result<Result<S, E>, ChannelError> {
88        let rx = self
89            .send_request_impl(request)
90            .map_err(ChannelError::from)?;
91        rx.await.map_err(|e| e.into())
92    }
93
94    #[allow(clippy::type_complexity)]
95    fn send_request_impl(
96        &self,
97        request: Q,
98    ) -> Result<oneshot::Receiver<Result<S, E>>, SendError<Message<Q, S, E>>> {
99        let (tx, rx) = oneshot::channel::<Result<S, E>>();
100        let message = Message {
101            sender: tx,
102            request,
103        };
104
105        self.tx.send(message).map(|_| rx)
106    }
107}
108
109// Manually implement `Clone` so that the generics do not need
110// also implement.
111impl<Q, S, E> Clone for MessageClient<Q, S, E> {
112    fn clone(&self) -> Self {
113        MessageClient {
114            tx: self.tx.clone(),
115        }
116    }
117}
118
119/// Receives requests from the associated `MessageClient`,
120/// and optionally sends a response.
121///
122/// Instances are created by the [message_channel] function.
123pub struct MessageProcessor<Q, S, E> {
124    rx: mpsc::UnboundedReceiver<Message<Q, S, E>>,
125}
126
127impl<Q, S, E> MessageProcessor<Q, S, E> {
128    /// Awaits until it can return a new message to process, or
129    /// [None] if all senders have been terminated.
130    pub async fn pull_message(&mut self) -> Option<Message<Q, S, E>> {
131        self.rx.recv().await
132    }
133}
134
135/// Creates a pair of bound `MessageClient` and `MessageProcessor`.
136pub fn message_channel<Q, S, E>() -> (MessageClient<Q, S, E>, MessageProcessor<Q, S, E>) {
137    let (tx, rx) = mpsc::unbounded_channel::<Message<Q, S, E>>();
138    let processor = MessageProcessor::<Q, S, E> { rx };
139    let client = MessageClient::<Q, S, E> { tx };
140    (client, processor)
141}
142
143#[cfg(test)]
144mod tests {
145    enum Request {
146        Ping(),
147        SetFlag(u32),
148        Shutdown(),
149        Throw(),
150    }
151
152    enum Response {
153        Pong(),
154        GenericResult(bool),
155    }
156    struct TestError {
157        pub message: String,
158    }
159    use super::*;
160    #[tokio::test]
161    async fn test_message_channel() -> Result<(), Box<dyn std::error::Error>> {
162        let (client, mut processor) = message_channel::<Request, Response, TestError>();
163
164        tokio::spawn(async move {
165            let mut set_flags: usize = 0;
166
167            loop {
168                let message = processor.pull_message().await;
169                match message {
170                    Some(m) => match m.request {
171                        Request::Ping() => {
172                            let success = m.respond(Ok(Response::Pong()));
173                            assert!(success, "receiver not closed");
174                        }
175                        Request::Throw() => {
176                            m.respond(Err(TestError {
177                                message: String::from("thrown!"),
178                            }));
179                        }
180                        Request::SetFlag(_) => {
181                            set_flags += 1;
182                            let success = m.respond(Ok(Response::GenericResult(true)));
183                            assert!(
184                                !success,
185                                "one-way requests should not successfully respond."
186                            );
187                        }
188                        Request::Shutdown() => {
189                            assert_eq!(set_flags, 10, "One-way requests successfully processed.");
190                            let success = m.respond(Ok(Response::GenericResult(true)));
191                            assert!(success);
192                            return;
193                        }
194                    },
195                    None => panic!("message queue empty"),
196                }
197            }
198        });
199
200        let res = client.send(Request::Ping()).await?;
201        matches!(res, Ok(Response::Pong()));
202
203        for n in 0..10 {
204            client.send_oneshot(Request::SetFlag(n))?;
205        }
206
207        let res = client.send(Request::Throw()).await?;
208        assert!(
209            match res {
210                Ok(_) => false,
211                Err(TestError { message }) => {
212                    assert_eq!(message, String::from("thrown!"));
213                    true
214                }
215            },
216            "User Error propagates to client."
217        );
218
219        let res = client.send(Request::Shutdown()).await?;
220        assert!(
221            match res {
222                Ok(Response::GenericResult(success)) => success,
223                _ => false,
224            },
225            "successfully shutdown processing thread."
226        );
227
228        Ok(())
229    }
230}