1use async_std::prelude::*;
2use async_std::channel::*;
3use async_std::task;
4use std::fmt::Display;
5use thiserror::Error;
6
7pub struct MailboxProcessor<Msg, ReplyMsg> {
8 message_sender: Sender<(Msg, Option<Sender<ReplyMsg>>)>,
9}
10
11pub enum BufferSize {
12 Default,
13 Size(usize),
14}
15
16#[derive(Debug, Error)]
17pub struct MailboxProcessorError {
18 msg: String,
19 #[source]
20 source: Option<Box<dyn std::error::Error + std::marker::Send + Sync + 'static>>,
21}
22impl Display for MailboxProcessorError {
23 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
24 write!(f, "{}", self.msg) } }
25
26impl BufferSize {
27 fn unwrap_or(&self, default_value: usize) -> usize {
28 match self {
29 BufferSize::Default => default_value,
30 BufferSize::Size(x) => *x,
31 }
32
33 }
34}
35
36impl<Msg: std::marker::Send + 'static, ReplyMsg: std::marker::Send + 'static> MailboxProcessor<Msg, ReplyMsg> {
37 pub async fn new<State: std::marker::Send + 'static, F: Future<Output = State> + std::marker::Send>
38 (
39 buffer_size: BufferSize,
40 initial_state: State,
41 message_processing_function: impl Fn(Msg, State, Option<Sender<ReplyMsg>>) -> F + std::marker::Send + 'static + std::marker::Sync,
42 )
43 -> Self {
44 let (s,r) = bounded(buffer_size.unwrap_or(1_000));
45 task::spawn(async move {
46 let mut state = initial_state;
47 loop {
48 match r.recv().await {
49 Err(_) => break, Ok((msg, reply_channel)) => {
51 state = message_processing_function(msg, state, reply_channel).await;
52 },
53 }
54 }
55 });
56 MailboxProcessor {
57 message_sender: s,
58 }
59 }
60 pub async fn send(&self, msg:Msg) -> Result<ReplyMsg, MailboxProcessorError> {
61 let (s, r) = bounded(1);
62 match self.message_sender.send((msg, Some(s))).await {
63 Err(_) => Err(MailboxProcessorError { msg: "the mailbox channel is closed send back nothing".to_owned(), source: None}),
64 Ok(_) => match r.recv().await {
65 Err(_) => Err(MailboxProcessorError { msg: "the response channel is closed (did you mean to call fire_and_forget() rather than send())".to_owned(), source: None}),
66 Ok(reply_message) => Ok(reply_message),
67 },
68 }
69 }
70 pub async fn fire_and_forget(&self, msg:Msg) -> Result<(), MailboxProcessorError> {
72 self.message_sender.send((msg, None)).await.map_err(|_| MailboxProcessorError {msg: "the mailbox channel is closed send back nothing".to_owned(), source: None})
73 }
74}
75
76#[cfg(test)]
77mod tests {
78 use super::*;
79 use futures::future::{OptionFuture};
80
81 #[async_std::test]
82 async fn mailbox_processor_tests() {
83
84 enum SendMessageTypes {
85 Increment(i32),
86 GetCurrentCount,
87 Decrement(i32),
88 }
89
90 let mb = MailboxProcessor::<SendMessageTypes, i32>::new(
91 BufferSize::Default,
92 0,
93 |msg, state, reply_channel| async move {
94 match msg {
95 SendMessageTypes::Increment(x) => {
96 OptionFuture::from(reply_channel.map(|rc| async move {
97 rc.send(state + x).await.unwrap()
98 })).await;
99 state + x
100 },
101 SendMessageTypes::GetCurrentCount => {
102 OptionFuture::from(reply_channel.map(|rc| async move {
103 rc.send(state).await.unwrap()
104 })).await;
105 state
106 },
107 SendMessageTypes::Decrement(x) => {
108 OptionFuture::from(reply_channel.map(|rc| async move {
109 rc.send(state - x).await.unwrap()
110 })).await;
111 state - x
112 },
113 }
114 }
115 ).await;
116
117 assert_eq!(mb.send(SendMessageTypes::GetCurrentCount).await.unwrap(), 0);
118
119 mb.fire_and_forget(SendMessageTypes::Increment(55)).await.unwrap();
120 assert_eq!(mb.send(SendMessageTypes::GetCurrentCount).await.unwrap(), 55);
121
122 mb.fire_and_forget(SendMessageTypes::Increment(55)).await.unwrap();
123 assert_eq!(mb.send(SendMessageTypes::GetCurrentCount).await.unwrap(), 110);
124
125 mb.fire_and_forget(SendMessageTypes::Decrement(10)).await.unwrap();
126 assert_eq!(mb.send(SendMessageTypes::GetCurrentCount).await.unwrap(), 100);
127
128 assert_eq!(mb.send(SendMessageTypes::Increment(55)).await.unwrap(), 155);
129 assert_eq!(mb.send(SendMessageTypes::GetCurrentCount).await.unwrap(), 155);
130 }
131}