use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::{broadcast, RwLock};
use crate::multi_agent::AgentMessage;
const DEFAULT_CAPACITY: usize = 256;
#[derive(Clone)]
pub struct AgentBus {
topics: Arc<RwLock<HashMap<String, broadcast::Sender<AgentMessage>>>>,
capacity: usize,
}
impl Default for AgentBus {
fn default() -> Self {
Self::new()
}
}
impl AgentBus {
pub fn new() -> Self {
Self::with_capacity(DEFAULT_CAPACITY)
}
pub fn with_capacity(capacity: usize) -> Self {
Self {
topics: Arc::new(RwLock::new(HashMap::new())),
capacity: capacity.max(1),
}
}
pub async fn subscribe(&self, topic: impl Into<String>) -> Subscription {
let topic = topic.into();
let mut topics = self.topics.write().await;
let tx = topics
.entry(topic.clone())
.or_insert_with(|| broadcast::channel(self.capacity).0);
Subscription {
rx: tx.subscribe(),
topic,
}
}
pub async fn publish(&self, topic: &str, msg: AgentMessage) -> usize {
if let Some(tx) = self.topics.read().await.get(topic) {
return tx.send(msg).unwrap_or(0);
}
let mut topics = self.topics.write().await;
let tx = topics
.entry(topic.to_string())
.or_insert_with(|| broadcast::channel(self.capacity).0);
tx.send(msg).unwrap_or(0)
}
pub async fn topic_count(&self) -> usize {
self.topics.read().await.len()
}
pub async fn subscriber_count(&self, topic: &str) -> usize {
self.topics
.read()
.await
.get(topic)
.map(broadcast::Sender::receiver_count)
.unwrap_or(0)
}
pub async fn drop_topic(&self, topic: &str) {
self.topics.write().await.remove(topic);
}
}
pub struct Subscription {
rx: broadcast::Receiver<AgentMessage>,
topic: String,
}
impl Subscription {
pub fn topic(&self) -> &str {
&self.topic
}
pub async fn recv(&mut self) -> Result<AgentMessage, SubscribeError> {
match self.rx.recv().await {
Ok(m) => Ok(m),
Err(broadcast::error::RecvError::Closed) => Err(SubscribeError::Closed),
Err(broadcast::error::RecvError::Lagged(n)) => Err(SubscribeError::Lagged(n)),
}
}
pub fn try_recv(&mut self) -> Result<AgentMessage, SubscribeError> {
match self.rx.try_recv() {
Ok(m) => Ok(m),
Err(broadcast::error::TryRecvError::Empty) => Err(SubscribeError::Empty),
Err(broadcast::error::TryRecvError::Closed) => Err(SubscribeError::Closed),
Err(broadcast::error::TryRecvError::Lagged(n)) => Err(SubscribeError::Lagged(n)),
}
}
}
#[derive(Debug)]
pub enum SubscribeError {
Closed,
Lagged(u64),
Empty,
}
impl std::fmt::Display for SubscribeError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
SubscribeError::Closed => write!(f, "topic closed"),
SubscribeError::Lagged(n) => write!(f, "subscriber lagged by {n} messages"),
SubscribeError::Empty => write!(f, "no message available"),
}
}
}
impl std::error::Error for SubscribeError {}
#[cfg(test)]
mod tests {
use super::*;
use cognis_core::Message;
fn msg(text: &str) -> AgentMessage {
AgentMessage {
from: "user".into(),
to: "broadcast".into(),
content: Message::human(text),
metadata: serde_json::Value::Null,
..Default::default()
}
}
#[tokio::test]
async fn one_publisher_one_subscriber_roundtrip() {
let bus = AgentBus::new();
let mut sub = bus.subscribe("planning").await;
let n = bus.publish("planning", msg("hello")).await;
assert_eq!(n, 1);
let got = sub.recv().await.unwrap();
assert_eq!(got.content.content(), "hello");
}
#[tokio::test]
async fn fanout_to_multiple_subscribers() {
let bus = AgentBus::new();
let mut a = bus.subscribe("alerts").await;
let mut b = bus.subscribe("alerts").await;
let mut c = bus.subscribe("alerts").await;
let n = bus.publish("alerts", msg("fire")).await;
assert_eq!(n, 3);
assert_eq!(a.recv().await.unwrap().content.content(), "fire");
assert_eq!(b.recv().await.unwrap().content.content(), "fire");
assert_eq!(c.recv().await.unwrap().content.content(), "fire");
}
#[tokio::test]
async fn topic_isolation() {
let bus = AgentBus::new();
let mut planning = bus.subscribe("planning").await;
let mut alerts = bus.subscribe("alerts").await;
bus.publish("planning", msg("plan-msg")).await;
bus.publish("alerts", msg("alert-msg")).await;
assert_eq!(planning.recv().await.unwrap().content.content(), "plan-msg");
assert_eq!(alerts.recv().await.unwrap().content.content(), "alert-msg");
}
#[tokio::test]
async fn publish_with_no_subscribers_drops_silently() {
let bus = AgentBus::new();
let n = bus.publish("ghost-topic", msg("nobody home")).await;
assert_eq!(n, 0);
assert_eq!(bus.topic_count().await, 1);
}
#[tokio::test]
async fn try_recv_empty_when_no_message() {
let bus = AgentBus::new();
let mut sub = bus.subscribe("t").await;
assert!(matches!(sub.try_recv(), Err(SubscribeError::Empty)));
}
#[tokio::test]
async fn drop_topic_closes_subscribers() {
let bus = AgentBus::new();
let mut sub = bus.subscribe("t").await;
bus.drop_topic("t").await;
assert!(matches!(sub.recv().await, Err(SubscribeError::Closed)));
}
#[tokio::test]
async fn subscriber_count_tracks_active() {
let bus = AgentBus::new();
let _a = bus.subscribe("t").await;
let _b = bus.subscribe("t").await;
assert_eq!(bus.subscriber_count("t").await, 2);
drop(_a);
let _c = bus.subscribe("t").await;
assert_eq!(bus.subscriber_count("t").await, 2);
}
#[tokio::test]
async fn lagged_subscribers_get_lagged_error() {
let bus = AgentBus::with_capacity(2);
let mut sub = bus.subscribe("t").await;
for i in 0..10 {
bus.publish("t", msg(&format!("m{i}"))).await;
}
match sub.recv().await {
Err(SubscribeError::Lagged(n)) => assert!(n > 0),
other => panic!("expected Lagged, got {other:?}"),
}
}
}