1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
use async_std::prelude::*;
use async_std::channel::*;
use async_std::task;
use std::fmt::Display;
use thiserror::Error;

pub struct MailboxProcessor<Msg, ReplyMsg> {
    message_sender: Sender<(Msg, Option<Sender<ReplyMsg>>)>,
}

pub enum BufferSize {
    Default,
    Size(usize),
}

#[derive(Debug, Error)]
pub struct MailboxProcessorError {
    msg: String,
    #[source]
    source: Option<Box<dyn std::error::Error + std::marker::Send + Sync + 'static>>,
}
impl Display for MailboxProcessorError {
    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
        write!(f, "{}", self.msg) } }

impl BufferSize {
    fn unwrap_or(&self, default_value: usize) -> usize {
        match self {
            BufferSize::Default => default_value,
            BufferSize::Size(x) => *x,
        }

    }
}

impl<Msg: std::marker::Send + 'static, ReplyMsg: std::marker::Send + 'static> MailboxProcessor<Msg, ReplyMsg> {
    pub async fn new<State: std::marker::Send + 'static, F: Future<Output = State> + std::marker::Send>
    (
        buffer_size: BufferSize,
        initial_state: State, 
        message_processing_function: impl Fn(Msg, State, Option<Sender<ReplyMsg>>) -> F + std::marker::Send + 'static + std::marker::Sync,
    ) 
    -> Self {
        let (s,r) = bounded(buffer_size.unwrap_or(1_000));
        task::spawn(async move { 
            let mut state = initial_state;
            loop {
                match r.recv().await {
                    Err(_) => break,  //the channel was closed so bail
                    Ok((msg, reply_channel)) => {
                        state = message_processing_function(msg, state, reply_channel).await;
                    },
                }
            }
        });
        MailboxProcessor {
            message_sender: s,
        }
    }
    pub async fn send(&self, msg:Msg) -> Result<ReplyMsg, MailboxProcessorError> {
        let (s, r) = bounded(1);
        match self.message_sender.send((msg, Some(s))).await {
            Err(_) => Err(MailboxProcessorError { msg: "the mailbox channel is closed send back nothing".to_owned(), source: None}),
            Ok(_) => match r.recv().await {
                Err(_) => Err(MailboxProcessorError { msg: "the response channel is closed (did you mean to call fire_and_forget() rather than send())".to_owned(), source: None}),
                Ok(reply_message) => Ok(reply_message),
            },
        }
    }
    //pub async fn fire_and_forget(&self, msg:Msg) -> Result<(), SendError<(Msg, Option<Sender<ReplyMsg>>)>> {
    pub async fn fire_and_forget(&self, msg:Msg) -> Result<(), MailboxProcessorError> {
        self.message_sender.send((msg, None)).await.map_err(|_| MailboxProcessorError {msg: "the mailbox channel is closed send back nothing".to_owned(), source: None})
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use futures::future::{OptionFuture};

    #[async_std::test]
    async fn mailbox_processor_tests() {

        enum SendMessageTypes {
            Increment(i32),
            GetCurrentCount,
            Decrement(i32),
        }

        let mb = MailboxProcessor::<SendMessageTypes, i32>::new( 
            BufferSize::Default, 
            0,  
            |msg, state, reply_channel| async move {
                match msg {
                    SendMessageTypes::Increment(x) => {
                        OptionFuture::from(reply_channel.map(|rc| async move {
                            rc.send(state + x).await.unwrap()
                        })).await;
                        state + x
                    },
                    SendMessageTypes::GetCurrentCount => {
                        OptionFuture::from(reply_channel.map(|rc| async move {
                            rc.send(state).await.unwrap()
                        })).await;
                        state
                    },
                    SendMessageTypes::Decrement(x) => {
                        OptionFuture::from(reply_channel.map(|rc| async move {
                            rc.send(state - x).await.unwrap()
                        })).await;
                        state - x
                    },
                }
            }
        ).await;

        assert_eq!(mb.send(SendMessageTypes::GetCurrentCount).await.unwrap(), 0);

        mb.fire_and_forget(SendMessageTypes::Increment(55)).await.unwrap();
        assert_eq!(mb.send(SendMessageTypes::GetCurrentCount).await.unwrap(), 55);

        mb.fire_and_forget(SendMessageTypes::Increment(55)).await.unwrap();
        assert_eq!(mb.send(SendMessageTypes::GetCurrentCount).await.unwrap(), 110);

        mb.fire_and_forget(SendMessageTypes::Decrement(10)).await.unwrap();
        assert_eq!(mb.send(SendMessageTypes::GetCurrentCount).await.unwrap(), 100);

        assert_eq!(mb.send(SendMessageTypes::Increment(55)).await.unwrap(), 155);
        assert_eq!(mb.send(SendMessageTypes::GetCurrentCount).await.unwrap(), 155);
    }
}