use super::types::{Event, EventType};
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use tokio::sync::broadcast;
const DEFAULT_CHANNEL_CAPACITY: usize = 256;
pub struct Subscription {
pub receiver: broadcast::Receiver<Event>,
pub event_type: Option<EventType>,
}
impl Subscription {
pub fn try_recv(&mut self) -> Result<Event, broadcast::error::TryRecvError> {
self.receiver.try_recv()
}
pub async fn recv(&mut self) -> Result<Event, broadcast::error::RecvError> {
self.receiver.recv().await
}
}
#[derive(Clone)]
pub struct EventBus {
all_sender: broadcast::Sender<Event>,
type_senders: Arc<RwLock<HashMap<EventType, broadcast::Sender<Event>>>>,
capacity: usize,
}
impl EventBus {
pub fn new() -> Self {
Self::with_capacity(DEFAULT_CHANNEL_CAPACITY)
}
pub fn with_capacity(capacity: usize) -> Self {
let (all_sender, _) = broadcast::channel(capacity);
Self {
all_sender,
type_senders: Arc::new(RwLock::new(HashMap::new())),
capacity,
}
}
pub fn subscribe_all(&self) -> Subscription {
Subscription {
receiver: self.all_sender.subscribe(),
event_type: None,
}
}
pub fn subscribe(&self, event_type: EventType) -> Subscription {
let sender = self.get_or_create_sender(event_type);
Subscription {
receiver: sender.subscribe(),
event_type: Some(event_type),
}
}
pub fn subscribe_many(&self, event_types: &[EventType]) -> FilteredSubscription {
FilteredSubscription {
receiver: self.all_sender.subscribe(),
event_types: event_types.to_vec(),
}
}
pub fn publish(&self, event: Event) {
let _ = self.all_sender.send(event.clone());
if let Ok(senders) = self.type_senders.read() {
if let Some(sender) = senders.get(&event.event_type) {
let _ = sender.send(event);
}
}
}
pub fn publish_batch(&self, events: Vec<Event>) {
for event in events {
self.publish(event);
}
}
pub fn subscriber_count(&self) -> usize {
self.all_sender.receiver_count()
}
pub fn subscriber_count_for(&self, event_type: EventType) -> usize {
if let Ok(senders) = self.type_senders.read() {
senders
.get(&event_type)
.map(|s| s.receiver_count())
.unwrap_or(0)
} else {
0
}
}
fn get_or_create_sender(&self, event_type: EventType) -> broadcast::Sender<Event> {
if let Ok(senders) = self.type_senders.read() {
if let Some(sender) = senders.get(&event_type) {
return sender.clone();
}
}
let mut senders = self.type_senders.write().unwrap();
senders
.entry(event_type)
.or_insert_with(|| {
let (sender, _) = broadcast::channel(self.capacity);
sender
})
.clone()
}
}
impl Default for EventBus {
fn default() -> Self {
Self::new()
}
}
pub struct FilteredSubscription {
receiver: broadcast::Receiver<Event>,
event_types: Vec<EventType>,
}
impl FilteredSubscription {
pub fn try_recv(&mut self) -> Result<Event, broadcast::error::TryRecvError> {
loop {
let event = self.receiver.try_recv()?;
if self.event_types.contains(&event.event_type) {
return Ok(event);
}
}
}
pub async fn recv(&mut self) -> Result<Event, broadcast::error::RecvError> {
loop {
let event = self.receiver.recv().await?;
if self.event_types.contains(&event.event_type) {
return Ok(event);
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_bus_creation() {
let bus = EventBus::new();
assert_eq!(bus.subscriber_count(), 0);
}
#[test]
fn test_subscribe_all() {
let bus = EventBus::new();
let _sub = bus.subscribe_all();
assert_eq!(bus.subscriber_count(), 1);
}
#[test]
fn test_subscribe_type() {
let bus = EventBus::new();
let _sub = bus.subscribe(EventType::ToolCall);
assert_eq!(bus.subscriber_count_for(EventType::ToolCall), 1);
assert_eq!(bus.subscriber_count_for(EventType::AgentStart), 0);
}
#[test]
fn test_publish_and_receive() {
let bus = EventBus::new();
let mut sub = bus.subscribe(EventType::ToolCall);
bus.publish(Event::tool_call("calc", "{}"));
let event = sub.try_recv().unwrap();
assert_eq!(event.event_type, EventType::ToolCall);
}
#[test]
fn test_type_filtering() {
let bus = EventBus::new();
let mut tool_sub = bus.subscribe(EventType::ToolCall);
let mut agent_sub = bus.subscribe(EventType::AgentStart);
bus.publish(Event::tool_call("calc", "{}"));
assert!(tool_sub.try_recv().is_ok());
assert!(agent_sub.try_recv().is_err());
}
#[test]
fn test_all_events_subscription() {
let bus = EventBus::new();
let mut sub = bus.subscribe_all();
bus.publish(Event::tool_call("calc", "{}"));
bus.publish(Event::agent_start("agent-1"));
assert!(sub.try_recv().is_ok());
assert!(sub.try_recv().is_ok());
}
#[test]
fn test_publish_batch() {
let bus = EventBus::new();
let mut sub = bus.subscribe_all();
bus.publish_batch(vec![
Event::tool_call("a", "{}"),
Event::tool_call("b", "{}"),
]);
assert!(sub.try_recv().is_ok());
assert!(sub.try_recv().is_ok());
}
#[tokio::test]
async fn test_async_receive() {
let bus = EventBus::new();
let mut sub = bus.subscribe(EventType::ToolCall);
let bus_clone = bus.clone();
tokio::spawn(async move {
tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
bus_clone.publish(Event::tool_call("calc", "{}"));
});
let event = sub.recv().await.unwrap();
assert_eq!(event.event_type, EventType::ToolCall);
}
}