1use std::collections::HashMap;
6use std::fmt::Debug;
7use std::sync::Arc;
8use tokio::sync::RwLock;
9
10#[derive(Debug, Clone)]
12pub enum AgentSessionEvent {
13 Message {
15 role: String,
16 content: String,
17 timestamp: u64,
18 },
19 ToolStart {
21 tool_name: String,
22 input: serde_json::Value,
23 },
24 ToolEnd {
26 tool_name: String,
27 output: Result<serde_json::Value, String>,
28 duration_ms: u64,
29 },
30 Error {
32 message: String,
33 recoverable: bool,
34 },
35 ModelStart {
37 model_id: String,
38 },
39 ModelEnd {
41 model_id: String,
42 duration_ms: u64,
43 tokens_used: Option<u32>,
44 },
45 TokenUsage {
47 input_tokens: u32,
48 output_tokens: u32,
49 cached_tokens: Option<u32>,
50 },
51 SessionStart {
53 session_id: String,
54 },
55 SessionEnd {
57 session_id: String,
58 total_messages: u32,
59 },
60 ThinkingStart,
62 ThinkingEnd {
64 thoughts: String,
65 },
66 StreamChunk {
68 content: String,
69 },
70 ToolCall {
72 tool_name: String,
73 arguments: serde_json::Value,
74 },
75 ToolResult {
77 tool_name: String,
78 result: serde_json::Value,
79 },
80 Custom {
82 name: String,
83 data: serde_json::Value,
84 },
85}
86
87pub type EventHandler = Arc<dyn Fn(AgentSessionEvent) -> std::pin::Pin<Box<dyn std::future::Future<Output = ()> + Send + 'static>> + Send + Sync>;
89
90pub type SyncEventHandler = Arc<dyn Fn(AgentSessionEvent) + Send + Sync>;
92
93pub struct Subscriber {
95 pub channel: String,
96 pub id: u64,
97}
98
99impl Subscriber {
100 pub fn unsubscribe(self) {
102 }
104}
105
106struct BusInner {
108 subscribers: RwLock<HashMap<String, HashMap<u64, EventHandler>>>,
109 sync_subscribers: RwLock<HashMap<String, HashMap<u64, SyncEventHandler>>>,
110 next_id: RwLock<u64>,
111}
112
113pub struct EventBus {
115 inner: Arc<BusInner>,
116}
117
118impl Default for EventBus {
119 fn default() -> Self {
120 Self::new()
121 }
122}
123
124impl EventBus {
125 pub fn new() -> Self {
127 Self {
128 inner: Arc::new(BusInner {
129 subscribers: RwLock::new(HashMap::new()),
130 sync_subscribers: RwLock::new(HashMap::new()),
131 next_id: RwLock::new(0),
132 }),
133 }
134 }
135
136 pub fn arc() -> Arc<Self> {
138 Arc::new(Self::new())
139 }
140
141 pub async fn subscribe_async<F, Fut>(&self, channel: &str, handler: F) -> Subscriber
143 where
144 F: Fn(AgentSessionEvent) -> Fut + Send + Sync + 'static,
145 Fut: std::future::Future<Output = ()> + Send + 'static,
146 {
147 let mut next_id = self.inner.next_id.write().await;
148 let id = *next_id;
149 *next_id = id + 1;
150 drop(next_id);
151
152 let handler: EventHandler = Arc::new(move |event| {
153 let fut = handler(event);
154 Box::pin(fut)
155 });
156
157 self.inner.subscribers
158 .write()
159 .await
160 .entry(channel.to_string())
161 .or_insert_with(HashMap::new)
162 .insert(id, handler);
163
164 Subscriber {
165 channel: channel.to_string(),
166 id,
167 }
168 }
169
170 pub async fn subscribe_sync(&self, channel: &str, handler: SyncEventHandler) -> Subscriber {
172 let mut next_id = self.inner.next_id.write().await;
173 let id = *next_id;
174 *next_id = id + 1;
175 drop(next_id);
176
177 self.inner.sync_subscribers
178 .write()
179 .await
180 .entry(channel.to_string())
181 .or_insert_with(HashMap::new)
182 .insert(id, handler);
183
184 Subscriber {
185 channel: channel.to_string(),
186 id,
187 }
188 }
189
190 pub fn subscribe(&self, channel: &str, handler: SyncEventHandler) -> Subscriber {
192 let rt = tokio::runtime::Handle::current();
193 rt.block_on(async {
194 self.subscribe_sync(channel, handler).await
195 })
196 }
197
198 pub async fn publish(&self, channel: &str, event: AgentSessionEvent) {
200 {
202 let sync_handlers = self.inner.sync_subscribers.read().await;
203 if let Some(handlers) = sync_handlers.get(channel) {
204 for handler in handlers.values() {
205 handler(event.clone());
206 }
207 }
208 }
209
210 let handlers: Vec<EventHandler> = {
212 let async_handlers = self.inner.subscribers.read().await;
213 async_handlers
214 .get(channel)
215 .map(|h| h.values().cloned().collect())
216 .unwrap_or_default()
217 };
218
219 for handler in handlers {
220 let event_clone = event.clone();
221 tokio::spawn(async move {
222 handler(event_clone).await;
223 });
224 }
225 }
226
227 pub async fn unsubscribe(&self, channel: &str, id: u64) {
229 if let Some(handlers) = self.inner.subscribers.write().await.get_mut(channel) {
230 handlers.remove(&id);
231 }
232 if let Some(handlers) = self.inner.sync_subscribers.write().await.get_mut(channel) {
233 handlers.remove(&id);
234 }
235 }
236
237 pub async fn unsubscribe_all(&self, channel: &str) {
239 self.inner.subscribers.write().await.remove(channel);
240 self.inner.sync_subscribers.write().await.remove(channel);
241 }
242
243 pub async fn clear(&self) {
245 self.inner.subscribers.write().await.clear();
246 self.inner.sync_subscribers.write().await.clear();
247 }
248
249 pub async fn subscription_count(&self) -> usize {
251 let async_count: usize = self.inner.subscribers.read().await.values().map(|h| h.len()).sum();
252 let sync_count: usize = self.inner.sync_subscribers.read().await.values().map(|h| h.len()).sum();
253 async_count + sync_count
254 }
255}
256
257pub struct EventBusBuilder {
259 channels: Vec<String>,
260}
261
262impl EventBusBuilder {
263 pub fn new() -> Self {
264 Self { channels: Vec::new() }
265 }
266
267 pub fn with_channel(mut self, channel: impl Into<String>) -> Self {
268 self.channels.push(channel.into());
269 self
270 }
271
272 pub fn build(self) -> Arc<EventBus> {
273 let bus = EventBus::arc();
274 let _ = self.channels; bus
276 }
277}
278
279impl Default for EventBusBuilder {
280 fn default() -> Self {
281 Self::new()
282 }
283}
284
285pub mod channels {
287 pub const SESSION: &str = "session:*";
288 pub const MESSAGE: &str = "session:message";
289 pub const TOOL: &str = "session:tool";
290 pub const ERROR: &str = "session:error";
291 pub const TOKEN_USAGE: &str = "session:token_usage";
292 pub const MODEL: &str = "session:model";
293 pub const THINKING: &str = "session:thinking";
294 pub const STREAM: &str = "session:stream";
295 pub const CUSTOM: &str = "session:custom";
296}
297
298#[cfg(test)]
299mod tests {
300 use super::*;
301
302 #[tokio::test]
303 async fn test_subscribe_and_publish() {
304 let bus = EventBus::arc();
305 let received = Arc::new(RwLock::new(Vec::new()));
306 let received_clone = received.clone();
307
308 bus.subscribe_async("test", move |event| {
309 let received = received_clone.clone();
310 async move {
311 received.write().await.push(event);
312 }
313 })
314 .await;
315
316 let event = AgentSessionEvent::Error {
317 message: "test error".to_string(),
318 recoverable: true,
319 };
320
321 bus.publish("test", event.clone()).await;
322
323 tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
324
325 let captured = received.read().await;
326 assert_eq!(captured.len(), 1);
327 if let AgentSessionEvent::Error { message, .. } = &captured[0] {
328 assert_eq!(message, "test error");
329 }
330 }
331
332 #[tokio::test]
333 async fn test_sync_handler() {
334 let bus = EventBus::arc();
335 let received = Arc::new(std::sync::Mutex::new(Vec::new()));
336 let received_clone = received.clone();
337
338 bus.subscribe_sync("test", Arc::new(move |event| {
339 received_clone.lock().unwrap().push(event);
340 })).await;
341
342 let event = AgentSessionEvent::SessionStart {
343 session_id: "123".to_string(),
344 };
345
346 bus.publish("test", event.clone()).await;
347
348 let captured = received.lock().unwrap();
349 assert_eq!(captured.len(), 1);
350 }
351
352 #[tokio::test]
353 async fn test_multiple_subscribers() {
354 let bus = EventBus::arc();
355 let count1 = Arc::new(std::sync::Mutex::new(0));
356 let count2 = Arc::new(std::sync::Mutex::new(0));
357 let count1_clone = count1.clone();
358 let count2_clone = count2.clone();
359
360 bus.subscribe_sync("test", Arc::new(move |_| {
361 *count1_clone.lock().unwrap() += 1;
362 })).await;
363 bus.subscribe_sync("test", Arc::new(move |_| {
364 *count2_clone.lock().unwrap() += 1;
365 })).await;
366
367 bus.publish("test", AgentSessionEvent::ThinkingStart).await;
368
369 tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
370
371 assert_eq!(*count1.lock().unwrap(), 1);
372 assert_eq!(*count2.lock().unwrap(), 1);
373 }
374
375 #[tokio::test]
376 async fn test_unsubscribe() {
377 let bus = EventBus::arc();
378 let received = Arc::new(std::sync::Mutex::new(Vec::new()));
379 let received_clone = received.clone();
380
381 let subscriber = bus.subscribe_sync("test", Arc::new(move |_| {
382 received_clone.lock().unwrap().push(1);
383 })).await;
384
385 bus.publish("test", AgentSessionEvent::ThinkingStart).await;
386 tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
387 assert_eq!(received.lock().unwrap().len(), 1);
388
389 bus.unsubscribe("test", subscriber.id).await;
390
391 bus.publish("test", AgentSessionEvent::ThinkingStart).await;
392 tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
393 assert_eq!(received.lock().unwrap().len(), 1);
394 }
395
396 #[tokio::test]
397 async fn test_clear() {
398 let bus = EventBus::arc();
399 let received = Arc::new(std::sync::Mutex::new(Vec::new()));
400 let received_clone = received.clone();
401
402 bus.subscribe_sync("test", Arc::new(move |_| {
403 received_clone.lock().unwrap().push(1);
404 })).await;
405
406 bus.publish("test", AgentSessionEvent::ThinkingStart).await;
407 tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
408
409 bus.clear().await;
410
411 bus.publish("test", AgentSessionEvent::ThinkingStart).await;
412 tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
413
414 assert_eq!(received.lock().unwrap().len(), 1);
415 }
416
417 #[tokio::test]
418 async fn test_subscription_count() {
419 let bus = EventBus::arc();
420
421 assert_eq!(bus.subscription_count().await, 0);
422
423 let _sub1 = bus.subscribe_sync("test", Arc::new(|_| {})).await;
424 let _sub2 = bus.subscribe_sync("test", Arc::new(|_| {})).await;
425
426 assert_eq!(bus.subscription_count().await, 2);
427 }
428}