Skip to main content

engram/realtime/
server.rs

1//! WebSocket server for real-time updates
2
3use std::collections::{HashMap, VecDeque};
4use std::net::SocketAddr;
5use std::sync::atomic::{AtomicU64, Ordering};
6use std::sync::Arc;
7
8use axum::{
9    extract::{
10        ws::{Message, WebSocket, WebSocketUpgrade},
11        State,
12    },
13    response::IntoResponse,
14    routing::get,
15    Router,
16};
17use futures::{SinkExt, StreamExt};
18use parking_lot::RwLock;
19use tokio::sync::broadcast;
20use uuid::Uuid;
21
22use super::events::{RealtimeEvent, SubscriptionFilter};
23
24/// Connection ID
25pub type ConnectionId = String;
26
27/// Default maximum number of events retained in the replay ring buffer.
28const DEFAULT_MAX_BUFFERED_EVENTS: usize = 500;
29
30/// Manages WebSocket connections and SSE subscriptions.
31///
32/// Each event broadcast through [`RealtimeManager::broadcast`] is:
33/// 1. Assigned a monotonically-increasing `seq_id`.
34/// 2. Pushed into an in-memory ring buffer (capacity [`DEFAULT_MAX_BUFFERED_EVENTS`]).
35/// 3. Sent over the tokio broadcast channel for live subscribers.
36///
37/// Clients that reconnect with a `Last-Event-Id` header can call
38/// [`RealtimeManager::get_events_after`] to retrieve buffered events they missed.
39pub struct RealtimeManager {
40    /// Broadcast channel for live delivery
41    tx: broadcast::Sender<RealtimeEvent>,
42    /// Connected clients with their filters
43    clients: Arc<RwLock<HashMap<ConnectionId, SubscriptionFilter>>>,
44    /// Monotonically-increasing sequence counter (starts at 1)
45    next_seq_id: Arc<AtomicU64>,
46    /// In-memory ring buffer for replay
47    buffer: Arc<RwLock<VecDeque<RealtimeEvent>>>,
48    /// Maximum number of events kept in the buffer
49    max_buffered_events: usize,
50}
51
52impl RealtimeManager {
53    /// Create a new realtime manager with the default buffer size (500 events).
54    pub fn new() -> Self {
55        Self::with_buffer_size(DEFAULT_MAX_BUFFERED_EVENTS)
56    }
57
58    /// Create a realtime manager with a custom ring-buffer size.
59    pub fn with_buffer_size(max_buffered_events: usize) -> Self {
60        let (tx, _) = broadcast::channel(1000);
61        Self {
62            tx,
63            clients: Arc::new(RwLock::new(HashMap::new())),
64            next_seq_id: Arc::new(AtomicU64::new(1)),
65            buffer: Arc::new(RwLock::new(VecDeque::with_capacity(
66                max_buffered_events.min(4096),
67            ))),
68            max_buffered_events,
69        }
70    }
71
72    /// Broadcast an event to all matching clients.
73    ///
74    /// The event is stamped with a sequential `seq_id`, pushed into the ring
75    /// buffer, and sent over the broadcast channel.
76    pub fn broadcast(&self, mut event: RealtimeEvent) {
77        // Stamp with sequential ID (fetch-and-increment, wraps at u64::MAX which
78        // is effectively never for any real-world workload).
79        let seq = self.next_seq_id.fetch_add(1, Ordering::Relaxed);
80        event.seq_id = Some(seq);
81
82        // Push into ring buffer, evicting the oldest entry when full.
83        {
84            let mut buf = self.buffer.write();
85            if buf.len() >= self.max_buffered_events {
86                buf.pop_front();
87            }
88            buf.push_back(event.clone());
89        }
90
91        // Deliver to live subscribers (errors are expected when no subscriber
92        // is registered yet — ignore them).
93        let _ = self.tx.send(event);
94    }
95
96    /// Return all buffered events whose `seq_id` is strictly greater than
97    /// `last_seq_id`, in ascending order. Used to replay missed events for
98    /// reconnecting clients.
99    pub fn get_events_after(&self, last_seq_id: u64) -> Vec<RealtimeEvent> {
100        self.buffer
101            .read()
102            .iter()
103            .filter(|e| e.seq_id.is_some_and(|id| id > last_seq_id))
104            .cloned()
105            .collect()
106    }
107
108    /// Return the current value of the sequence counter (next ID to be issued).
109    /// Mainly useful for tests.
110    pub fn current_seq(&self) -> u64 {
111        self.next_seq_id.load(Ordering::Relaxed)
112    }
113
114    /// Get number of connected clients
115    pub fn client_count(&self) -> usize {
116        self.clients.read().len()
117    }
118
119    /// Subscribe to live events
120    pub fn subscribe(&self) -> broadcast::Receiver<RealtimeEvent> {
121        self.tx.subscribe()
122    }
123
124    /// Register a new client
125    pub fn register_client(&self, id: ConnectionId, filter: SubscriptionFilter) {
126        self.clients.write().insert(id, filter);
127    }
128
129    /// Unregister a client
130    pub fn unregister_client(&self, id: &str) {
131        self.clients.write().remove(id);
132    }
133
134    /// Get client filter
135    pub fn get_client_filter(&self, id: &str) -> Option<SubscriptionFilter> {
136        self.clients.read().get(id).cloned()
137    }
138}
139
140impl Default for RealtimeManager {
141    fn default() -> Self {
142        Self::new()
143    }
144}
145
146impl Clone for RealtimeManager {
147    fn clone(&self) -> Self {
148        Self {
149            tx: self.tx.clone(),
150            clients: self.clients.clone(),
151            next_seq_id: self.next_seq_id.clone(),
152            buffer: self.buffer.clone(),
153            max_buffered_events: self.max_buffered_events,
154        }
155    }
156}
157
158/// WebSocket server
159pub struct RealtimeServer {
160    manager: RealtimeManager,
161    addr: SocketAddr,
162}
163
164impl RealtimeServer {
165    /// Create a new WebSocket server
166    pub fn new(manager: RealtimeManager, port: u16) -> Self {
167        let addr = SocketAddr::from(([0, 0, 0, 0], port));
168        Self { manager, addr }
169    }
170
171    /// Build the router
172    pub fn router(manager: RealtimeManager) -> Router {
173        Router::new()
174            .route("/ws", get(ws_handler))
175            .route("/health", get(health_handler))
176            .with_state(manager)
177    }
178
179    /// Start the server
180    pub async fn start(self) -> std::io::Result<()> {
181        let app = Self::router(self.manager);
182
183        tracing::info!("WebSocket server listening on {}", self.addr);
184
185        let listener = tokio::net::TcpListener::bind(self.addr).await?;
186        axum::serve(listener, app).await?;
187
188        Ok(())
189    }
190}
191
192/// Health check endpoint
193async fn health_handler(State(manager): State<RealtimeManager>) -> impl IntoResponse {
194    serde_json::json!({
195        "status": "ok",
196        "clients": manager.client_count(),
197    })
198    .to_string()
199}
200
201/// WebSocket upgrade handler
202async fn ws_handler(
203    ws: WebSocketUpgrade,
204    State(manager): State<RealtimeManager>,
205) -> impl IntoResponse {
206    ws.on_upgrade(move |socket| handle_socket(socket, manager))
207}
208
209/// Handle an individual WebSocket connection
210async fn handle_socket(socket: WebSocket, manager: RealtimeManager) {
211    let connection_id = Uuid::new_v4().to_string();
212    let filter = SubscriptionFilter::default();
213
214    manager.register_client(connection_id.clone(), filter.clone());
215    tracing::info!("Client connected: {}", connection_id);
216
217    let (mut sender, mut receiver) = socket.split();
218    let mut rx = manager.subscribe();
219
220    // Task to forward events to client
221    let conn_id = connection_id.clone();
222    let mgr = manager.clone();
223    let send_task = tokio::spawn(async move {
224        while let Ok(event) = rx.recv().await {
225            // Check if event matches client's filter
226            if let Some(filter) = mgr.get_client_filter(&conn_id) {
227                if filter.matches(&event) {
228                    let json = serde_json::to_string(&event).unwrap_or_default();
229                    if sender.send(Message::Text(json)).await.is_err() {
230                        break;
231                    }
232                }
233            }
234        }
235    });
236
237    // Task to handle incoming messages from client
238    let conn_id = connection_id.clone();
239    let mgr = manager.clone();
240    let recv_task = tokio::spawn(async move {
241        while let Some(Ok(msg)) = receiver.next().await {
242            match msg {
243                Message::Text(text) => {
244                    // Try to parse as filter update
245                    if let Ok(new_filter) = serde_json::from_str::<SubscriptionFilter>(&text) {
246                        mgr.register_client(conn_id.clone(), new_filter);
247                        tracing::debug!("Updated filter for client {}", conn_id);
248                    }
249                }
250                Message::Close(_) => {
251                    break;
252                }
253                _ => {}
254            }
255        }
256    });
257
258    // Wait for either task to finish
259    tokio::select! {
260        _ = send_task => {}
261        _ = recv_task => {}
262    }
263
264    manager.unregister_client(&connection_id);
265    tracing::info!("Client disconnected: {}", connection_id);
266}
267
268#[cfg(test)]
269mod tests {
270    use super::*;
271
272    #[test]
273    fn test_realtime_manager() {
274        let manager = RealtimeManager::new();
275        assert_eq!(manager.client_count(), 0);
276
277        manager.register_client("test".to_string(), SubscriptionFilter::default());
278        assert_eq!(manager.client_count(), 1);
279
280        manager.unregister_client("test");
281        assert_eq!(manager.client_count(), 0);
282    }
283
284    #[test]
285    fn test_subscription_filter() {
286        let filter = SubscriptionFilter {
287            event_types: Some(vec![super::super::events::EventType::MemoryCreated]),
288            memory_ids: None,
289            tags: None,
290        };
291
292        let event = RealtimeEvent::memory_created(1, "test".to_string());
293        assert!(filter.matches(&event));
294
295        let event = RealtimeEvent::memory_deleted(1);
296        assert!(!filter.matches(&event));
297    }
298
299    // --- Sequential event ID tests ------------------------------------------
300
301    #[test]
302    fn test_broadcast_stamps_sequential_ids() {
303        let manager = RealtimeManager::new();
304        let _rx = manager.subscribe(); // keep channel alive
305
306        manager.broadcast(RealtimeEvent::memory_created(1, "first".to_string()));
307        manager.broadcast(RealtimeEvent::memory_created(2, "second".to_string()));
308        manager.broadcast(RealtimeEvent::memory_deleted(3));
309
310        // IDs should be 1, 2, 3 (counter starts at 1)
311        let buf = manager.buffer.read();
312        let ids: Vec<u64> = buf.iter().filter_map(|e| e.seq_id).collect();
313        assert_eq!(ids, vec![1, 2, 3]);
314    }
315
316    #[test]
317    fn test_seq_id_starts_at_one() {
318        let manager = RealtimeManager::new();
319        assert_eq!(manager.current_seq(), 1);
320
321        let _rx = manager.subscribe();
322        manager.broadcast(RealtimeEvent::memory_created(1, "hello".to_string()));
323        assert_eq!(manager.current_seq(), 2); // next id to be issued
324    }
325
326    // --- Ring buffer eviction tests -----------------------------------------
327
328    #[test]
329    fn test_ring_buffer_evicts_oldest_when_full() {
330        let max = 3;
331        let manager = RealtimeManager::with_buffer_size(max);
332        let _rx = manager.subscribe();
333
334        for i in 1..=5u64 {
335            manager.broadcast(RealtimeEvent::memory_created(i as i64, format!("m{i}")));
336        }
337
338        let buf = manager.buffer.read();
339        assert_eq!(buf.len(), max, "buffer should be at capacity");
340        // The first two events (seq 1, 2) should have been evicted
341        let ids: Vec<u64> = buf.iter().filter_map(|e| e.seq_id).collect();
342        assert_eq!(ids, vec![3, 4, 5]);
343    }
344
345    #[test]
346    fn test_ring_buffer_does_not_exceed_max_size() {
347        let max = 10;
348        let manager = RealtimeManager::with_buffer_size(max);
349        let _rx = manager.subscribe();
350
351        for i in 1..=20u64 {
352            manager.broadcast(RealtimeEvent::memory_deleted(i as i64));
353        }
354
355        assert_eq!(manager.buffer.read().len(), max);
356    }
357
358    // --- Replay / get_events_after tests ------------------------------------
359
360    #[test]
361    fn test_get_events_after_returns_correct_subset() {
362        let manager = RealtimeManager::new();
363        let _rx = manager.subscribe();
364
365        manager.broadcast(RealtimeEvent::memory_created(1, "a".to_string())); // seq 1
366        manager.broadcast(RealtimeEvent::memory_created(2, "b".to_string())); // seq 2
367        manager.broadcast(RealtimeEvent::memory_deleted(3)); // seq 3
368
369        let replayed = manager.get_events_after(1);
370        assert_eq!(replayed.len(), 2);
371        let ids: Vec<u64> = replayed.iter().filter_map(|e| e.seq_id).collect();
372        assert_eq!(ids, vec![2, 3]);
373    }
374
375    #[test]
376    fn test_get_events_after_zero_returns_all() {
377        let manager = RealtimeManager::new();
378        let _rx = manager.subscribe();
379
380        manager.broadcast(RealtimeEvent::memory_created(1, "x".to_string()));
381        manager.broadcast(RealtimeEvent::memory_created(2, "y".to_string()));
382
383        let replayed = manager.get_events_after(0);
384        assert_eq!(replayed.len(), 2);
385    }
386
387    #[test]
388    fn test_get_events_after_last_id_returns_empty() {
389        let manager = RealtimeManager::new();
390        let _rx = manager.subscribe();
391
392        manager.broadcast(RealtimeEvent::memory_created(1, "only".to_string())); // seq 1
393
394        // Requesting events after the last known ID → nothing new
395        let replayed = manager.get_events_after(1);
396        assert!(replayed.is_empty());
397    }
398
399    #[test]
400    fn test_get_events_after_large_id_returns_empty() {
401        let manager = RealtimeManager::new();
402        let _rx = manager.subscribe();
403
404        manager.broadcast(RealtimeEvent::memory_created(1, "ev".to_string()));
405
406        let replayed = manager.get_events_after(9999);
407        assert!(replayed.is_empty());
408    }
409
410    // --- Clone shares same state --------------------------------------------
411
412    #[test]
413    fn test_clone_shares_buffer() {
414        let manager = RealtimeManager::new();
415        let cloned = manager.clone();
416        let _rx = manager.subscribe();
417
418        manager.broadcast(RealtimeEvent::memory_created(1, "shared".to_string()));
419
420        // cloned should see the same buffer
421        assert_eq!(cloned.buffer.read().len(), 1);
422        let replayed = cloned.get_events_after(0);
423        assert_eq!(replayed.len(), 1);
424    }
425}