use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use async_trait::async_trait;
use tokio::sync::{Mutex, RwLock};
use tokio::time::{interval, sleep};
use crate::{
AgentMessage, AgentPid, DeliveryConfig, DeliveryGuarantee, DeliveryManager, MailboxManager,
MailboxSender, MessageEnvelope, MessagingError, MessagingResult,
};
#[derive(Debug, Clone)]
pub struct RouterConfig {
pub delivery_config: DeliveryConfig,
pub retry_interval: Duration,
pub max_concurrent_deliveries: usize,
pub enable_metrics: bool,
}
impl Default for RouterConfig {
fn default() -> Self {
Self {
delivery_config: DeliveryConfig::default(),
retry_interval: Duration::from_secs(5),
max_concurrent_deliveries: 100,
enable_metrics: true,
}
}
}
#[derive(Debug, Default, Clone)]
pub struct RouterStats {
pub messages_routed: u64,
pub messages_delivered: u64,
pub messages_failed: u64,
pub active_routes: usize,
pub retry_attempts: u64,
}
#[async_trait]
pub trait MessageRouter: Send + Sync {
async fn route_message(&self, envelope: MessageEnvelope) -> MessagingResult<()>;
async fn register_agent(
&self,
agent_id: AgentPid,
sender: MailboxSender,
) -> MessagingResult<()>;
async fn unregister_agent(&self, agent_id: &AgentPid) -> MessagingResult<()>;
async fn get_stats(&self) -> RouterStats;
async fn shutdown(&self) -> MessagingResult<()>;
}
pub struct DefaultMessageRouter {
config: RouterConfig,
agents: Arc<RwLock<HashMap<AgentPid, MailboxSender>>>,
delivery_manager: Arc<DeliveryManager>,
stats: Arc<Mutex<RouterStats>>,
shutdown_signal: Arc<tokio::sync::Notify>,
}
impl DefaultMessageRouter {
pub fn new(config: RouterConfig) -> Self {
let delivery_manager = Arc::new(DeliveryManager::new(config.delivery_config.clone()));
let router = Self {
config: config.clone(),
agents: Arc::new(RwLock::new(HashMap::new())),
delivery_manager,
stats: Arc::new(Mutex::new(RouterStats::default())),
shutdown_signal: Arc::new(tokio::sync::Notify::new()),
};
router.start_retry_task();
router
}
fn start_retry_task(&self) {
let delivery_manager = Arc::clone(&self.delivery_manager);
let agents = Arc::clone(&self.agents);
let stats = Arc::clone(&self.stats);
let retry_interval = self.config.retry_interval;
let shutdown_signal = Arc::clone(&self.shutdown_signal);
tokio::spawn(async move {
let mut interval = interval(retry_interval);
loop {
tokio::select! {
_ = interval.tick() => {
let candidates = delivery_manager.get_retry_candidates().await;
for mut envelope in candidates {
let delay = delivery_manager.calculate_retry_delay(envelope.attempts);
sleep(delay).await;
let agents_guard = agents.read().await;
if let Some(sender) = agents_guard.get(&envelope.to) {
envelope.increment_attempts();
if let Err(e) = delivery_manager.mark_in_transit(&envelope.id).await {
log::error!("Failed to mark message {} as in transit: {}", envelope.id, e);
continue;
}
let agent_message = AgentMessage::cast(
envelope.from.clone().unwrap_or_else(AgentPid::new),
envelope.payload.clone()
);
match sender.send(agent_message).await {
Ok(()) => {
if let Err(e) = delivery_manager.mark_delivered(&envelope.id).await {
log::error!("Failed to mark message {} as delivered: {}", envelope.id, e);
}
{
let mut stats_guard = stats.lock().await;
stats_guard.retry_attempts += 1;
stats_guard.messages_delivered += 1;
}
}
Err(e) => {
if let Err(mark_err) = delivery_manager.mark_failed(&envelope.id, e.to_string()).await {
log::error!("Failed to mark message {} as failed: {}", envelope.id, mark_err);
}
{
let mut stats_guard = stats.lock().await;
stats_guard.retry_attempts += 1;
stats_guard.messages_failed += 1;
}
}
}
} else {
if let Err(e) = delivery_manager.mark_failed(
&envelope.id,
format!("Agent {} not found", envelope.to)
).await {
log::error!("Failed to mark message {} as failed: {}", envelope.id, e);
}
}
}
}
_ = shutdown_signal.notified() => {
log::info!("Retry task shutting down");
break;
}
}
}
});
}
fn envelope_to_agent_message(
&self,
envelope: &MessageEnvelope,
) -> MessagingResult<AgentMessage> {
let from = envelope.from.clone().unwrap_or_default();
Ok(AgentMessage::cast(from, envelope.payload.clone()))
}
}
#[async_trait]
impl MessageRouter for DefaultMessageRouter {
async fn route_message(&self, envelope: MessageEnvelope) -> MessagingResult<()> {
self.delivery_manager.record_message(&envelope).await?;
let agents = self.agents.read().await;
let sender = agents
.get(&envelope.to)
.ok_or_else(|| MessagingError::AgentNotFound(envelope.to.clone()))?;
self.delivery_manager.mark_in_transit(&envelope.id).await?;
let agent_message = self.envelope_to_agent_message(&envelope)?;
match sender.send(agent_message).await {
Ok(()) => {
self.delivery_manager.mark_delivered(&envelope.id).await?;
if self.config.delivery_config.guarantee == DeliveryGuarantee::AtMostOnce {
self.delivery_manager
.mark_acknowledged(&envelope.id)
.await?;
}
{
let mut stats = self.stats.lock().await;
stats.messages_routed += 1;
stats.messages_delivered += 1;
}
Ok(())
}
Err(e) => {
self.delivery_manager
.mark_failed(&envelope.id, e.to_string())
.await?;
{
let mut stats = self.stats.lock().await;
stats.messages_routed += 1;
stats.messages_failed += 1;
}
Err(e)
}
}
}
async fn register_agent(
&self,
agent_id: AgentPid,
sender: MailboxSender,
) -> MessagingResult<()> {
let mut agents = self.agents.write().await;
if agents.contains_key(&agent_id) {
return Err(MessagingError::DuplicateAgent(agent_id));
}
agents.insert(agent_id.clone(), sender);
{
let mut stats = self.stats.lock().await;
stats.active_routes = agents.len();
}
log::info!("Registered agent {} for message routing", agent_id);
Ok(())
}
async fn unregister_agent(&self, agent_id: &AgentPid) -> MessagingResult<()> {
let mut agents = self.agents.write().await;
if agents.remove(agent_id).is_none() {
return Err(MessagingError::AgentNotFound(agent_id.clone()));
}
{
let mut stats = self.stats.lock().await;
stats.active_routes = agents.len();
}
log::info!("Unregistered agent {} from message routing", agent_id);
Ok(())
}
async fn get_stats(&self) -> RouterStats {
self.stats.lock().await.clone()
}
async fn shutdown(&self) -> MessagingResult<()> {
log::info!("Shutting down message router");
self.shutdown_signal.notify_waiters();
{
let mut agents = self.agents.write().await;
agents.clear();
}
{
let mut stats = self.stats.lock().await;
*stats = RouterStats::default();
}
Ok(())
}
}
pub struct MessageSystem {
router: Arc<dyn MessageRouter>,
mailbox_manager: Arc<MailboxManager>,
}
impl MessageSystem {
pub fn new(router_config: RouterConfig) -> Self {
let router = Arc::new(DefaultMessageRouter::new(router_config));
let mailbox_config = crate::MailboxConfig::default();
let mailbox_manager = Arc::new(MailboxManager::new(mailbox_config));
Self {
router,
mailbox_manager,
}
}
pub async fn register_agent(&self, agent_id: AgentPid) -> MessagingResult<()> {
let sender = self
.mailbox_manager
.create_mailbox(agent_id.clone())
.await?;
self.router.register_agent(agent_id, sender).await?;
Ok(())
}
pub async fn unregister_agent(&self, agent_id: &AgentPid) -> MessagingResult<()> {
self.router.unregister_agent(agent_id).await?;
self.mailbox_manager.remove_mailbox(agent_id).await?;
Ok(())
}
pub async fn send_message(&self, envelope: MessageEnvelope) -> MessagingResult<()> {
self.router.route_message(envelope).await
}
pub async fn get_mailbox(&self, agent_id: &AgentPid) -> Option<crate::AgentMailbox> {
self.mailbox_manager.get_mailbox(agent_id).await
}
pub async fn get_stats(&self) -> (RouterStats, Vec<crate::MailboxStats>) {
let router_stats = self.router.get_stats().await;
let mailbox_stats = self.mailbox_manager.get_all_stats().await;
(router_stats, mailbox_stats)
}
pub async fn shutdown(&self) -> MessagingResult<()> {
self.router.shutdown().await?;
let agents = self.mailbox_manager.list_agents().await;
for agent_id in agents {
let _ = self.mailbox_manager.remove_mailbox(&agent_id).await;
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::DeliveryOptions;
#[tokio::test]
async fn test_router_registration() {
let config = RouterConfig::default();
let router = DefaultMessageRouter::new(config);
let agent_id = AgentPid::new();
let mailbox_config = crate::MailboxConfig::default();
let mailbox = crate::AgentMailbox::new(agent_id.clone(), mailbox_config);
let sender = mailbox.sender();
router
.register_agent(agent_id.clone(), sender)
.await
.unwrap();
let stats = router.get_stats().await;
assert_eq!(stats.active_routes, 1);
router.unregister_agent(&agent_id).await.unwrap();
let stats = router.get_stats().await;
assert_eq!(stats.active_routes, 0);
}
#[tokio::test]
async fn test_message_routing() {
let config = RouterConfig::default();
let router = DefaultMessageRouter::new(config);
let agent_id = AgentPid::new();
let mailbox_config = crate::MailboxConfig::default();
let mailbox = crate::AgentMailbox::new(agent_id.clone(), mailbox_config);
let sender = mailbox.sender();
router
.register_agent(agent_id.clone(), sender)
.await
.unwrap();
let envelope = MessageEnvelope::new(
agent_id.clone(),
"test_message".to_string(),
serde_json::json!({"data": "test"}),
DeliveryOptions::default(),
);
router.route_message(envelope).await.unwrap();
let stats = router.get_stats().await;
assert_eq!(stats.messages_routed, 1);
assert_eq!(stats.messages_delivered, 1);
let received = mailbox.receive().await.unwrap();
assert!(received.from().is_some());
}
#[tokio::test]
async fn test_message_system() {
let config = RouterConfig::default();
let system = MessageSystem::new(config);
let agent_id = AgentPid::new();
system.register_agent(agent_id.clone()).await.unwrap();
let envelope = MessageEnvelope::new(
agent_id.clone(),
"test_message".to_string(),
serde_json::json!({"data": "test"}),
DeliveryOptions::default(),
);
system.send_message(envelope).await.unwrap();
let (router_stats, mailbox_stats) = system.get_stats().await;
assert_eq!(router_stats.messages_delivered, 1);
assert_eq!(mailbox_stats.len(), 1);
}
#[tokio::test]
async fn test_agent_not_found() {
let config = RouterConfig::default();
let router = DefaultMessageRouter::new(config);
let agent_id = AgentPid::new();
let envelope = MessageEnvelope::new(
agent_id.clone(),
"test_message".to_string(),
serde_json::json!({"data": "test"}),
DeliveryOptions::default(),
);
let result = router.route_message(envelope).await;
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
MessagingError::AgentNotFound(_)
));
}
#[tokio::test]
async fn test_duplicate_registration() {
let config = RouterConfig::default();
let router = DefaultMessageRouter::new(config);
let agent_id = AgentPid::new();
let mailbox_config = crate::MailboxConfig::default();
let mailbox = crate::AgentMailbox::new(agent_id.clone(), mailbox_config);
let sender = mailbox.sender();
router
.register_agent(agent_id.clone(), sender.clone())
.await
.unwrap();
let result = router.register_agent(agent_id, sender).await;
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
MessagingError::DuplicateAgent(_)
));
}
}