Skip to main content

agentlink_sdk/events/
event_loop.rs

1//! Generic Event Loop
2//!
3//! Provides type-safe event handling with generic callbacks.
4
5use std::collections::HashMap;
6use std::future::Future;
7use std::pin::Pin;
8use std::sync::Arc;
9use tokio::sync::{mpsc, RwLock};
10use serde::de::DeserializeOwned;
11use serde_json::Value;
12
13use super::ServerEvent;
14
15
16
17/// Type alias for event callbacks
18pub type EventCallback<T> = Arc<
19    dyn Fn(ServerEvent<T>) -> Pin<Box<dyn Future<Output = ()> + Send>> + Send + Sync,
20>;
21
22/// Type-erased callback wrapper
23struct TypeErasedCallback {
24    callback: Arc<dyn Fn(Value) -> Pin<Box<dyn Future<Output = ()> + Send>> + Send + Sync>,
25}
26
27/// Generic event loop with type-safe callbacks
28pub struct EventLoop {
29    /// Registered callbacks for specific event types
30    callbacks: Arc<RwLock<HashMap<String, Vec<TypeErasedCallback>>>>,
31    /// Channel sender for internal event distribution
32    event_tx: mpsc::UnboundedSender<(String, Value)>,
33    /// Running state
34    running: Arc<RwLock<bool>>,
35}
36
37impl EventLoop {
38    /// Create a new event loop
39    pub fn new() -> (Self, mpsc::UnboundedReceiver<(String, Value)>) {
40        let (tx, rx) = mpsc::unbounded_channel();
41        
42        let event_loop = Self {
43            callbacks: Arc::new(RwLock::new(HashMap::new())),
44            event_tx: tx,
45            running: Arc::new(RwLock::new(false)),
46        };
47        
48        (event_loop, rx)
49    }
50    
51    /// Register a typed callback for a specific event type
52    ///
53    /// # Example
54    /// ```rust,ignore
55    /// event_loop.on_event(EVENT_MESSAGE_RECEIVED, |event: ServerEvent<MessageReceivedData>| {
56    ///     Box::pin(async move {
57    ///         println!("New message: {:?}", event.data.content);
58    ///     })
59    /// }).await;
60    /// ```
61    pub async fn on_event<T, F, Fut>(&self, event_type: &str, callback: F)
62    where
63        T: DeserializeOwned + Send + Sync + 'static,
64        F: Fn(ServerEvent<T>) -> Fut + Send + Sync + 'static,
65        Fut: Future<Output = ()> + Send + 'static,
66    {
67        let wrapped = TypeErasedCallback {
68            callback: Arc::new(move |value: Value| {
69                // Try to deserialize the full event
70                match serde_json::from_value::<ServerEvent<T>>(value.clone()) {
71                    Ok(event) => {
72                        Box::pin(callback(event)) as Pin<Box<dyn Future<Output = ()> + Send>>
73                    }
74                    Err(e) => {
75                        tracing::error!("[EventLoop] Failed to deserialize event: {}", e);
76                        Box::pin(async {}) as Pin<Box<dyn Future<Output = ()> + Send>>
77                    }
78                }
79            }),
80        };
81        
82        let mut callbacks = self.callbacks.write().await;
83        callbacks
84            .entry(event_type.to_string())
85            .or_insert_with(Vec::new)
86            .push(wrapped);
87        
88        tracing::debug!("[EventLoop] Registered callback for event type: {}", event_type);
89    }
90    
91    /// Register a handler function directly (simplified API)
92    ///
93    /// This is a convenience wrapper around `on_event` that allows direct
94    /// registration of async handler functions without wrapping them in closures.
95    ///
96    /// # Example
97    /// ```rust,ignore
98    /// // 定义 handler 函数
99    /// async fn on_message_received(event: ServerEvent<MessageReceivedData>) {
100    ///     println!("New message: {:?}", event.data.content);
101    /// }
102    ///
103    /// // 直接注册,无需包装闭包
104    /// event_loop.on(EVENT_MESSAGE_RECEIVED, on_message_received).await;
105    /// ```
106    pub async fn on<T, F, Fut>(&self, event_type: &'static str, handler: F)
107    where
108        T: DeserializeOwned + Send + Sync + 'static,
109        F: Fn(ServerEvent<T>) -> Fut + Send + Sync + 'static,
110        Fut: Future<Output = ()> + Send + 'static,
111    {
112        self.on_event(event_type, handler).await;
113    }
114    
115    /// Remove all callbacks for a specific event type
116    pub async fn off_event(&self, event_type: &str) {
117        let mut callbacks = self.callbacks.write().await;
118        callbacks.remove(event_type);
119        tracing::info!("[EventLoop] Removed callbacks for event type: {}", event_type);
120    }
121    
122    /// Clear all callbacks
123    pub async fn clear_callbacks(&self) {
124        let mut callbacks = self.callbacks.write().await;
125        callbacks.clear();
126        tracing::info!("[EventLoop] Cleared all callbacks");
127    }
128    
129    /// Process a single event
130    pub async fn process_event(&self, event_type: &str, payload: Value) {
131        let callbacks = self.callbacks.read().await;
132        
133        if let Some(handlers) = callbacks.get(event_type) {
134            for handler in handlers {
135                let callback = handler.callback.clone();
136                let payload = payload.clone();
137                tokio::spawn(async move {
138                    callback(payload).await;
139                });
140            }
141        }
142    }
143    
144    /// Start the event loop
145    pub async fn start(&self, mut event_rx: mpsc::UnboundedReceiver<(String, Value)>) {
146        {
147            let mut running = self.running.write().await;
148            *running = true;
149        }
150        
151        tracing::debug!("[EventLoop] Started");
152        
153        while let Some((event_type, payload)) = event_rx.recv().await {
154            self.process_event(&event_type, payload).await;
155        }
156        
157        {
158            let mut running = self.running.write().await;
159            *running = false;
160        }
161        
162        tracing::info!("[EventLoop] Stopped");
163    }
164    
165    /// Check if the event loop is running
166    pub async fn is_running(&self) -> bool {
167        *self.running.read().await
168    }
169    
170    /// Get event sender
171    pub fn event_sender(&self) -> mpsc::UnboundedSender<(String, Value)> {
172        self.event_tx.clone()
173    }
174}
175
176impl Clone for EventLoop {
177    fn clone(&self) -> Self {
178        Self {
179            callbacks: self.callbacks.clone(),
180            event_tx: self.event_tx.clone(),
181            running: self.running.clone(),
182        }
183    }
184}
185
186impl Default for EventLoop {
187    fn default() -> Self {
188        let (event_loop, _) = Self::new();
189        event_loop
190    }
191}
192
193#[cfg(test)]
194mod tests {
195    use super::*;
196    use super::super::{MessageReceivedData, EVENT_MESSAGE_RECEIVED};
197    
198    #[tokio::test]
199    async fn test_typed_callback_registration() {
200        let (event_loop, _) = EventLoop::new();
201        
202        let called = Arc::new(RwLock::new(false));
203        let called_clone = called.clone();
204        
205        event_loop.on_event(EVENT_MESSAGE_RECEIVED, move |event: ServerEvent<MessageReceivedData>| {
206            let called = called_clone.clone();
207            Box::pin(async move {
208                assert_eq!(event.data.content, Some("Hello".to_string()));
209                let mut c = called.write().await;
210                *c = true;
211            })
212        }).await;
213        
214        // Verify callback was registered
215        let callbacks = event_loop.callbacks.read().await;
216        assert!(callbacks.contains_key(EVENT_MESSAGE_RECEIVED));
217    }
218}