1use std::collections::HashMap;
7use std::sync::Arc;
8use tokio::sync::RwLock;
9
10use crate::trace::trace_lazy;
11
12#[derive(Debug, Clone, Hash, Eq, PartialEq)]
14pub enum EventType {
15 Stdout,
17 Stderr,
19 Data,
21 End,
23 Exit,
25 Error,
27 Spawn,
29}
30
31impl std::fmt::Display for EventType {
32 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
33 match self {
34 EventType::Stdout => write!(f, "stdout"),
35 EventType::Stderr => write!(f, "stderr"),
36 EventType::Data => write!(f, "data"),
37 EventType::End => write!(f, "end"),
38 EventType::Exit => write!(f, "exit"),
39 EventType::Error => write!(f, "error"),
40 EventType::Spawn => write!(f, "spawn"),
41 }
42 }
43}
44
45#[derive(Debug, Clone)]
47pub enum EventData {
48 String(String),
50 ExitCode(i32),
52 TypedData { data_type: String, data: String },
54 Result(crate::CommandResult),
56 Error(String),
58 None,
60}
61
62type Listener = Arc<dyn Fn(EventData) + Send + Sync>;
64
65pub struct StreamEmitter {
69 listeners: RwLock<HashMap<EventType, Vec<Listener>>>,
70}
71
72impl Default for StreamEmitter {
73 fn default() -> Self {
74 Self::new()
75 }
76}
77
78impl StreamEmitter {
79 pub fn new() -> Self {
81 StreamEmitter {
82 listeners: RwLock::new(HashMap::new()),
83 }
84 }
85
86 pub async fn on<F>(&self, event: EventType, listener: F)
101 where
102 F: Fn(EventData) + Send + Sync + 'static,
103 {
104 trace_lazy("StreamEmitter", || {
105 format!("on() called for event: {}", event)
106 });
107
108 let mut listeners = self.listeners.write().await;
109 listeners.entry(event).or_default().push(Arc::new(listener));
110 }
111
112 pub async fn once<F>(&self, event: EventType, listener: F)
116 where
117 F: Fn(EventData) + Send + Sync + 'static,
118 {
119 trace_lazy("StreamEmitter", || {
120 format!("once() called for event: {}", event)
121 });
122
123 let called = Arc::new(std::sync::atomic::AtomicBool::new(false));
125 let called_clone = called.clone();
126
127 let once_listener = move |data: EventData| {
128 if !called_clone.swap(true, std::sync::atomic::Ordering::SeqCst) {
129 listener(data);
130 }
131 };
132
133 self.on(event, once_listener).await;
134 }
135
136 pub async fn emit(&self, event: EventType, data: EventData) {
142 let listeners = self.listeners.read().await;
143
144 if let Some(event_listeners) = listeners.get(&event) {
145 trace_lazy("StreamEmitter", || {
146 format!(
147 "Emitting event {} to {} listeners",
148 event,
149 event_listeners.len()
150 )
151 });
152
153 for listener in event_listeners {
154 listener(data.clone());
155 }
156 }
157 }
158
159 pub async fn off(&self, event: EventType) {
164 trace_lazy("StreamEmitter", || {
165 format!("off() called for event: {}", event)
166 });
167
168 let mut listeners = self.listeners.write().await;
169 listeners.remove(&event);
170 }
171
172 pub async fn listener_count(&self, event: &EventType) -> usize {
174 let listeners = self.listeners.read().await;
175 listeners.get(event).map(|v| v.len()).unwrap_or(0)
176 }
177
178 pub async fn remove_all_listeners(&self) {
180 trace_lazy("StreamEmitter", || "Removing all listeners".to_string());
181 let mut listeners = self.listeners.write().await;
182 listeners.clear();
183 }
184}
185
186impl std::fmt::Debug for StreamEmitter {
187 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
188 f.debug_struct("StreamEmitter")
189 .field("listeners", &"<RwLock<HashMap<...>>>")
190 .finish()
191 }
192}
193
194#[cfg(test)]
195mod tests {
196 use super::*;
197 use std::sync::atomic::{AtomicUsize, Ordering};
198
199 #[tokio::test]
200 async fn test_emit_basic() {
201 let emitter = StreamEmitter::new();
202 let counter = Arc::new(AtomicUsize::new(0));
203 let counter_clone = counter.clone();
204
205 emitter
206 .on(EventType::Stdout, move |_| {
207 counter_clone.fetch_add(1, Ordering::SeqCst);
208 })
209 .await;
210
211 emitter
212 .emit(EventType::Stdout, EventData::String("test".to_string()))
213 .await;
214
215 assert_eq!(counter.load(Ordering::SeqCst), 1);
216 }
217
218 #[tokio::test]
219 async fn test_once() {
220 let emitter = StreamEmitter::new();
221 let counter = Arc::new(AtomicUsize::new(0));
222 let counter_clone = counter.clone();
223
224 emitter
225 .once(EventType::Exit, move |_| {
226 counter_clone.fetch_add(1, Ordering::SeqCst);
227 })
228 .await;
229
230 emitter.emit(EventType::Exit, EventData::ExitCode(0)).await;
232 emitter.emit(EventType::Exit, EventData::ExitCode(0)).await;
233
234 assert_eq!(counter.load(Ordering::SeqCst), 1);
236 }
237
238 #[tokio::test]
239 async fn test_off() {
240 let emitter = StreamEmitter::new();
241 let counter = Arc::new(AtomicUsize::new(0));
242 let counter_clone = counter.clone();
243
244 emitter
245 .on(EventType::Stdout, move |_| {
246 counter_clone.fetch_add(1, Ordering::SeqCst);
247 })
248 .await;
249
250 emitter.off(EventType::Stdout).await;
251 emitter
252 .emit(EventType::Stdout, EventData::String("test".to_string()))
253 .await;
254
255 assert_eq!(counter.load(Ordering::SeqCst), 0);
256 }
257
258 #[tokio::test]
259 async fn test_listener_count() {
260 let emitter = StreamEmitter::new();
261
262 assert_eq!(emitter.listener_count(&EventType::Stdout).await, 0);
263
264 emitter.on(EventType::Stdout, |_| {}).await;
265 assert_eq!(emitter.listener_count(&EventType::Stdout).await, 1);
266
267 emitter.on(EventType::Stdout, |_| {}).await;
268 assert_eq!(emitter.listener_count(&EventType::Stdout).await, 2);
269 }
270
271 #[tokio::test]
272 async fn test_multiple_events() {
273 let emitter = StreamEmitter::new();
274 let stdout_counter = Arc::new(AtomicUsize::new(0));
275 let stderr_counter = Arc::new(AtomicUsize::new(0));
276
277 let stdout_clone = stdout_counter.clone();
278 let stderr_clone = stderr_counter.clone();
279
280 emitter
281 .on(EventType::Stdout, move |_| {
282 stdout_clone.fetch_add(1, Ordering::SeqCst);
283 })
284 .await;
285
286 emitter
287 .on(EventType::Stderr, move |_| {
288 stderr_clone.fetch_add(1, Ordering::SeqCst);
289 })
290 .await;
291
292 emitter
293 .emit(EventType::Stdout, EventData::String("out".to_string()))
294 .await;
295 emitter
296 .emit(EventType::Stderr, EventData::String("err".to_string()))
297 .await;
298
299 assert_eq!(stdout_counter.load(Ordering::SeqCst), 1);
300 assert_eq!(stderr_counter.load(Ordering::SeqCst), 1);
301 }
302}