use crate::agent::AgentMetadata;
use crate::message::AgentMessage;
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use tokio::sync::{RwLock, broadcast};
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
pub enum CommunicationMode {
PointToPoint(String),
Broadcast,
PubSub(String),
}
#[derive(Clone)]
pub struct AgentBus {
agent_channels:
Arc<RwLock<HashMap<String, HashMap<CommunicationMode, broadcast::Sender<Vec<u8>>>>>>,
topic_subscribers: Arc<RwLock<HashMap<String, HashSet<String>>>>,
broadcast_channel: broadcast::Sender<Vec<u8>>,
}
impl AgentBus {
pub async fn new() -> anyhow::Result<Self> {
let (broadcast_sender, _) = broadcast::channel(100);
Ok(Self {
agent_channels: Arc::new(RwLock::new(HashMap::new())),
topic_subscribers: Arc::new(RwLock::new(HashMap::new())),
broadcast_channel: broadcast_sender,
})
}
pub async fn register_channel(
&self,
agent_metadata: &AgentMetadata,
mode: CommunicationMode,
) -> anyhow::Result<()> {
let id = &agent_metadata.id;
let mut agent_channels = self.agent_channels.write().await;
let entry = agent_channels.entry(id.clone()).or_default();
if matches!(mode, CommunicationMode::Broadcast) {
return Ok(());
}
if entry.contains_key(&mode) {
return Ok(());
}
let (sender, _) = broadcast::channel(100);
entry.insert(mode.clone(), sender);
if let CommunicationMode::PubSub(topic) = &mode {
let mut topic_subs = self.topic_subscribers.write().await;
topic_subs
.entry(topic.clone())
.or_default()
.insert(id.clone());
}
Ok(())
}
pub async fn send_message(
&self,
sender_id: &str,
mode: CommunicationMode,
message: &AgentMessage,
) -> anyhow::Result<()> {
let message_bytes = bincode::serialize(message)?;
match mode {
CommunicationMode::PointToPoint(receiver_id) => {
let agent_channels = self.agent_channels.read().await;
let Some(receiver_channels) = agent_channels.get(&receiver_id) else {
return Err(anyhow::anyhow!("Receiver agent {} not found", receiver_id));
};
let Some(channel) =
receiver_channels.get(&CommunicationMode::PointToPoint(sender_id.to_string()))
else {
return Err(anyhow::anyhow!(
"Receiver {} has no point-to-point channel with sender {}",
receiver_id,
sender_id
));
};
channel.send(message_bytes)?;
}
CommunicationMode::Broadcast => {
self.broadcast_channel.send(message_bytes)?;
}
CommunicationMode::PubSub(ref topic) => {
let topic_subs = self.topic_subscribers.read().await;
let subscribers = topic_subs
.get(topic)
.ok_or_else(|| anyhow::anyhow!("No subscribers for topic: {}", topic))?;
let agent_channels = self.agent_channels.read().await;
for sub_id in subscribers {
let Some(channels) = agent_channels.get(sub_id) else {
continue;
};
let Some(channel) = channels.get(&mode) else {
continue;
};
channel.send(message_bytes.clone())?;
}
}
}
Ok(())
}
pub async fn receive_message(
&self,
id: &str,
mode: CommunicationMode,
) -> anyhow::Result<Option<AgentMessage>> {
let agent_channels = self.agent_channels.read().await;
if matches!(mode, CommunicationMode::Broadcast) {
let mut receiver = self.broadcast_channel.subscribe();
match receiver.recv().await {
Ok(data) => {
let message = bincode::deserialize(&data)?;
Ok(Some(message))
}
Err(_) => Ok(None),
}
} else {
let Some(channels) = agent_channels.get(id) else {
return Ok(None);
};
let Some(channel) = channels.get(&mode) else {
return Ok(None);
};
let mut receiver = channel.subscribe();
match receiver.recv().await {
Ok(data) => {
let message = bincode::deserialize(&data)?;
Ok(Some(message))
}
Err(_) => Ok(None),
}
}
}
pub async fn unsubscribe_topic(&self, id: &str, topic: &str) -> anyhow::Result<()> {
let mut topic_subs = self.topic_subscribers.write().await;
if let Some(subscribers) = topic_subs.get_mut(topic) {
subscribers.remove(id);
if subscribers.is_empty() {
topic_subs.remove(topic);
}
}
Ok(())
}
}