Skip to main content

oxi/
event_bus.rs

1//! Async event bus for pub/sub communication
2//!
3//! Provides a type-safe event system for agent session events.
4
5use std::collections::HashMap;
6use std::fmt::Debug;
7use std::sync::Arc;
8use tokio::sync::RwLock;
9
10/// Event types for the agent session
11#[derive(Debug, Clone)]
12pub enum AgentSessionEvent {
13    /// A new message was received or sent
14    Message {
15        role: String,
16        content: String,
17        timestamp: u64,
18    },
19    /// A tool started executing
20    ToolStart {
21        tool_name: String,
22        input: serde_json::Value,
23    },
24    /// A tool finished executing
25    ToolEnd {
26        tool_name: String,
27        output: Result<serde_json::Value, String>,
28        duration_ms: u64,
29    },
30    /// An error occurred
31    Error {
32        message: String,
33        recoverable: bool,
34    },
35    /// Model started generating a response
36    ModelStart {
37        model_id: String,
38    },
39    /// Model finished generating a response
40    ModelEnd {
41        model_id: String,
42        duration_ms: u64,
43        tokens_used: Option<u32>,
44    },
45    /// Token usage update
46    TokenUsage {
47        input_tokens: u32,
48        output_tokens: u32,
49        cached_tokens: Option<u32>,
50    },
51    /// Session started
52    SessionStart {
53        session_id: String,
54    },
55    /// Session ended
56    SessionEnd {
57        session_id: String,
58        total_messages: u32,
59    },
60    /// Thinking block started
61    ThinkingStart,
62    /// Thinking block ended
63    ThinkingEnd {
64        thoughts: String,
65    },
66    /// Stream chunk received
67    StreamChunk {
68        content: String,
69    },
70    /// Tool call requested
71    ToolCall {
72        tool_name: String,
73        arguments: serde_json::Value,
74    },
75    /// Tool result received
76    ToolResult {
77        tool_name: String,
78        result: serde_json::Value,
79    },
80    /// Custom event from extensions
81    Custom {
82        name: String,
83        data: serde_json::Value,
84    },
85}
86
87/// Async event handler type - returns a pinned boxed future
88pub type EventHandler = Arc<dyn Fn(AgentSessionEvent) -> std::pin::Pin<Box<dyn std::future::Future<Output = ()> + Send + 'static>> + Send + Sync>;
89
90/// Sync event handler type (for simpler handlers)
91pub type SyncEventHandler = Arc<dyn Fn(AgentSessionEvent) + Send + Sync>;
92
93/// A subscriber handle
94pub struct Subscriber {
95    pub channel: String,
96    pub id: u64,
97}
98
99impl Subscriber {
100    /// Unsubscribe from the event channel
101    pub fn unsubscribe(self) {
102        // Subscriber is dropped
103    }
104}
105
106/// Internal subscriber storage
107struct BusInner {
108    subscribers: RwLock<HashMap<String, HashMap<u64, EventHandler>>>,
109    sync_subscribers: RwLock<HashMap<String, HashMap<u64, SyncEventHandler>>>,
110    next_id: RwLock<u64>,
111}
112
113/// Thread-safe async event bus for publish/subscribe pattern
114pub struct EventBus {
115    inner: Arc<BusInner>,
116}
117
118impl Default for EventBus {
119    fn default() -> Self {
120        Self::new()
121    }
122}
123
124impl EventBus {
125    /// Create a new event bus
126    pub fn new() -> Self {
127        Self {
128            inner: Arc::new(BusInner {
129                subscribers: RwLock::new(HashMap::new()),
130                sync_subscribers: RwLock::new(HashMap::new()),
131                next_id: RwLock::new(0),
132            }),
133        }
134    }
135
136    /// Create a new Arc-wrapped event bus
137    pub fn arc() -> Arc<Self> {
138        Arc::new(Self::new())
139    }
140
141    /// Subscribe to an event channel with an async handler
142    pub async fn subscribe_async<F, Fut>(&self, channel: &str, handler: F) -> Subscriber
143    where
144        F: Fn(AgentSessionEvent) -> Fut + Send + Sync + 'static,
145        Fut: std::future::Future<Output = ()> + Send + 'static,
146    {
147        let mut next_id = self.inner.next_id.write().await;
148        let id = *next_id;
149        *next_id = id + 1;
150        drop(next_id);
151
152        let handler: EventHandler = Arc::new(move |event| {
153            let fut = handler(event);
154            Box::pin(fut)
155        });
156
157        self.inner.subscribers
158            .write()
159            .await
160            .entry(channel.to_string())
161            .or_insert_with(HashMap::new)
162            .insert(id, handler);
163
164        Subscriber {
165            channel: channel.to_string(),
166            id,
167        }
168    }
169
170    /// Subscribe to an event channel with a sync handler
171    pub async fn subscribe_sync(&self, channel: &str, handler: SyncEventHandler) -> Subscriber {
172        let mut next_id = self.inner.next_id.write().await;
173        let id = *next_id;
174        *next_id = id + 1;
175        drop(next_id);
176
177        self.inner.sync_subscribers
178            .write()
179            .await
180            .entry(channel.to_string())
181            .or_insert_with(HashMap::new)
182            .insert(id, handler);
183
184        Subscriber {
185            channel: channel.to_string(),
186            id,
187        }
188    }
189
190    /// Subscribe to an event channel (sync version for convenience)
191    pub fn subscribe(&self, channel: &str, handler: SyncEventHandler) -> Subscriber {
192        let rt = tokio::runtime::Handle::current();
193        rt.block_on(async {
194            self.subscribe_sync(channel, handler).await
195        })
196    }
197
198    /// Publish an event to a channel
199    pub async fn publish(&self, channel: &str, event: AgentSessionEvent) {
200        // Notify sync handlers first
201        {
202            let sync_handlers = self.inner.sync_subscribers.read().await;
203            if let Some(handlers) = sync_handlers.get(channel) {
204                for handler in handlers.values() {
205                    handler(event.clone());
206                }
207            }
208        }
209
210        // Then notify async handlers
211        let handlers: Vec<EventHandler> = {
212            let async_handlers = self.inner.subscribers.read().await;
213            async_handlers
214                .get(channel)
215                .map(|h| h.values().cloned().collect())
216                .unwrap_or_default()
217        };
218
219        for handler in handlers {
220            let event_clone = event.clone();
221            tokio::spawn(async move {
222                handler(event_clone).await;
223            });
224        }
225    }
226
227    /// Unsubscribe a specific handler
228    pub async fn unsubscribe(&self, channel: &str, id: u64) {
229        if let Some(handlers) = self.inner.subscribers.write().await.get_mut(channel) {
230            handlers.remove(&id);
231        }
232        if let Some(handlers) = self.inner.sync_subscribers.write().await.get_mut(channel) {
233            handlers.remove(&id);
234        }
235    }
236
237    /// Unsubscribe all handlers for a channel
238    pub async fn unsubscribe_all(&self, channel: &str) {
239        self.inner.subscribers.write().await.remove(channel);
240        self.inner.sync_subscribers.write().await.remove(channel);
241    }
242
243    /// Clear all subscriptions
244    pub async fn clear(&self) {
245        self.inner.subscribers.write().await.clear();
246        self.inner.sync_subscribers.write().await.clear();
247    }
248
249    /// Get the number of active subscriptions
250    pub async fn subscription_count(&self) -> usize {
251        let async_count: usize = self.inner.subscribers.read().await.values().map(|h| h.len()).sum();
252        let sync_count: usize = self.inner.sync_subscribers.read().await.values().map(|h| h.len()).sum();
253        async_count + sync_count
254    }
255}
256
257/// Builder for creating event buses with predefined channels
258pub struct EventBusBuilder {
259    channels: Vec<String>,
260}
261
262impl EventBusBuilder {
263    pub fn new() -> Self {
264        Self { channels: Vec::new() }
265    }
266
267    pub fn with_channel(mut self, channel: impl Into<String>) -> Self {
268        self.channels.push(channel.into());
269        self
270    }
271
272    pub fn build(self) -> Arc<EventBus> {
273        let bus = EventBus::arc();
274        let _ = self.channels; // Channels are created on-demand
275        bus
276    }
277}
278
279impl Default for EventBusBuilder {
280    fn default() -> Self {
281        Self::new()
282    }
283}
284
285/// Common channel names
286pub mod channels {
287    pub const SESSION: &str = "session:*";
288    pub const MESSAGE: &str = "session:message";
289    pub const TOOL: &str = "session:tool";
290    pub const ERROR: &str = "session:error";
291    pub const TOKEN_USAGE: &str = "session:token_usage";
292    pub const MODEL: &str = "session:model";
293    pub const THINKING: &str = "session:thinking";
294    pub const STREAM: &str = "session:stream";
295    pub const CUSTOM: &str = "session:custom";
296}
297
298#[cfg(test)]
299mod tests {
300    use super::*;
301
302    #[tokio::test]
303    async fn test_subscribe_and_publish() {
304        let bus = EventBus::arc();
305        let received = Arc::new(RwLock::new(Vec::new()));
306        let received_clone = received.clone();
307
308        bus.subscribe_async("test", move |event| {
309            let received = received_clone.clone();
310            async move {
311                received.write().await.push(event);
312            }
313        })
314        .await;
315
316        let event = AgentSessionEvent::Error {
317            message: "test error".to_string(),
318            recoverable: true,
319        };
320
321        bus.publish("test", event.clone()).await;
322
323        tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
324
325        let captured = received.read().await;
326        assert_eq!(captured.len(), 1);
327        if let AgentSessionEvent::Error { message, .. } = &captured[0] {
328            assert_eq!(message, "test error");
329        }
330    }
331
332    #[tokio::test]
333    async fn test_sync_handler() {
334        let bus = EventBus::arc();
335        let received = Arc::new(std::sync::Mutex::new(Vec::new()));
336        let received_clone = received.clone();
337
338        bus.subscribe_sync("test", Arc::new(move |event| {
339            received_clone.lock().unwrap().push(event);
340        })).await;
341
342        let event = AgentSessionEvent::SessionStart {
343            session_id: "123".to_string(),
344        };
345
346        bus.publish("test", event.clone()).await;
347
348        let captured = received.lock().unwrap();
349        assert_eq!(captured.len(), 1);
350    }
351
352    #[tokio::test]
353    async fn test_multiple_subscribers() {
354        let bus = EventBus::arc();
355        let count1 = Arc::new(std::sync::Mutex::new(0));
356        let count2 = Arc::new(std::sync::Mutex::new(0));
357        let count1_clone = count1.clone();
358        let count2_clone = count2.clone();
359
360        bus.subscribe_sync("test", Arc::new(move |_| {
361            *count1_clone.lock().unwrap() += 1;
362        })).await;
363        bus.subscribe_sync("test", Arc::new(move |_| {
364            *count2_clone.lock().unwrap() += 1;
365        })).await;
366
367        bus.publish("test", AgentSessionEvent::ThinkingStart).await;
368
369        tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
370
371        assert_eq!(*count1.lock().unwrap(), 1);
372        assert_eq!(*count2.lock().unwrap(), 1);
373    }
374
375    #[tokio::test]
376    async fn test_unsubscribe() {
377        let bus = EventBus::arc();
378        let received = Arc::new(std::sync::Mutex::new(Vec::new()));
379        let received_clone = received.clone();
380
381        let subscriber = bus.subscribe_sync("test", Arc::new(move |_| {
382            received_clone.lock().unwrap().push(1);
383        })).await;
384
385        bus.publish("test", AgentSessionEvent::ThinkingStart).await;
386        tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
387        assert_eq!(received.lock().unwrap().len(), 1);
388
389        bus.unsubscribe("test", subscriber.id).await;
390
391        bus.publish("test", AgentSessionEvent::ThinkingStart).await;
392        tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
393        assert_eq!(received.lock().unwrap().len(), 1);
394    }
395
396    #[tokio::test]
397    async fn test_clear() {
398        let bus = EventBus::arc();
399        let received = Arc::new(std::sync::Mutex::new(Vec::new()));
400        let received_clone = received.clone();
401
402        bus.subscribe_sync("test", Arc::new(move |_| {
403            received_clone.lock().unwrap().push(1);
404        })).await;
405
406        bus.publish("test", AgentSessionEvent::ThinkingStart).await;
407        tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
408
409        bus.clear().await;
410
411        bus.publish("test", AgentSessionEvent::ThinkingStart).await;
412        tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
413
414        assert_eq!(received.lock().unwrap().len(), 1);
415    }
416
417    #[tokio::test]
418    async fn test_subscription_count() {
419        let bus = EventBus::arc();
420
421        assert_eq!(bus.subscription_count().await, 0);
422
423        let _sub1 = bus.subscribe_sync("test", Arc::new(|_| {})).await;
424        let _sub2 = bus.subscribe_sync("test", Arc::new(|_| {})).await;
425
426        assert_eq!(bus.subscription_count().await, 2);
427    }
428}