use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use tokio::sync::{Mutex, mpsc};
use crate::error::{ReactError, Result};
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum MessageKind {
TaskAssigned {
task: String,
context: HashMap<String, String>,
},
TaskResult {
task: String,
result: String,
success: bool,
},
Query {
question: String,
},
QueryResponse {
answer: String,
},
Status {
message: String,
},
Cancelled,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MailboxMessage {
pub from: String,
pub to: String,
pub kind: MessageKind,
pub ts: u64,
}
impl MailboxMessage {
pub fn new(from: impl Into<String>, to: impl Into<String>, kind: MessageKind) -> Self {
Self {
from: from.into(),
to: to.into(),
kind,
ts: crate::utils::time::now_secs(),
}
}
}
const MAILBOX_CAPACITY: usize = 256;
pub struct Mailbox {
tx: mpsc::Sender<MailboxMessage>,
rx: Mutex<mpsc::Receiver<MailboxMessage>>,
}
impl Mailbox {
pub fn new() -> Self {
let (tx, rx) = mpsc::channel(MAILBOX_CAPACITY);
Self {
tx,
rx: Mutex::new(rx),
}
}
pub fn with_capacity(capacity: usize) -> Self {
let (tx, rx) = mpsc::channel(capacity);
Self {
tx,
rx: Mutex::new(rx),
}
}
pub async fn send(&self, msg: MailboxMessage) -> Result<()> {
self.tx
.send(msg)
.await
.map_err(|e| ReactError::Other(format!("Mailbox send error: {}", e)))
}
pub async fn recv(&self) -> Option<MailboxMessage> {
let mut rx = self.rx.lock().await;
rx.recv().await
}
pub async fn try_recv(&self) -> Option<MailboxMessage> {
let mut rx = self.rx.lock().await;
rx.try_recv().ok()
}
pub fn sender(&self) -> MailboxSender {
MailboxSender {
tx: self.tx.clone(),
}
}
pub fn is_closed(&self) -> bool {
self.tx.is_closed()
}
}
#[derive(Clone)]
pub struct MailboxSender {
tx: mpsc::Sender<MailboxMessage>,
}
impl MailboxSender {
pub async fn send(&self, msg: MailboxMessage) -> Result<()> {
self.tx
.send(msg)
.await
.map_err(|e| ReactError::Other(format!("MailboxSender send error: {}", e)))
}
}
impl Default for Mailbox {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_mailbox_send_recv() {
let mailbox = Mailbox::new();
let msg = MailboxMessage::new(
"leader",
"worker",
MessageKind::TaskAssigned {
task: "do stuff".into(),
context: HashMap::new(),
},
);
mailbox.send(msg).await.unwrap();
let received = mailbox.recv().await.unwrap();
assert_eq!(received.from, "leader");
assert_eq!(received.to, "worker");
match received.kind {
MessageKind::TaskAssigned { task, .. } => assert_eq!(task, "do stuff"),
_ => panic!("Wrong message kind"),
}
}
#[tokio::test]
async fn test_mailbox_sender() {
let mailbox = Mailbox::new();
let sender = mailbox.sender();
sender
.send(MailboxMessage::new(
"a",
"b",
MessageKind::Status {
message: "ok".into(),
},
))
.await
.unwrap();
let received = mailbox.recv().await.unwrap();
assert_eq!(received.from, "a");
}
#[tokio::test]
async fn test_mailbox_try_recv_empty() {
let mailbox = Mailbox::new();
let result = mailbox.try_recv().await;
assert!(result.is_none());
}
}