use std::sync::Arc;
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use tokio::sync::{Mutex, Notify, mpsc};
use crate::{AgentMessage, AgentPid, MessagingError, MessagingResult};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MailboxConfig {
pub max_messages: usize,
pub preserve_order: bool,
pub enable_persistence: bool,
pub stats_interval: std::time::Duration,
}
impl Default for MailboxConfig {
fn default() -> Self {
Self {
max_messages: 0, preserve_order: true,
enable_persistence: false,
stats_interval: std::time::Duration::from_secs(60),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MailboxStats {
pub agent_id: AgentPid,
pub total_messages_received: u64,
pub total_messages_processed: u64,
pub current_queue_size: usize,
pub max_queue_size_reached: usize,
pub last_message_received: Option<DateTime<Utc>>,
pub last_message_processed: Option<DateTime<Utc>>,
pub average_processing_time: std::time::Duration,
}
impl MailboxStats {
pub fn new(agent_id: AgentPid) -> Self {
Self {
agent_id,
total_messages_received: 0,
total_messages_processed: 0,
current_queue_size: 0,
max_queue_size_reached: 0,
last_message_received: None,
last_message_processed: None,
average_processing_time: std::time::Duration::ZERO,
}
}
pub fn record_message_received(&mut self) {
self.total_messages_received += 1;
self.current_queue_size += 1;
self.max_queue_size_reached = self.max_queue_size_reached.max(self.current_queue_size);
self.last_message_received = Some(Utc::now());
}
pub fn record_message_processed(&mut self, processing_time: std::time::Duration) {
self.total_messages_processed += 1;
self.current_queue_size = self.current_queue_size.saturating_sub(1);
self.last_message_processed = Some(Utc::now());
if self.total_messages_processed == 1 {
self.average_processing_time = processing_time;
} else {
let total_time = self.average_processing_time.as_nanos() as f64
* (self.total_messages_processed - 1) as f64;
let new_average = (total_time + processing_time.as_nanos() as f64)
/ self.total_messages_processed as f64;
self.average_processing_time = std::time::Duration::from_nanos(new_average as u64);
}
}
}
pub struct AgentMailbox {
agent_id: AgentPid,
config: MailboxConfig,
sender: mpsc::UnboundedSender<AgentMessage>,
receiver: Arc<Mutex<mpsc::UnboundedReceiver<AgentMessage>>>,
stats: Arc<Mutex<MailboxStats>>,
shutdown_notify: Arc<Notify>,
}
impl AgentMailbox {
pub fn new(agent_id: AgentPid, config: MailboxConfig) -> Self {
let (sender, receiver) = mpsc::unbounded_channel();
let stats = MailboxStats::new(agent_id.clone());
Self {
agent_id: agent_id.clone(),
config,
sender,
receiver: Arc::new(Mutex::new(receiver)),
stats: Arc::new(Mutex::new(stats)),
shutdown_notify: Arc::new(Notify::new()),
}
}
pub fn agent_id(&self) -> &AgentPid {
&self.agent_id
}
pub fn config(&self) -> &MailboxConfig {
&self.config
}
pub async fn send(&self, message: AgentMessage) -> MessagingResult<()> {
if self.config.max_messages > 0 {
let stats = self.stats.lock().await;
if stats.current_queue_size >= self.config.max_messages {
return Err(MessagingError::MailboxFull(self.agent_id.clone()));
}
}
self.sender
.send(message)
.map_err(|_| MessagingError::ChannelClosed(self.agent_id.clone()))?;
{
let mut stats = self.stats.lock().await;
stats.record_message_received();
}
Ok(())
}
pub async fn receive(&self) -> MessagingResult<AgentMessage> {
let start_time = std::time::Instant::now();
let message = {
let mut receiver = self.receiver.lock().await;
receiver
.recv()
.await
.ok_or_else(|| MessagingError::ChannelClosed(self.agent_id.clone()))?
};
{
let mut stats = self.stats.lock().await;
stats.record_message_processed(start_time.elapsed());
}
Ok(message)
}
pub async fn try_receive(&self) -> MessagingResult<Option<AgentMessage>> {
let start_time = std::time::Instant::now();
let message = {
let mut receiver = self.receiver.lock().await;
match receiver.try_recv() {
Ok(message) => Some(message),
Err(mpsc::error::TryRecvError::Empty) => None,
Err(mpsc::error::TryRecvError::Disconnected) => {
return Err(MessagingError::ChannelClosed(self.agent_id.clone()));
}
}
};
if let Some(_message) = &message {
let mut stats = self.stats.lock().await;
stats.record_message_processed(start_time.elapsed());
}
Ok(message)
}
pub async fn receive_timeout(
&self,
timeout: std::time::Duration,
) -> MessagingResult<AgentMessage> {
let start_time = std::time::Instant::now();
let message = tokio::time::timeout(timeout, async {
let mut receiver = self.receiver.lock().await;
receiver
.recv()
.await
.ok_or_else(|| MessagingError::ChannelClosed(self.agent_id.clone()))
})
.await
.map_err(|_| MessagingError::MessageTimeout(self.agent_id.clone()))??;
{
let mut stats = self.stats.lock().await;
stats.record_message_processed(start_time.elapsed());
}
Ok(message)
}
pub async fn stats(&self) -> MailboxStats {
self.stats.lock().await.clone()
}
pub async fn is_empty(&self) -> bool {
let stats = self.stats.lock().await;
stats.current_queue_size == 0
}
pub async fn queue_size(&self) -> usize {
let stats = self.stats.lock().await;
stats.current_queue_size
}
pub fn close(&self) {
self.shutdown_notify.notify_waiters();
}
pub async fn wait_for_shutdown(&self) {
self.shutdown_notify.notified().await;
}
pub fn sender(&self) -> MailboxSender {
MailboxSender {
agent_id: self.agent_id.clone(),
sender: self.sender.clone(),
}
}
}
#[derive(Clone)]
pub struct MailboxSender {
agent_id: AgentPid,
sender: mpsc::UnboundedSender<AgentMessage>,
}
impl MailboxSender {
pub async fn send(&self, message: AgentMessage) -> MessagingResult<()> {
self.sender
.send(message)
.map_err(|_| MessagingError::ChannelClosed(self.agent_id.clone()))
}
pub fn agent_id(&self) -> &AgentPid {
&self.agent_id
}
pub fn is_closed(&self) -> bool {
self.sender.is_closed()
}
}
pub struct MailboxManager {
mailboxes: Arc<Mutex<std::collections::HashMap<AgentPid, AgentMailbox>>>,
default_config: MailboxConfig,
}
impl MailboxManager {
pub fn new(default_config: MailboxConfig) -> Self {
Self {
mailboxes: Arc::new(Mutex::new(std::collections::HashMap::new())),
default_config,
}
}
pub async fn create_mailbox(&self, agent_id: AgentPid) -> MessagingResult<MailboxSender> {
let mut mailboxes = self.mailboxes.lock().await;
if mailboxes.contains_key(&agent_id) {
return Err(MessagingError::DuplicateAgent(agent_id));
}
let mailbox = AgentMailbox::new(agent_id.clone(), self.default_config.clone());
let sender = mailbox.sender();
mailboxes.insert(agent_id, mailbox);
Ok(sender)
}
pub async fn get_mailbox_sender(&self, agent_id: &AgentPid) -> Option<MailboxSender> {
let mailboxes = self.mailboxes.lock().await;
mailboxes.get(agent_id).map(|mailbox| mailbox.sender())
}
pub async fn get_mailbox(&self, agent_id: &AgentPid) -> Option<AgentMailbox> {
let mailboxes = self.mailboxes.lock().await;
mailboxes.get(agent_id).cloned()
}
pub async fn remove_mailbox(&self, agent_id: &AgentPid) -> MessagingResult<()> {
let mut mailboxes = self.mailboxes.lock().await;
if let Some(mailbox) = mailboxes.remove(agent_id) {
mailbox.close();
Ok(())
} else {
Err(MessagingError::AgentNotFound(agent_id.clone()))
}
}
pub async fn list_agents(&self) -> Vec<AgentPid> {
let mailboxes = self.mailboxes.lock().await;
mailboxes.keys().cloned().collect()
}
pub async fn get_all_stats(&self) -> Vec<MailboxStats> {
let mailboxes = self.mailboxes.lock().await;
let mut stats = Vec::new();
for mailbox in mailboxes.values() {
stats.push(mailbox.stats().await);
}
stats
}
}
impl Clone for AgentMailbox {
fn clone(&self) -> Self {
AgentMailbox::new(self.agent_id.clone(), self.config.clone())
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::time::Duration;
#[tokio::test]
async fn test_mailbox_creation() {
let agent_id = AgentPid::new();
let config = MailboxConfig::default();
let mailbox = AgentMailbox::new(agent_id.clone(), config);
assert_eq!(mailbox.agent_id(), &agent_id);
assert!(mailbox.is_empty().await);
assert_eq!(mailbox.queue_size().await, 0);
}
#[tokio::test]
async fn test_message_send_receive() {
let agent_id = AgentPid::new();
let config = MailboxConfig::default();
let mailbox = AgentMailbox::new(agent_id.clone(), config);
let message = AgentMessage::cast(agent_id.clone(), "test payload");
mailbox.send(message).await.unwrap();
assert!(!mailbox.is_empty().await);
assert_eq!(mailbox.queue_size().await, 1);
let received = mailbox.receive().await.unwrap();
assert_eq!(received.from(), Some(&agent_id));
assert!(mailbox.is_empty().await);
assert_eq!(mailbox.queue_size().await, 0);
}
#[tokio::test]
async fn test_mailbox_stats() {
let agent_id = AgentPid::new();
let config = MailboxConfig::default();
let mailbox = AgentMailbox::new(agent_id.clone(), config);
let message = AgentMessage::cast(agent_id.clone(), "test");
mailbox.send(message).await.unwrap();
let _received = mailbox.receive().await.unwrap();
let stats = mailbox.stats().await;
assert_eq!(stats.total_messages_received, 1);
assert_eq!(stats.total_messages_processed, 1);
assert_eq!(stats.current_queue_size, 0);
assert!(stats.last_message_received.is_some());
assert!(stats.last_message_processed.is_some());
}
#[tokio::test]
async fn test_mailbox_timeout() {
let agent_id = AgentPid::new();
let config = MailboxConfig::default();
let mailbox = AgentMailbox::new(agent_id.clone(), config);
let result = mailbox.receive_timeout(Duration::from_millis(100)).await;
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
MessagingError::MessageTimeout(_)
));
}
#[tokio::test]
async fn test_mailbox_manager() {
let config = MailboxConfig::default();
let manager = MailboxManager::new(config);
let agent_id = AgentPid::new();
let sender = manager.create_mailbox(agent_id.clone()).await.unwrap();
assert_eq!(sender.agent_id(), &agent_id);
let agents = manager.list_agents().await;
assert_eq!(agents.len(), 1);
assert_eq!(agents[0], agent_id);
manager.remove_mailbox(&agent_id).await.unwrap();
let agents = manager.list_agents().await;
assert_eq!(agents.len(), 0);
}
#[tokio::test]
async fn test_bounded_mailbox() {
let agent_id = AgentPid::new();
let config = MailboxConfig {
max_messages: 2, ..Default::default()
};
let mailbox = AgentMailbox::new(agent_id.clone(), config);
let msg1 = AgentMessage::cast(agent_id.clone(), "msg1");
let msg2 = AgentMessage::cast(agent_id.clone(), "msg2");
mailbox.send(msg1).await.unwrap();
mailbox.send(msg2).await.unwrap();
let msg3 = AgentMessage::cast(agent_id.clone(), "msg3");
let result = mailbox.send(msg3).await;
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
MessagingError::MailboxFull(_)
));
}
}