use dashmap::DashMap;
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender, unbounded_channel};
use uuid::Uuid;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SseEvent {
pub id: String,
pub event_type: String,
pub data: serde_json::Value,
}
pub struct SSEBroadcaster {
clients: Arc<DashMap<Uuid, UnboundedSender<SseEvent>>>,
event_counter: Arc<AtomicU64>,
total_clients: Arc<AtomicU64>,
}
impl SSEBroadcaster {
pub fn new() -> Self {
Self {
clients: Arc::new(DashMap::new()),
event_counter: Arc::new(AtomicU64::new(0)),
total_clients: Arc::new(AtomicU64::new(0)),
}
}
pub fn register_client(&self, id: Uuid) -> UnboundedReceiver<SseEvent> {
let (tx, rx) = unbounded_channel();
self.clients.insert(id, tx);
self.total_clients.fetch_add(1, Ordering::Relaxed);
rx
}
pub fn unregister_client(&self, id: &Uuid) {
if self.clients.remove(id).is_some() {
self.total_clients.fetch_sub(1, Ordering::Relaxed);
}
}
pub fn broadcast(&self, event: SseEvent) {
self.event_counter.fetch_add(1, Ordering::Relaxed);
self.clients.iter().for_each(|entry| {
let _ = entry.value().send(event.clone());
});
}
pub fn send_to_client(&self, client_id: &Uuid, event: SseEvent) -> Result<(), String> {
self.clients
.get(client_id)
.ok_or_else(|| format!("Client {} not found", client_id))?
.send(event)
.map_err(|e| format!("Failed to send to client: {}", e))
}
pub fn active_clients(&self) -> u64 {
self.total_clients.load(Ordering::Relaxed)
}
pub fn total_events(&self) -> u64 {
self.event_counter.load(Ordering::Relaxed)
}
}
impl Default for SSEBroadcaster {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_sse_broadcaster_registration() {
let broadcaster = SSEBroadcaster::new();
let client1 = Uuid::new_v4();
let mut rx1 = broadcaster.register_client(client1);
assert_eq!(broadcaster.active_clients(), 1);
let event = SseEvent {
id: "1".to_string(),
event_type: "test".to_string(),
data: serde_json::json!({"message": "hello"}),
};
broadcaster.broadcast(event.clone());
let received = rx1.recv().await.unwrap();
assert_eq!(received.id, "1");
assert_eq!(received.event_type, "test");
broadcaster.unregister_client(&client1);
assert_eq!(broadcaster.active_clients(), 0);
}
#[tokio::test]
async fn test_concurrent_sse_operations() {
let broadcaster = Arc::new(SSEBroadcaster::new());
let mut handles = vec![];
for _ in 0..50 {
let bc = broadcaster.clone();
let handle = tokio::spawn(async move {
let id = Uuid::new_v4();
let mut rx = bc.register_client(id);
let event = rx.recv().await;
assert!(event.is_some());
bc.unregister_client(&id);
});
handles.push(handle);
}
tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
broadcaster.broadcast(SseEvent {
id: "broadcast".to_string(),
event_type: "test".to_string(),
data: serde_json::json!({}),
});
for handle in handles {
handle.await.unwrap();
}
assert_eq!(broadcaster.active_clients(), 0);
assert_eq!(broadcaster.total_events(), 1);
}
}