use crate::error::MultiError;
use crate::types::Message;
use std::collections::HashMap;
use tokio::sync::{mpsc, RwLock};
pub struct Mailbox {
senders: RwLock<HashMap<String, mpsc::Sender<Message>>>,
buffer_size: usize,
}
impl Mailbox {
pub fn new(buffer_size: usize) -> Self {
Self {
senders: RwLock::new(HashMap::new()),
buffer_size,
}
}
pub async fn register(&self, name: &str) -> mpsc::Receiver<Message> {
let (tx, rx) = mpsc::channel(self.buffer_size);
self.senders.write().await.insert(name.to_string(), tx);
rx
}
pub async fn send(&self, msg: Message) -> Result<(), MultiError> {
let senders = self.senders.read().await;
let tx = senders
.get(&msg.to)
.ok_or_else(|| MultiError::MailboxSend(format!("no agent '{}'", msg.to)))?;
tx.send(msg)
.await
.map_err(|e| MultiError::MailboxSend(e.to_string()))
}
pub async fn broadcast(&self, msg: Message) -> Result<(), MultiError> {
let senders = self.senders.read().await;
for (name, tx) in senders.iter() {
if *name != msg.from {
let mut m = msg.clone();
m.to = name.clone();
let _ = tx.send(m).await;
}
}
Ok(())
}
pub async fn unregister(&self, name: &str) {
self.senders.write().await.remove(name);
}
}
impl Default for Mailbox {
fn default() -> Self {
Self::new(64)
}
}