mailbox_processor/
lib.rs

1use async_std::prelude::*;
2use async_std::channel::*;
3use async_std::task;
4use std::fmt::Display;
5use thiserror::Error;
6
7pub struct MailboxProcessor<Msg, ReplyMsg> {
8    message_sender: Sender<(Msg, Option<Sender<ReplyMsg>>)>,
9}
10
11pub enum BufferSize {
12    Default,
13    Size(usize),
14}
15
16#[derive(Debug, Error)]
17pub struct MailboxProcessorError {
18    msg: String,
19    #[source]
20    source: Option<Box<dyn std::error::Error + std::marker::Send + Sync + 'static>>,
21}
22impl Display for MailboxProcessorError {
23    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
24        write!(f, "{}", self.msg) } }
25
26impl BufferSize {
27    fn unwrap_or(&self, default_value: usize) -> usize {
28        match self {
29            BufferSize::Default => default_value,
30            BufferSize::Size(x) => *x,
31        }
32
33    }
34}
35
36impl<Msg: std::marker::Send + 'static, ReplyMsg: std::marker::Send + 'static> MailboxProcessor<Msg, ReplyMsg> {
37    pub async fn new<State: std::marker::Send + 'static, F: Future<Output = State> + std::marker::Send>
38    (
39        buffer_size: BufferSize,
40        initial_state: State, 
41        message_processing_function: impl Fn(Msg, State, Option<Sender<ReplyMsg>>) -> F + std::marker::Send + 'static + std::marker::Sync,
42    ) 
43    -> Self {
44        let (s,r) = bounded(buffer_size.unwrap_or(1_000));
45        task::spawn(async move { 
46            let mut state = initial_state;
47            loop {
48                match r.recv().await {
49                    Err(_) => break,  //the channel was closed so bail
50                    Ok((msg, reply_channel)) => {
51                        state = message_processing_function(msg, state, reply_channel).await;
52                    },
53                }
54            }
55        });
56        MailboxProcessor {
57            message_sender: s,
58        }
59    }
60    pub async fn send(&self, msg:Msg) -> Result<ReplyMsg, MailboxProcessorError> {
61        let (s, r) = bounded(1);
62        match self.message_sender.send((msg, Some(s))).await {
63            Err(_) => Err(MailboxProcessorError { msg: "the mailbox channel is closed send back nothing".to_owned(), source: None}),
64            Ok(_) => match r.recv().await {
65                Err(_) => Err(MailboxProcessorError { msg: "the response channel is closed (did you mean to call fire_and_forget() rather than send())".to_owned(), source: None}),
66                Ok(reply_message) => Ok(reply_message),
67            },
68        }
69    }
70    //pub async fn fire_and_forget(&self, msg:Msg) -> Result<(), SendError<(Msg, Option<Sender<ReplyMsg>>)>> {
71    pub async fn fire_and_forget(&self, msg:Msg) -> Result<(), MailboxProcessorError> {
72        self.message_sender.send((msg, None)).await.map_err(|_| MailboxProcessorError {msg: "the mailbox channel is closed send back nothing".to_owned(), source: None})
73    }
74}
75
76#[cfg(test)]
77mod tests {
78    use super::*;
79    use futures::future::{OptionFuture};
80
81    #[async_std::test]
82    async fn mailbox_processor_tests() {
83
84        enum SendMessageTypes {
85            Increment(i32),
86            GetCurrentCount,
87            Decrement(i32),
88        }
89
90        let mb = MailboxProcessor::<SendMessageTypes, i32>::new( 
91            BufferSize::Default, 
92            0,  
93            |msg, state, reply_channel| async move {
94                match msg {
95                    SendMessageTypes::Increment(x) => {
96                        OptionFuture::from(reply_channel.map(|rc| async move {
97                            rc.send(state + x).await.unwrap()
98                        })).await;
99                        state + x
100                    },
101                    SendMessageTypes::GetCurrentCount => {
102                        OptionFuture::from(reply_channel.map(|rc| async move {
103                            rc.send(state).await.unwrap()
104                        })).await;
105                        state
106                    },
107                    SendMessageTypes::Decrement(x) => {
108                        OptionFuture::from(reply_channel.map(|rc| async move {
109                            rc.send(state - x).await.unwrap()
110                        })).await;
111                        state - x
112                    },
113                }
114            }
115        ).await;
116
117        assert_eq!(mb.send(SendMessageTypes::GetCurrentCount).await.unwrap(), 0);
118
119        mb.fire_and_forget(SendMessageTypes::Increment(55)).await.unwrap();
120        assert_eq!(mb.send(SendMessageTypes::GetCurrentCount).await.unwrap(), 55);
121
122        mb.fire_and_forget(SendMessageTypes::Increment(55)).await.unwrap();
123        assert_eq!(mb.send(SendMessageTypes::GetCurrentCount).await.unwrap(), 110);
124
125        mb.fire_and_forget(SendMessageTypes::Decrement(10)).await.unwrap();
126        assert_eq!(mb.send(SendMessageTypes::GetCurrentCount).await.unwrap(), 100);
127
128        assert_eq!(mb.send(SendMessageTypes::Increment(55)).await.unwrap(), 155);
129        assert_eq!(mb.send(SendMessageTypes::GetCurrentCount).await.unwrap(), 155);
130    }
131}