use std::collections::HashMap;
use std::fmt::Debug;
use std::sync::Arc;
use tokio::sync::RwLock;
#[derive(Debug, Clone)]
pub enum AgentSessionEvent {
Message {
role: String,
content: String,
timestamp: u64,
},
ToolStart {
tool_name: String,
input: serde_json::Value,
},
ToolEnd {
tool_name: String,
output: Result<serde_json::Value, String>,
duration_ms: u64,
},
Error { message: String, recoverable: bool },
ModelStart { model_id: String },
ModelEnd {
model_id: String,
duration_ms: u64,
tokens_used: Option<u32>,
},
TokenUsage {
input_tokens: u32,
output_tokens: u32,
cached_tokens: Option<u32>,
},
SessionStart { session_id: String },
SessionEnd {
session_id: String,
total_messages: u32,
},
ThinkingStart,
ThinkingEnd { thoughts: String },
StreamChunk { content: String },
ToolCall {
tool_name: String,
arguments: serde_json::Value,
},
ToolResult {
tool_name: String,
result: serde_json::Value,
},
Custom {
name: String,
data: serde_json::Value,
},
}
pub type EventHandler = Arc<
dyn Fn(
AgentSessionEvent,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = ()> + Send + 'static>>
+ Send
+ Sync,
>;
pub type SyncEventHandler = Arc<dyn Fn(AgentSessionEvent) + Send + Sync>;
pub struct Subscriber {
pub channel: String,
pub id: u64,
}
impl Subscriber {
pub fn unsubscribe(self) {
}
}
struct BusInner {
subscribers: RwLock<HashMap<String, HashMap<u64, EventHandler>>>,
sync_subscribers: RwLock<HashMap<String, HashMap<u64, SyncEventHandler>>>,
next_id: RwLock<u64>,
}
pub struct EventBus {
inner: Arc<BusInner>,
}
impl Default for EventBus {
fn default() -> Self {
Self::new()
}
}
impl EventBus {
pub fn new() -> Self {
Self {
inner: Arc::new(BusInner {
subscribers: RwLock::new(HashMap::new()),
sync_subscribers: RwLock::new(HashMap::new()),
next_id: RwLock::new(0),
}),
}
}
pub fn arc() -> Arc<Self> {
Arc::new(Self::new())
}
pub async fn subscribe_async<F, Fut>(&self, channel: &str, handler: F) -> Subscriber
where
F: Fn(AgentSessionEvent) -> Fut + Send + Sync + 'static,
Fut: std::future::Future<Output = ()> + Send + 'static,
{
let mut next_id = self.inner.next_id.write().await;
let id = *next_id;
*next_id = id + 1;
drop(next_id);
let handler: EventHandler = Arc::new(move |event| {
let fut = handler(event);
Box::pin(fut)
});
self.inner
.subscribers
.write()
.await
.entry(channel.to_string())
.or_insert_with(HashMap::new)
.insert(id, handler);
Subscriber {
channel: channel.to_string(),
id,
}
}
pub async fn subscribe_sync(&self, channel: &str, handler: SyncEventHandler) -> Subscriber {
let mut next_id = self.inner.next_id.write().await;
let id = *next_id;
*next_id = id + 1;
drop(next_id);
self.inner
.sync_subscribers
.write()
.await
.entry(channel.to_string())
.or_insert_with(HashMap::new)
.insert(id, handler);
Subscriber {
channel: channel.to_string(),
id,
}
}
pub fn subscribe(&self, channel: &str, handler: SyncEventHandler) -> Subscriber {
let rt = tokio::runtime::Handle::current();
rt.block_on(async { self.subscribe_sync(channel, handler).await })
}
pub async fn publish(&self, channel: &str, event: AgentSessionEvent) {
{
let sync_handlers = self.inner.sync_subscribers.read().await;
if let Some(handlers) = sync_handlers.get(channel) {
for handler in handlers.values() {
handler(event.clone());
}
}
}
let handlers: Vec<EventHandler> = {
let async_handlers = self.inner.subscribers.read().await;
async_handlers
.get(channel)
.map(|h| h.values().cloned().collect())
.unwrap_or_default()
};
for handler in handlers {
let event_clone = event.clone();
tokio::spawn(async move {
handler(event_clone).await;
});
}
}
pub async fn unsubscribe(&self, channel: &str, id: u64) {
if let Some(handlers) = self.inner.subscribers.write().await.get_mut(channel) {
handlers.remove(&id);
}
if let Some(handlers) = self.inner.sync_subscribers.write().await.get_mut(channel) {
handlers.remove(&id);
}
}
pub async fn unsubscribe_all(&self, channel: &str) {
self.inner.subscribers.write().await.remove(channel);
self.inner.sync_subscribers.write().await.remove(channel);
}
pub async fn clear(&self) {
self.inner.subscribers.write().await.clear();
self.inner.sync_subscribers.write().await.clear();
}
pub async fn subscription_count(&self) -> usize {
let async_count: usize = self
.inner
.subscribers
.read()
.await
.values()
.map(|h| h.len())
.sum();
let sync_count: usize = self
.inner
.sync_subscribers
.read()
.await
.values()
.map(|h| h.len())
.sum();
async_count + sync_count
}
}
pub struct EventBusBuilder {
channels: Vec<String>,
}
impl EventBusBuilder {
pub fn new() -> Self {
Self {
channels: Vec::new(),
}
}
pub fn with_channel(mut self, channel: impl Into<String>) -> Self {
self.channels.push(channel.into());
self
}
pub fn build(self) -> Arc<EventBus> {
let bus = EventBus::arc();
let _ = self.channels; bus
}
}
impl Default for EventBusBuilder {
fn default() -> Self {
Self::new()
}
}
pub mod channels {
pub const SESSION: &str = "session:*";
pub const MESSAGE: &str = "session:message";
pub const TOOL: &str = "session:tool";
pub const ERROR: &str = "session:error";
pub const TOKEN_USAGE: &str = "session:token_usage";
pub const MODEL: &str = "session:model";
pub const THINKING: &str = "session:thinking";
pub const STREAM: &str = "session:stream";
pub const CUSTOM: &str = "session:custom";
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_subscribe_and_publish() {
let bus = EventBus::arc();
let received = Arc::new(RwLock::new(Vec::new()));
let received_clone = received.clone();
bus.subscribe_async("test", move |event| {
let received = received_clone.clone();
async move {
received.write().await.push(event);
}
})
.await;
let event = AgentSessionEvent::Error {
message: "test error".to_string(),
recoverable: true,
};
bus.publish("test", event.clone()).await;
tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
let captured = received.read().await;
assert_eq!(captured.len(), 1);
if let AgentSessionEvent::Error { message, .. } = &captured[0] {
assert_eq!(message, "test error");
}
}
#[tokio::test]
async fn test_sync_handler() {
let bus = EventBus::arc();
let received = Arc::new(std::sync::Mutex::new(Vec::new()));
let received_clone = received.clone();
bus.subscribe_sync(
"test",
Arc::new(move |event| {
received_clone.lock().unwrap().push(event);
}),
)
.await;
let event = AgentSessionEvent::SessionStart {
session_id: "123".to_string(),
};
bus.publish("test", event.clone()).await;
let captured = received.lock().unwrap();
assert_eq!(captured.len(), 1);
}
#[tokio::test]
async fn test_multiple_subscribers() {
let bus = EventBus::arc();
let count1 = Arc::new(std::sync::Mutex::new(0));
let count2 = Arc::new(std::sync::Mutex::new(0));
let count1_clone = count1.clone();
let count2_clone = count2.clone();
bus.subscribe_sync(
"test",
Arc::new(move |_| {
*count1_clone.lock().unwrap() += 1;
}),
)
.await;
bus.subscribe_sync(
"test",
Arc::new(move |_| {
*count2_clone.lock().unwrap() += 1;
}),
)
.await;
bus.publish("test", AgentSessionEvent::ThinkingStart).await;
tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
assert_eq!(*count1.lock().unwrap(), 1);
assert_eq!(*count2.lock().unwrap(), 1);
}
#[tokio::test]
async fn test_unsubscribe() {
let bus = EventBus::arc();
let received = Arc::new(std::sync::Mutex::new(Vec::new()));
let received_clone = received.clone();
let subscriber = bus
.subscribe_sync(
"test",
Arc::new(move |_| {
received_clone.lock().unwrap().push(1);
}),
)
.await;
bus.publish("test", AgentSessionEvent::ThinkingStart).await;
tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
assert_eq!(received.lock().unwrap().len(), 1);
bus.unsubscribe("test", subscriber.id).await;
bus.publish("test", AgentSessionEvent::ThinkingStart).await;
tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
assert_eq!(received.lock().unwrap().len(), 1);
}
#[tokio::test]
async fn test_clear() {
let bus = EventBus::arc();
let received = Arc::new(std::sync::Mutex::new(Vec::new()));
let received_clone = received.clone();
bus.subscribe_sync(
"test",
Arc::new(move |_| {
received_clone.lock().unwrap().push(1);
}),
)
.await;
bus.publish("test", AgentSessionEvent::ThinkingStart).await;
tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
bus.clear().await;
bus.publish("test", AgentSessionEvent::ThinkingStart).await;
tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
assert_eq!(received.lock().unwrap().len(), 1);
}
#[tokio::test]
async fn test_subscription_count() {
let bus = EventBus::arc();
assert_eq!(bus.subscription_count().await, 0);
let _sub1 = bus.subscribe_sync("test", Arc::new(|_| {})).await;
let _sub2 = bus.subscribe_sync("test", Arc::new(|_| {})).await;
assert_eq!(bus.subscription_count().await, 2);
}
}