use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::RwLock;
use crate::trace::trace_lazy;
#[derive(Debug, Clone, Hash, Eq, PartialEq)]
pub enum EventType {
Stdout,
Stderr,
Data,
End,
Exit,
Error,
Spawn,
}
impl std::fmt::Display for EventType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
EventType::Stdout => write!(f, "stdout"),
EventType::Stderr => write!(f, "stderr"),
EventType::Data => write!(f, "data"),
EventType::End => write!(f, "end"),
EventType::Exit => write!(f, "exit"),
EventType::Error => write!(f, "error"),
EventType::Spawn => write!(f, "spawn"),
}
}
}
#[derive(Debug, Clone)]
pub enum EventData {
String(String),
ExitCode(i32),
TypedData { data_type: String, data: String },
Result(crate::CommandResult),
Error(String),
None,
}
type Listener = Arc<dyn Fn(EventData) + Send + Sync>;
pub struct StreamEmitter {
listeners: RwLock<HashMap<EventType, Vec<Listener>>>,
}
impl Default for StreamEmitter {
fn default() -> Self {
Self::new()
}
}
impl StreamEmitter {
pub fn new() -> Self {
StreamEmitter {
listeners: RwLock::new(HashMap::new()),
}
}
pub async fn on<F>(&self, event: EventType, listener: F)
where
F: Fn(EventData) + Send + Sync + 'static,
{
trace_lazy("StreamEmitter", || {
format!("on() called for event: {}", event)
});
let mut listeners = self.listeners.write().await;
listeners.entry(event).or_default().push(Arc::new(listener));
}
pub async fn once<F>(&self, event: EventType, listener: F)
where
F: Fn(EventData) + Send + Sync + 'static,
{
trace_lazy("StreamEmitter", || {
format!("once() called for event: {}", event)
});
let called = Arc::new(std::sync::atomic::AtomicBool::new(false));
let called_clone = called.clone();
let once_listener = move |data: EventData| {
if !called_clone.swap(true, std::sync::atomic::Ordering::SeqCst) {
listener(data);
}
};
self.on(event, once_listener).await;
}
pub async fn emit(&self, event: EventType, data: EventData) {
let listeners = self.listeners.read().await;
if let Some(event_listeners) = listeners.get(&event) {
trace_lazy("StreamEmitter", || {
format!(
"Emitting event {} to {} listeners",
event,
event_listeners.len()
)
});
for listener in event_listeners {
listener(data.clone());
}
}
}
pub async fn off(&self, event: EventType) {
trace_lazy("StreamEmitter", || {
format!("off() called for event: {}", event)
});
let mut listeners = self.listeners.write().await;
listeners.remove(&event);
}
pub async fn listener_count(&self, event: &EventType) -> usize {
let listeners = self.listeners.read().await;
listeners.get(event).map(|v| v.len()).unwrap_or(0)
}
pub async fn remove_all_listeners(&self) {
trace_lazy("StreamEmitter", || "Removing all listeners".to_string());
let mut listeners = self.listeners.write().await;
listeners.clear();
}
}
impl std::fmt::Debug for StreamEmitter {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("StreamEmitter")
.field("listeners", &"<RwLock<HashMap<...>>>")
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicUsize, Ordering};
#[tokio::test]
async fn test_emit_basic() {
let emitter = StreamEmitter::new();
let counter = Arc::new(AtomicUsize::new(0));
let counter_clone = counter.clone();
emitter
.on(EventType::Stdout, move |_| {
counter_clone.fetch_add(1, Ordering::SeqCst);
})
.await;
emitter
.emit(EventType::Stdout, EventData::String("test".to_string()))
.await;
assert_eq!(counter.load(Ordering::SeqCst), 1);
}
#[tokio::test]
async fn test_once() {
let emitter = StreamEmitter::new();
let counter = Arc::new(AtomicUsize::new(0));
let counter_clone = counter.clone();
emitter
.once(EventType::Exit, move |_| {
counter_clone.fetch_add(1, Ordering::SeqCst);
})
.await;
emitter.emit(EventType::Exit, EventData::ExitCode(0)).await;
emitter.emit(EventType::Exit, EventData::ExitCode(0)).await;
assert_eq!(counter.load(Ordering::SeqCst), 1);
}
#[tokio::test]
async fn test_off() {
let emitter = StreamEmitter::new();
let counter = Arc::new(AtomicUsize::new(0));
let counter_clone = counter.clone();
emitter
.on(EventType::Stdout, move |_| {
counter_clone.fetch_add(1, Ordering::SeqCst);
})
.await;
emitter.off(EventType::Stdout).await;
emitter
.emit(EventType::Stdout, EventData::String("test".to_string()))
.await;
assert_eq!(counter.load(Ordering::SeqCst), 0);
}
#[tokio::test]
async fn test_listener_count() {
let emitter = StreamEmitter::new();
assert_eq!(emitter.listener_count(&EventType::Stdout).await, 0);
emitter.on(EventType::Stdout, |_| {}).await;
assert_eq!(emitter.listener_count(&EventType::Stdout).await, 1);
emitter.on(EventType::Stdout, |_| {}).await;
assert_eq!(emitter.listener_count(&EventType::Stdout).await, 2);
}
#[tokio::test]
async fn test_multiple_events() {
let emitter = StreamEmitter::new();
let stdout_counter = Arc::new(AtomicUsize::new(0));
let stderr_counter = Arc::new(AtomicUsize::new(0));
let stdout_clone = stdout_counter.clone();
let stderr_clone = stderr_counter.clone();
emitter
.on(EventType::Stdout, move |_| {
stdout_clone.fetch_add(1, Ordering::SeqCst);
})
.await;
emitter
.on(EventType::Stderr, move |_| {
stderr_clone.fetch_add(1, Ordering::SeqCst);
})
.await;
emitter
.emit(EventType::Stdout, EventData::String("out".to_string()))
.await;
emitter
.emit(EventType::Stderr, EventData::String("err".to_string()))
.await;
assert_eq!(stdout_counter.load(Ordering::SeqCst), 1);
assert_eq!(stderr_counter.load(Ordering::SeqCst), 1);
}
}