1use std::collections::{HashMap, VecDeque};
4use std::net::SocketAddr;
5use std::sync::atomic::{AtomicU64, Ordering};
6use std::sync::Arc;
7
8use axum::{
9 extract::{
10 ws::{Message, WebSocket, WebSocketUpgrade},
11 State,
12 },
13 response::IntoResponse,
14 routing::get,
15 Router,
16};
17use futures::{SinkExt, StreamExt};
18use parking_lot::RwLock;
19use tokio::sync::broadcast;
20use uuid::Uuid;
21
22use super::events::{RealtimeEvent, SubscriptionFilter};
23
24pub type ConnectionId = String;
26
27const DEFAULT_MAX_BUFFERED_EVENTS: usize = 500;
29
30pub struct RealtimeManager {
40 tx: broadcast::Sender<RealtimeEvent>,
42 clients: Arc<RwLock<HashMap<ConnectionId, SubscriptionFilter>>>,
44 next_seq_id: Arc<AtomicU64>,
46 buffer: Arc<RwLock<VecDeque<RealtimeEvent>>>,
48 max_buffered_events: usize,
50}
51
52impl RealtimeManager {
53 pub fn new() -> Self {
55 Self::with_buffer_size(DEFAULT_MAX_BUFFERED_EVENTS)
56 }
57
58 pub fn with_buffer_size(max_buffered_events: usize) -> Self {
60 let (tx, _) = broadcast::channel(1000);
61 Self {
62 tx,
63 clients: Arc::new(RwLock::new(HashMap::new())),
64 next_seq_id: Arc::new(AtomicU64::new(1)),
65 buffer: Arc::new(RwLock::new(VecDeque::with_capacity(
66 max_buffered_events.min(4096),
67 ))),
68 max_buffered_events,
69 }
70 }
71
72 pub fn broadcast(&self, mut event: RealtimeEvent) {
77 let seq = self.next_seq_id.fetch_add(1, Ordering::Relaxed);
80 event.seq_id = Some(seq);
81
82 {
84 let mut buf = self.buffer.write();
85 if buf.len() >= self.max_buffered_events {
86 buf.pop_front();
87 }
88 buf.push_back(event.clone());
89 }
90
91 let _ = self.tx.send(event);
94 }
95
96 pub fn get_events_after(&self, last_seq_id: u64) -> Vec<RealtimeEvent> {
100 self.buffer
101 .read()
102 .iter()
103 .filter(|e| e.seq_id.is_some_and(|id| id > last_seq_id))
104 .cloned()
105 .collect()
106 }
107
108 pub fn current_seq(&self) -> u64 {
111 self.next_seq_id.load(Ordering::Relaxed)
112 }
113
114 pub fn client_count(&self) -> usize {
116 self.clients.read().len()
117 }
118
119 pub fn subscribe(&self) -> broadcast::Receiver<RealtimeEvent> {
121 self.tx.subscribe()
122 }
123
124 pub fn register_client(&self, id: ConnectionId, filter: SubscriptionFilter) {
126 self.clients.write().insert(id, filter);
127 }
128
129 pub fn unregister_client(&self, id: &str) {
131 self.clients.write().remove(id);
132 }
133
134 pub fn get_client_filter(&self, id: &str) -> Option<SubscriptionFilter> {
136 self.clients.read().get(id).cloned()
137 }
138}
139
140impl Default for RealtimeManager {
141 fn default() -> Self {
142 Self::new()
143 }
144}
145
146impl Clone for RealtimeManager {
147 fn clone(&self) -> Self {
148 Self {
149 tx: self.tx.clone(),
150 clients: self.clients.clone(),
151 next_seq_id: self.next_seq_id.clone(),
152 buffer: self.buffer.clone(),
153 max_buffered_events: self.max_buffered_events,
154 }
155 }
156}
157
158pub struct RealtimeServer {
160 manager: RealtimeManager,
161 addr: SocketAddr,
162}
163
164impl RealtimeServer {
165 pub fn new(manager: RealtimeManager, port: u16) -> Self {
167 let addr = SocketAddr::from(([0, 0, 0, 0], port));
168 Self { manager, addr }
169 }
170
171 pub fn router(manager: RealtimeManager) -> Router {
173 Router::new()
174 .route("/ws", get(ws_handler))
175 .route("/health", get(health_handler))
176 .with_state(manager)
177 }
178
179 pub async fn start(self) -> std::io::Result<()> {
181 let app = Self::router(self.manager);
182
183 tracing::info!("WebSocket server listening on {}", self.addr);
184
185 let listener = tokio::net::TcpListener::bind(self.addr).await?;
186 axum::serve(listener, app).await?;
187
188 Ok(())
189 }
190}
191
192async fn health_handler(State(manager): State<RealtimeManager>) -> impl IntoResponse {
194 serde_json::json!({
195 "status": "ok",
196 "clients": manager.client_count(),
197 })
198 .to_string()
199}
200
201async fn ws_handler(
203 ws: WebSocketUpgrade,
204 State(manager): State<RealtimeManager>,
205) -> impl IntoResponse {
206 ws.on_upgrade(move |socket| handle_socket(socket, manager))
207}
208
209async fn handle_socket(socket: WebSocket, manager: RealtimeManager) {
211 let connection_id = Uuid::new_v4().to_string();
212 let filter = SubscriptionFilter::default();
213
214 manager.register_client(connection_id.clone(), filter.clone());
215 tracing::info!("Client connected: {}", connection_id);
216
217 let (mut sender, mut receiver) = socket.split();
218 let mut rx = manager.subscribe();
219
220 let conn_id = connection_id.clone();
222 let mgr = manager.clone();
223 let send_task = tokio::spawn(async move {
224 while let Ok(event) = rx.recv().await {
225 if let Some(filter) = mgr.get_client_filter(&conn_id) {
227 if filter.matches(&event) {
228 let json = serde_json::to_string(&event).unwrap_or_default();
229 if sender.send(Message::Text(json)).await.is_err() {
230 break;
231 }
232 }
233 }
234 }
235 });
236
237 let conn_id = connection_id.clone();
239 let mgr = manager.clone();
240 let recv_task = tokio::spawn(async move {
241 while let Some(Ok(msg)) = receiver.next().await {
242 match msg {
243 Message::Text(text) => {
244 if let Ok(new_filter) = serde_json::from_str::<SubscriptionFilter>(&text) {
246 mgr.register_client(conn_id.clone(), new_filter);
247 tracing::debug!("Updated filter for client {}", conn_id);
248 }
249 }
250 Message::Close(_) => {
251 break;
252 }
253 _ => {}
254 }
255 }
256 });
257
258 tokio::select! {
260 _ = send_task => {}
261 _ = recv_task => {}
262 }
263
264 manager.unregister_client(&connection_id);
265 tracing::info!("Client disconnected: {}", connection_id);
266}
267
268#[cfg(test)]
269mod tests {
270 use super::*;
271
272 #[test]
273 fn test_realtime_manager() {
274 let manager = RealtimeManager::new();
275 assert_eq!(manager.client_count(), 0);
276
277 manager.register_client("test".to_string(), SubscriptionFilter::default());
278 assert_eq!(manager.client_count(), 1);
279
280 manager.unregister_client("test");
281 assert_eq!(manager.client_count(), 0);
282 }
283
284 #[test]
285 fn test_subscription_filter() {
286 let filter = SubscriptionFilter {
287 event_types: Some(vec![super::super::events::EventType::MemoryCreated]),
288 memory_ids: None,
289 tags: None,
290 };
291
292 let event = RealtimeEvent::memory_created(1, "test".to_string());
293 assert!(filter.matches(&event));
294
295 let event = RealtimeEvent::memory_deleted(1);
296 assert!(!filter.matches(&event));
297 }
298
299 #[test]
302 fn test_broadcast_stamps_sequential_ids() {
303 let manager = RealtimeManager::new();
304 let _rx = manager.subscribe(); manager.broadcast(RealtimeEvent::memory_created(1, "first".to_string()));
307 manager.broadcast(RealtimeEvent::memory_created(2, "second".to_string()));
308 manager.broadcast(RealtimeEvent::memory_deleted(3));
309
310 let buf = manager.buffer.read();
312 let ids: Vec<u64> = buf.iter().filter_map(|e| e.seq_id).collect();
313 assert_eq!(ids, vec![1, 2, 3]);
314 }
315
316 #[test]
317 fn test_seq_id_starts_at_one() {
318 let manager = RealtimeManager::new();
319 assert_eq!(manager.current_seq(), 1);
320
321 let _rx = manager.subscribe();
322 manager.broadcast(RealtimeEvent::memory_created(1, "hello".to_string()));
323 assert_eq!(manager.current_seq(), 2); }
325
326 #[test]
329 fn test_ring_buffer_evicts_oldest_when_full() {
330 let max = 3;
331 let manager = RealtimeManager::with_buffer_size(max);
332 let _rx = manager.subscribe();
333
334 for i in 1..=5u64 {
335 manager.broadcast(RealtimeEvent::memory_created(i as i64, format!("m{i}")));
336 }
337
338 let buf = manager.buffer.read();
339 assert_eq!(buf.len(), max, "buffer should be at capacity");
340 let ids: Vec<u64> = buf.iter().filter_map(|e| e.seq_id).collect();
342 assert_eq!(ids, vec![3, 4, 5]);
343 }
344
345 #[test]
346 fn test_ring_buffer_does_not_exceed_max_size() {
347 let max = 10;
348 let manager = RealtimeManager::with_buffer_size(max);
349 let _rx = manager.subscribe();
350
351 for i in 1..=20u64 {
352 manager.broadcast(RealtimeEvent::memory_deleted(i as i64));
353 }
354
355 assert_eq!(manager.buffer.read().len(), max);
356 }
357
358 #[test]
361 fn test_get_events_after_returns_correct_subset() {
362 let manager = RealtimeManager::new();
363 let _rx = manager.subscribe();
364
365 manager.broadcast(RealtimeEvent::memory_created(1, "a".to_string())); manager.broadcast(RealtimeEvent::memory_created(2, "b".to_string())); manager.broadcast(RealtimeEvent::memory_deleted(3)); let replayed = manager.get_events_after(1);
370 assert_eq!(replayed.len(), 2);
371 let ids: Vec<u64> = replayed.iter().filter_map(|e| e.seq_id).collect();
372 assert_eq!(ids, vec![2, 3]);
373 }
374
375 #[test]
376 fn test_get_events_after_zero_returns_all() {
377 let manager = RealtimeManager::new();
378 let _rx = manager.subscribe();
379
380 manager.broadcast(RealtimeEvent::memory_created(1, "x".to_string()));
381 manager.broadcast(RealtimeEvent::memory_created(2, "y".to_string()));
382
383 let replayed = manager.get_events_after(0);
384 assert_eq!(replayed.len(), 2);
385 }
386
387 #[test]
388 fn test_get_events_after_last_id_returns_empty() {
389 let manager = RealtimeManager::new();
390 let _rx = manager.subscribe();
391
392 manager.broadcast(RealtimeEvent::memory_created(1, "only".to_string())); let replayed = manager.get_events_after(1);
396 assert!(replayed.is_empty());
397 }
398
399 #[test]
400 fn test_get_events_after_large_id_returns_empty() {
401 let manager = RealtimeManager::new();
402 let _rx = manager.subscribe();
403
404 manager.broadcast(RealtimeEvent::memory_created(1, "ev".to_string()));
405
406 let replayed = manager.get_events_after(9999);
407 assert!(replayed.is_empty());
408 }
409
410 #[test]
413 fn test_clone_shares_buffer() {
414 let manager = RealtimeManager::new();
415 let cloned = manager.clone();
416 let _rx = manager.subscribe();
417
418 manager.broadcast(RealtimeEvent::memory_created(1, "shared".to_string()));
419
420 assert_eq!(cloned.buffer.read().len(), 1);
422 let replayed = cloned.get_events_after(0);
423 assert_eq!(replayed.len(), 1);
424 }
425}