Skip to main content

memfault_ssf/
msg_mailbox.rs

1//
2// Copyright (c) Memfault, Inc.
3// See License.txt for details
4use std::sync::mpsc::{channel, sync_channel, Receiver, Sender, SyncSender};
5
6use crate::{BoundedMailbox, BoundedTaskMailbox, Handler, Mailbox, MailboxError, Message, Service};
7
8// This type alias includes a trait bound on the generic parameter which
9// triggers clippy's `type_alias_bounds` lint. The bound is intentional to
10// tie the alias to `Message` and to keep the surrounding mocks simple.
11// Allow the lint here so the compiler doesn't warn in downstream builds.
12#[allow(type_alias_bounds)]
13type MockReplySender<M: Message> = Receiver<(M, Sender<M::Reply>)>;
14
15/// A `MsgMailbox` only depends on the type of the messages it can contain.
16///
17/// This allows a real separation between the caller and the recipient, they do
18/// not need to know about each other.
19pub struct MsgMailbox<M: Message> {
20    service_mailbox: Box<dyn MsgMailboxT<M>>,
21}
22
23impl<M: Message> MsgMailbox<M> {
24    /// Create a mock msg mailbox. Messages will be kept in a Vec - Do not use this directly but use ServiceMock::new()
25    pub(super) fn mock() -> (Self, MockReplySender<M>) {
26        let (sender, receiver) = channel();
27        let mock = MockMsgMailbox::new(sender);
28        (
29            MsgMailbox {
30                service_mailbox: mock.duplicate(),
31            },
32            receiver,
33        )
34    }
35
36    /// Create a bounded mock msg mailbox. Messages will be kept in a Vec - Do not use this directly but use ServiceMock::new()
37    pub(super) fn bounded_mock(channel_size: usize) -> (Self, MockReplySender<M>) {
38        let (sender, receiver) = sync_channel(channel_size);
39        let mock = BoundedMockMsgMailbox::new(sender);
40        (
41            MsgMailbox {
42                service_mailbox: mock.duplicate(),
43            },
44            receiver,
45        )
46    }
47
48    pub fn send_and_forget(&self, message: M) -> Result<(), MailboxError> {
49        self.service_mailbox.send_and_forget(message)
50    }
51    pub fn send_and_wait_for_reply(&self, message: M) -> Result<M::Reply, MailboxError> {
52        self.service_mailbox.send_and_wait_for_reply(message)
53    }
54}
55
56impl<M: Message> Clone for MsgMailbox<M> {
57    fn clone(&self) -> Self {
58        MsgMailbox {
59            service_mailbox: self.service_mailbox.duplicate(),
60        }
61    }
62}
63
64/// A `MsgMailbox` that will send to all services that have registered with it
65///
66/// This mailbox allows you to do a one-to-many mapping of mailboxes to services.
67/// It's used in the case where we need to notiy several services of the same
68/// message simultaneously.
69#[derive(Clone)]
70pub struct BroadcastMsgMailbox<M: Message + Clone> {
71    msg_mailboxes: Vec<MsgMailbox<M>>,
72}
73
74impl<M: Message + Clone> BroadcastMsgMailbox<M> {
75    pub fn send_and_forget(&self, message: M) -> Result<(), MailboxError> {
76        self.msg_mailboxes
77            .iter()
78            .try_for_each(|mbox| mbox.send_and_forget(message.clone()))
79    }
80
81    pub fn send_and_wait_for_reply(&self, message: M) -> Result<Vec<M::Reply>, MailboxError> {
82        self.msg_mailboxes
83            .iter()
84            .map(|mbox| mbox.send_and_wait_for_reply(message.clone()))
85            .collect()
86    }
87}
88
89impl<M: Message + Clone> From<Vec<MsgMailbox<M>>> for BroadcastMsgMailbox<M> {
90    fn from(msg_mailboxes: Vec<MsgMailbox<M>>) -> Self {
91        Self { msg_mailboxes }
92    }
93}
94
95trait MsgMailboxT<M: Message>: Send + Sync {
96    fn send_and_forget(&self, message: M) -> Result<(), MailboxError>;
97    fn send_and_wait_for_reply(&self, message: M) -> Result<M::Reply, MailboxError>;
98    fn duplicate(&self) -> Box<dyn MsgMailboxT<M>>;
99}
100
101impl<M, S> MsgMailboxT<M> for Mailbox<S>
102where
103    S: Service + 'static,
104    M: Message,
105    S: Handler<M>,
106{
107    fn send_and_forget(&self, message: M) -> Result<(), MailboxError> {
108        self.send_and_forget(message)
109    }
110    fn send_and_wait_for_reply(&self, message: M) -> Result<M::Reply, MailboxError> {
111        self.send_and_wait_for_reply(message)
112    }
113    fn duplicate(&self) -> Box<dyn MsgMailboxT<M>> {
114        Box::new(self.clone())
115    }
116}
117
118impl<M, S> MsgMailboxT<M> for BoundedMailbox<S>
119where
120    S: Service + 'static,
121    M: Message,
122    S: Handler<M>,
123{
124    fn send_and_forget(&self, message: M) -> Result<(), MailboxError> {
125        self.send_and_forget(message)
126    }
127    fn send_and_wait_for_reply(&self, message: M) -> Result<M::Reply, MailboxError> {
128        self.send_and_wait_for_reply(message)
129    }
130    fn duplicate(&self) -> Box<dyn MsgMailboxT<M>> {
131        Box::new(self.clone())
132    }
133}
134
135impl<M, S> MsgMailboxT<M> for BoundedTaskMailbox<S>
136where
137    S: Service + 'static,
138    M: Message,
139    S: Handler<M>,
140{
141    fn send_and_forget(&self, message: M) -> Result<(), MailboxError> {
142        self.send_and_forget(message)
143    }
144    fn send_and_wait_for_reply(&self, message: M) -> Result<M::Reply, MailboxError> {
145        self.send_and_wait_for_reply(message)
146    }
147    fn duplicate(&self) -> Box<dyn MsgMailboxT<M>> {
148        Box::new(self.clone())
149    }
150}
151
152impl<M, S> From<Mailbox<S>> for MsgMailbox<M>
153where
154    M: Message,
155    S: Service,
156    S: Handler<M>,
157    S: 'static,
158{
159    fn from(mailbox: Mailbox<S>) -> Self {
160        MsgMailbox {
161            service_mailbox: Box::new(mailbox),
162        }
163    }
164}
165
166impl<M, S> From<BoundedMailbox<S>> for MsgMailbox<M>
167where
168    M: Message,
169    S: Service,
170    S: Handler<M>,
171    S: 'static,
172{
173    fn from(mailbox: BoundedMailbox<S>) -> Self {
174        MsgMailbox {
175            service_mailbox: Box::new(mailbox),
176        }
177    }
178}
179
180impl<M, S> From<BoundedTaskMailbox<S>> for MsgMailbox<M>
181where
182    M: Message,
183    S: Service,
184    S: Handler<M>,
185    S: 'static,
186{
187    fn from(mailbox: BoundedTaskMailbox<S>) -> Self {
188        MsgMailbox {
189            service_mailbox: Box::new(mailbox),
190        }
191    }
192}
193
194pub(super) struct MockMsgMailbox<M: Message> {
195    sender: Sender<(M, Sender<M::Reply>)>,
196}
197
198impl<M: Message> MockMsgMailbox<M> {
199    pub fn new(sender: Sender<(M, Sender<M::Reply>)>) -> Self {
200        MockMsgMailbox { sender }
201    }
202}
203
204impl<M: Message> MsgMailboxT<M> for MockMsgMailbox<M> {
205    fn send_and_forget(&self, message: M) -> Result<(), MailboxError> {
206        let (tx, _rx) = channel();
207        if self.sender.send((message, tx)).is_err() {
208            return Err(MailboxError::SendChannelClosed);
209        }
210
211        Ok(())
212    }
213
214    fn send_and_wait_for_reply(&self, message: M) -> Result<M::Reply, MailboxError> {
215        let (tx, rx) = channel();
216
217        if self.sender.send((message, tx)).is_err() {
218            return Err(MailboxError::SendChannelClosed);
219        }
220
221        rx.recv().map_err(|_| MailboxError::NoResponse)
222    }
223
224    fn duplicate(&self) -> Box<dyn MsgMailboxT<M>> {
225        Box::new(MockMsgMailbox {
226            sender: self.sender.clone(),
227        })
228    }
229}
230
231pub(super) struct BoundedMockMsgMailbox<M: Message> {
232    sender: SyncSender<(M, Sender<M::Reply>)>,
233}
234
235impl<M: Message> BoundedMockMsgMailbox<M> {
236    pub fn new(sender: SyncSender<(M, Sender<M::Reply>)>) -> Self {
237        BoundedMockMsgMailbox { sender }
238    }
239}
240
241impl<M: Message> MsgMailboxT<M> for BoundedMockMsgMailbox<M> {
242    fn send_and_forget(&self, message: M) -> Result<(), MailboxError> {
243        let (tx, _rx) = channel();
244        self.sender.try_send((message, tx)).map_err(|e| match e {
245            std::sync::mpsc::TrySendError::Full(_) => MailboxError::SendChannelFull,
246            std::sync::mpsc::TrySendError::Disconnected(_) => MailboxError::SendChannelClosed,
247        })
248    }
249
250    fn send_and_wait_for_reply(&self, message: M) -> Result<M::Reply, MailboxError> {
251        let (tx, rx) = channel();
252        self.sender.try_send((message, tx)).map_err(|e| match e {
253            std::sync::mpsc::TrySendError::Full(_) => MailboxError::SendChannelFull,
254            std::sync::mpsc::TrySendError::Disconnected(_) => MailboxError::SendChannelClosed,
255        })?;
256
257        rx.recv().map_err(|_| MailboxError::NoResponse)
258    }
259
260    fn duplicate(&self) -> Box<dyn MsgMailboxT<M>> {
261        Box::new(BoundedMockMsgMailbox {
262            sender: self.sender.clone(),
263        })
264    }
265}
266
267#[cfg(test)]
268mod test {
269    use std::thread::spawn;
270
271    use super::*;
272
273    #[test]
274    fn test_broadcast_mailbox() {
275        let (mbox1, rx1) = MsgMailbox::<TestMessage>::mock();
276        let (mbox2, rx2) = MsgMailbox::<TestMessage>::mock();
277
278        let broadcast_mbox = BroadcastMsgMailbox::from(vec![mbox1, mbox2]);
279
280        broadcast_mbox.send_and_forget(TestMessage).unwrap();
281
282        assert!(rx1.try_recv().is_ok());
283        assert!(rx2.try_recv().is_ok());
284    }
285
286    #[test]
287    fn test_broadcast_mailbox_send_and_wait() {
288        let (mbox1, rx1) = MsgMailbox::<TestMessage>::mock();
289        let (mbox2, rx2) = MsgMailbox::<TestMessage>::mock();
290
291        let broadcast_mbox = BroadcastMsgMailbox::from(vec![mbox1, mbox2]);
292
293        let join_handle =
294            spawn(move || broadcast_mbox.send_and_wait_for_reply(TestMessage).unwrap());
295
296        let (_, reply_tx1) = rx1.recv().unwrap();
297        reply_tx1.send(()).unwrap();
298        let (_, reply_tx2) = rx2.recv().unwrap();
299        reply_tx2.send(()).unwrap();
300
301        let replies = join_handle.join().unwrap();
302        assert_eq!(replies.len(), 2);
303    }
304
305    #[derive(Clone)]
306    struct TestMessage;
307
308    impl Message for TestMessage {
309        type Reply = ();
310    }
311}