1use core::{fmt, result::Result};
6use tokio;
7use tokio::sync::{mpsc, mpsc::error::SendError, oneshot, oneshot::error::RecvError};
8
9impl std::error::Error for ChannelError {}
10impl fmt::Display for ChannelError {
11 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
12 match self {
13 ChannelError::SendError => write!(fmt, "channel send error"),
14 ChannelError::RecvError => write!(fmt, "channel receiver error"),
15 }
16 }
17}
18#[derive(Debug)]
21pub enum ChannelError {
22 SendError,
25 RecvError,
28}
29
30impl<Q, S, E> From<SendError<Message<Q, S, E>>> for ChannelError {
31 fn from(_: SendError<Message<Q, S, E>>) -> Self {
32 ChannelError::SendError
33 }
34}
35
36impl From<RecvError> for ChannelError {
37 fn from(_: RecvError) -> Self {
38 ChannelError::RecvError
39 }
40}
41
42pub struct Message<Q, S, E> {
45 pub request: Q,
47 sender: oneshot::Sender<Result<S, E>>,
48}
49
50impl<Q, S, E> Message<Q, S, E> {
51 pub fn respond(self, response: Result<S, E>) -> bool {
54 self.sender.send(response).map_or_else(|_| false, |_| true)
55 }
56}
57
58impl<Q: std::fmt::Debug, S, E> fmt::Debug for Message<Q, S, E> {
59 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
60 f.debug_struct("Message")
61 .field("request", &self.request)
62 .finish()
63 }
64}
65
66#[derive(Debug)]
71pub struct MessageClient<Q, S, E> {
72 tx: mpsc::UnboundedSender<Message<Q, S, E>>,
73}
74
75impl<Q, S, E> MessageClient<Q, S, E> {
76 #[allow(dead_code)]
79 pub fn send_oneshot(&self, request: Q) -> Result<(), ChannelError> {
80 self.send_request_impl(request)
81 .map(|_| Ok(()))
82 .map_err(ChannelError::from)?
83 }
84
85 pub async fn send(&self, request: Q) -> Result<Result<S, E>, ChannelError> {
88 let rx = self
89 .send_request_impl(request)
90 .map_err(ChannelError::from)?;
91 rx.await.map_err(|e| e.into())
92 }
93
94 #[allow(clippy::type_complexity)]
95 fn send_request_impl(
96 &self,
97 request: Q,
98 ) -> Result<oneshot::Receiver<Result<S, E>>, SendError<Message<Q, S, E>>> {
99 let (tx, rx) = oneshot::channel::<Result<S, E>>();
100 let message = Message {
101 sender: tx,
102 request,
103 };
104
105 self.tx.send(message).map(|_| rx)
106 }
107}
108
109impl<Q, S, E> Clone for MessageClient<Q, S, E> {
112 fn clone(&self) -> Self {
113 MessageClient {
114 tx: self.tx.clone(),
115 }
116 }
117}
118
119pub struct MessageProcessor<Q, S, E> {
124 rx: mpsc::UnboundedReceiver<Message<Q, S, E>>,
125}
126
127impl<Q, S, E> MessageProcessor<Q, S, E> {
128 pub async fn pull_message(&mut self) -> Option<Message<Q, S, E>> {
131 self.rx.recv().await
132 }
133}
134
135pub fn message_channel<Q, S, E>() -> (MessageClient<Q, S, E>, MessageProcessor<Q, S, E>) {
137 let (tx, rx) = mpsc::unbounded_channel::<Message<Q, S, E>>();
138 let processor = MessageProcessor::<Q, S, E> { rx };
139 let client = MessageClient::<Q, S, E> { tx };
140 (client, processor)
141}
142
143#[cfg(test)]
144mod tests {
145 enum Request {
146 Ping(),
147 SetFlag(u32),
148 Shutdown(),
149 Throw(),
150 }
151
152 enum Response {
153 Pong(),
154 GenericResult(bool),
155 }
156 struct TestError {
157 pub message: String,
158 }
159 use super::*;
160 #[tokio::test]
161 async fn test_message_channel() -> Result<(), Box<dyn std::error::Error>> {
162 let (client, mut processor) = message_channel::<Request, Response, TestError>();
163
164 tokio::spawn(async move {
165 let mut set_flags: usize = 0;
166
167 loop {
168 let message = processor.pull_message().await;
169 match message {
170 Some(m) => match m.request {
171 Request::Ping() => {
172 let success = m.respond(Ok(Response::Pong()));
173 assert!(success, "receiver not closed");
174 }
175 Request::Throw() => {
176 m.respond(Err(TestError {
177 message: String::from("thrown!"),
178 }));
179 }
180 Request::SetFlag(_) => {
181 set_flags += 1;
182 let success = m.respond(Ok(Response::GenericResult(true)));
183 assert!(
184 !success,
185 "one-way requests should not successfully respond."
186 );
187 }
188 Request::Shutdown() => {
189 assert_eq!(set_flags, 10, "One-way requests successfully processed.");
190 let success = m.respond(Ok(Response::GenericResult(true)));
191 assert!(success);
192 return;
193 }
194 },
195 None => panic!("message queue empty"),
196 }
197 }
198 });
199
200 let res = client.send(Request::Ping()).await?;
201 matches!(res, Ok(Response::Pong()));
202
203 for n in 0..10 {
204 client.send_oneshot(Request::SetFlag(n))?;
205 }
206
207 let res = client.send(Request::Throw()).await?;
208 assert!(
209 match res {
210 Ok(_) => false,
211 Err(TestError { message }) => {
212 assert_eq!(message, String::from("thrown!"));
213 true
214 }
215 },
216 "User Error propagates to client."
217 );
218
219 let res = client.send(Request::Shutdown()).await?;
220 assert!(
221 match res {
222 Ok(Response::GenericResult(success)) => success,
223 _ => false,
224 },
225 "successfully shutdown processing thread."
226 );
227
228 Ok(())
229 }
230}