Skip to main content

engram/realtime/
server.rs

1//! WebSocket server for real-time updates
2
3use std::collections::HashMap;
4use std::net::SocketAddr;
5use std::sync::Arc;
6
7use axum::{
8    extract::{
9        ws::{Message, WebSocket, WebSocketUpgrade},
10        State,
11    },
12    response::IntoResponse,
13    routing::get,
14    Router,
15};
16use futures::{SinkExt, StreamExt};
17use parking_lot::RwLock;
18use tokio::sync::broadcast;
19use uuid::Uuid;
20
21use super::events::{RealtimeEvent, SubscriptionFilter};
22
23/// Connection ID
24pub type ConnectionId = String;
25
26/// Manages WebSocket connections
27pub struct RealtimeManager {
28    /// Broadcast channel for events
29    tx: broadcast::Sender<RealtimeEvent>,
30    /// Connected clients with their filters
31    clients: Arc<RwLock<HashMap<ConnectionId, SubscriptionFilter>>>,
32}
33
34impl RealtimeManager {
35    /// Create a new realtime manager
36    pub fn new() -> Self {
37        let (tx, _) = broadcast::channel(1000);
38        Self {
39            tx,
40            clients: Arc::new(RwLock::new(HashMap::new())),
41        }
42    }
43
44    /// Broadcast an event to all matching clients
45    pub fn broadcast(&self, event: RealtimeEvent) {
46        // The broadcast channel handles delivery to all subscribers
47        let _ = self.tx.send(event);
48    }
49
50    /// Get number of connected clients
51    pub fn client_count(&self) -> usize {
52        self.clients.read().len()
53    }
54
55    /// Subscribe to events
56    pub fn subscribe(&self) -> broadcast::Receiver<RealtimeEvent> {
57        self.tx.subscribe()
58    }
59
60    /// Register a new client
61    pub fn register_client(&self, id: ConnectionId, filter: SubscriptionFilter) {
62        self.clients.write().insert(id, filter);
63    }
64
65    /// Unregister a client
66    pub fn unregister_client(&self, id: &str) {
67        self.clients.write().remove(id);
68    }
69
70    /// Get client filter
71    pub fn get_client_filter(&self, id: &str) -> Option<SubscriptionFilter> {
72        self.clients.read().get(id).cloned()
73    }
74}
75
76impl Default for RealtimeManager {
77    fn default() -> Self {
78        Self::new()
79    }
80}
81
82impl Clone for RealtimeManager {
83    fn clone(&self) -> Self {
84        Self {
85            tx: self.tx.clone(),
86            clients: self.clients.clone(),
87        }
88    }
89}
90
91/// WebSocket server
92pub struct RealtimeServer {
93    manager: RealtimeManager,
94    addr: SocketAddr,
95}
96
97impl RealtimeServer {
98    /// Create a new WebSocket server
99    pub fn new(manager: RealtimeManager, port: u16) -> Self {
100        let addr = SocketAddr::from(([0, 0, 0, 0], port));
101        Self { manager, addr }
102    }
103
104    /// Build the router
105    pub fn router(manager: RealtimeManager) -> Router {
106        Router::new()
107            .route("/ws", get(ws_handler))
108            .route("/health", get(health_handler))
109            .with_state(manager)
110    }
111
112    /// Start the server
113    pub async fn start(self) -> std::io::Result<()> {
114        let app = Self::router(self.manager);
115
116        tracing::info!("WebSocket server listening on {}", self.addr);
117
118        let listener = tokio::net::TcpListener::bind(self.addr).await?;
119        axum::serve(listener, app).await?;
120
121        Ok(())
122    }
123}
124
125/// Health check endpoint
126async fn health_handler(State(manager): State<RealtimeManager>) -> impl IntoResponse {
127    serde_json::json!({
128        "status": "ok",
129        "clients": manager.client_count(),
130    })
131    .to_string()
132}
133
134/// WebSocket upgrade handler
135async fn ws_handler(
136    ws: WebSocketUpgrade,
137    State(manager): State<RealtimeManager>,
138) -> impl IntoResponse {
139    ws.on_upgrade(move |socket| handle_socket(socket, manager))
140}
141
142/// Handle an individual WebSocket connection
143async fn handle_socket(socket: WebSocket, manager: RealtimeManager) {
144    let connection_id = Uuid::new_v4().to_string();
145    let filter = SubscriptionFilter::default();
146
147    manager.register_client(connection_id.clone(), filter.clone());
148    tracing::info!("Client connected: {}", connection_id);
149
150    let (mut sender, mut receiver) = socket.split();
151    let mut rx = manager.subscribe();
152
153    // Task to forward events to client
154    let conn_id = connection_id.clone();
155    let mgr = manager.clone();
156    let send_task = tokio::spawn(async move {
157        while let Ok(event) = rx.recv().await {
158            // Check if event matches client's filter
159            if let Some(filter) = mgr.get_client_filter(&conn_id) {
160                if filter.matches(&event) {
161                    let json = serde_json::to_string(&event).unwrap_or_default();
162                    if sender.send(Message::Text(json)).await.is_err() {
163                        break;
164                    }
165                }
166            }
167        }
168    });
169
170    // Task to handle incoming messages from client
171    let conn_id = connection_id.clone();
172    let mgr = manager.clone();
173    let recv_task = tokio::spawn(async move {
174        while let Some(Ok(msg)) = receiver.next().await {
175            match msg {
176                Message::Text(text) => {
177                    // Try to parse as filter update
178                    if let Ok(new_filter) = serde_json::from_str::<SubscriptionFilter>(&text) {
179                        mgr.register_client(conn_id.clone(), new_filter);
180                        tracing::debug!("Updated filter for client {}", conn_id);
181                    }
182                }
183                Message::Close(_) => {
184                    break;
185                }
186                _ => {}
187            }
188        }
189    });
190
191    // Wait for either task to finish
192    tokio::select! {
193        _ = send_task => {}
194        _ = recv_task => {}
195    }
196
197    manager.unregister_client(&connection_id);
198    tracing::info!("Client disconnected: {}", connection_id);
199}
200
201#[cfg(test)]
202mod tests {
203    use super::*;
204
205    #[test]
206    fn test_realtime_manager() {
207        let manager = RealtimeManager::new();
208        assert_eq!(manager.client_count(), 0);
209
210        manager.register_client("test".to_string(), SubscriptionFilter::default());
211        assert_eq!(manager.client_count(), 1);
212
213        manager.unregister_client("test");
214        assert_eq!(manager.client_count(), 0);
215    }
216
217    #[test]
218    fn test_subscription_filter() {
219        let filter = SubscriptionFilter {
220            event_types: Some(vec![super::super::events::EventType::MemoryCreated]),
221            memory_ids: None,
222            tags: None,
223        };
224
225        let event = RealtimeEvent::memory_created(1, "test".to_string());
226        assert!(filter.matches(&event));
227
228        let event = RealtimeEvent::memory_deleted(1);
229        assert!(!filter.matches(&event));
230    }
231}