post_cortex_daemon/daemon/
sse.rs1use dashmap::DashMap;
26use serde::{Deserialize, Serialize};
27use std::sync::Arc;
28use std::sync::atomic::{AtomicU64, Ordering};
29use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender, unbounded_channel};
30use uuid::Uuid;
31
32#[derive(Debug, Clone, Serialize, Deserialize)]
34pub struct SseEvent {
35 pub id: String,
37 pub event_type: String,
39 pub data: serde_json::Value,
41}
42
43pub struct SSEBroadcaster {
48 clients: Arc<DashMap<Uuid, UnboundedSender<SseEvent>>>,
50
51 event_counter: Arc<AtomicU64>,
53
54 total_clients: Arc<AtomicU64>,
56}
57
58impl SSEBroadcaster {
59 pub fn new() -> Self {
61 Self {
62 clients: Arc::new(DashMap::new()),
63 event_counter: Arc::new(AtomicU64::new(0)),
64 total_clients: Arc::new(AtomicU64::new(0)),
65 }
66 }
67
68 pub fn register_client(&self, id: Uuid) -> UnboundedReceiver<SseEvent> {
70 let (tx, rx) = unbounded_channel();
71 self.clients.insert(id, tx);
72 self.total_clients.fetch_add(1, Ordering::Relaxed);
73 rx
74 }
75
76 pub fn unregister_client(&self, id: &Uuid) {
78 if self.clients.remove(id).is_some() {
79 self.total_clients.fetch_sub(1, Ordering::Relaxed);
80 }
81 }
82
83 pub fn broadcast(&self, event: SseEvent) {
85 self.event_counter.fetch_add(1, Ordering::Relaxed);
86
87 self.clients.iter().for_each(|entry| {
89 let _ = entry.value().send(event.clone());
91 });
92 }
93
94 pub fn send_to_client(&self, client_id: &Uuid, event: SseEvent) -> Result<(), String> {
96 self.clients
97 .get(client_id)
98 .ok_or_else(|| format!("Client {} not found", client_id))?
99 .send(event)
100 .map_err(|e| format!("Failed to send to client: {}", e))
101 }
102
103 pub fn active_clients(&self) -> u64 {
105 self.total_clients.load(Ordering::Relaxed)
106 }
107
108 pub fn total_events(&self) -> u64 {
110 self.event_counter.load(Ordering::Relaxed)
111 }
112}
113
114impl Default for SSEBroadcaster {
115 fn default() -> Self {
116 Self::new()
117 }
118}
119
120#[cfg(test)]
121mod tests {
122 use super::*;
123
124 #[tokio::test]
125 async fn test_sse_broadcaster_registration() {
126 let broadcaster = SSEBroadcaster::new();
127
128 let client1 = Uuid::new_v4();
129 let mut rx1 = broadcaster.register_client(client1);
130
131 assert_eq!(broadcaster.active_clients(), 1);
132
133 let event = SseEvent {
135 id: "1".to_string(),
136 event_type: "test".to_string(),
137 data: serde_json::json!({"message": "hello"}),
138 };
139
140 broadcaster.broadcast(event.clone());
141
142 let received = rx1.recv().await.unwrap();
144 assert_eq!(received.id, "1");
145 assert_eq!(received.event_type, "test");
146
147 broadcaster.unregister_client(&client1);
149 assert_eq!(broadcaster.active_clients(), 0);
150 }
151
152 #[tokio::test]
153 async fn test_concurrent_sse_operations() {
154 let broadcaster = Arc::new(SSEBroadcaster::new());
155
156 let mut handles = vec![];
158 for _ in 0..50 {
159 let bc = broadcaster.clone();
160 let handle = tokio::spawn(async move {
161 let id = Uuid::new_v4();
162 let mut rx = bc.register_client(id);
163
164 let event = rx.recv().await;
166 assert!(event.is_some());
167
168 bc.unregister_client(&id);
169 });
170 handles.push(handle);
171 }
172
173 tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
175
176 broadcaster.broadcast(SseEvent {
178 id: "broadcast".to_string(),
179 event_type: "test".to_string(),
180 data: serde_json::json!({}),
181 });
182
183 for handle in handles {
185 handle.await.unwrap();
186 }
187
188 assert_eq!(broadcaster.active_clients(), 0);
190 assert_eq!(broadcaster.total_events(), 1);
191 }
192}