engram/realtime/
server.rs1use std::collections::HashMap;
4use std::net::SocketAddr;
5use std::sync::Arc;
6
7use axum::{
8 extract::{
9 ws::{Message, WebSocket, WebSocketUpgrade},
10 State,
11 },
12 response::IntoResponse,
13 routing::get,
14 Router,
15};
16use futures::{SinkExt, StreamExt};
17use parking_lot::RwLock;
18use tokio::sync::broadcast;
19use uuid::Uuid;
20
21use super::events::{RealtimeEvent, SubscriptionFilter};
22
23pub type ConnectionId = String;
25
26pub struct RealtimeManager {
28 tx: broadcast::Sender<RealtimeEvent>,
30 clients: Arc<RwLock<HashMap<ConnectionId, SubscriptionFilter>>>,
32}
33
34impl RealtimeManager {
35 pub fn new() -> Self {
37 let (tx, _) = broadcast::channel(1000);
38 Self {
39 tx,
40 clients: Arc::new(RwLock::new(HashMap::new())),
41 }
42 }
43
44 pub fn broadcast(&self, event: RealtimeEvent) {
46 let _ = self.tx.send(event);
48 }
49
50 pub fn client_count(&self) -> usize {
52 self.clients.read().len()
53 }
54
55 pub fn subscribe(&self) -> broadcast::Receiver<RealtimeEvent> {
57 self.tx.subscribe()
58 }
59
60 pub fn register_client(&self, id: ConnectionId, filter: SubscriptionFilter) {
62 self.clients.write().insert(id, filter);
63 }
64
65 pub fn unregister_client(&self, id: &str) {
67 self.clients.write().remove(id);
68 }
69
70 pub fn get_client_filter(&self, id: &str) -> Option<SubscriptionFilter> {
72 self.clients.read().get(id).cloned()
73 }
74}
75
76impl Default for RealtimeManager {
77 fn default() -> Self {
78 Self::new()
79 }
80}
81
82impl Clone for RealtimeManager {
83 fn clone(&self) -> Self {
84 Self {
85 tx: self.tx.clone(),
86 clients: self.clients.clone(),
87 }
88 }
89}
90
91pub struct RealtimeServer {
93 manager: RealtimeManager,
94 addr: SocketAddr,
95}
96
97impl RealtimeServer {
98 pub fn new(manager: RealtimeManager, port: u16) -> Self {
100 let addr = SocketAddr::from(([0, 0, 0, 0], port));
101 Self { manager, addr }
102 }
103
104 pub fn router(manager: RealtimeManager) -> Router {
106 Router::new()
107 .route("/ws", get(ws_handler))
108 .route("/health", get(health_handler))
109 .with_state(manager)
110 }
111
112 pub async fn start(self) -> std::io::Result<()> {
114 let app = Self::router(self.manager);
115
116 tracing::info!("WebSocket server listening on {}", self.addr);
117
118 let listener = tokio::net::TcpListener::bind(self.addr).await?;
119 axum::serve(listener, app).await?;
120
121 Ok(())
122 }
123}
124
125async fn health_handler(State(manager): State<RealtimeManager>) -> impl IntoResponse {
127 serde_json::json!({
128 "status": "ok",
129 "clients": manager.client_count(),
130 })
131 .to_string()
132}
133
134async fn ws_handler(
136 ws: WebSocketUpgrade,
137 State(manager): State<RealtimeManager>,
138) -> impl IntoResponse {
139 ws.on_upgrade(move |socket| handle_socket(socket, manager))
140}
141
142async fn handle_socket(socket: WebSocket, manager: RealtimeManager) {
144 let connection_id = Uuid::new_v4().to_string();
145 let filter = SubscriptionFilter::default();
146
147 manager.register_client(connection_id.clone(), filter.clone());
148 tracing::info!("Client connected: {}", connection_id);
149
150 let (mut sender, mut receiver) = socket.split();
151 let mut rx = manager.subscribe();
152
153 let conn_id = connection_id.clone();
155 let mgr = manager.clone();
156 let send_task = tokio::spawn(async move {
157 while let Ok(event) = rx.recv().await {
158 if let Some(filter) = mgr.get_client_filter(&conn_id) {
160 if filter.matches(&event) {
161 let json = serde_json::to_string(&event).unwrap_or_default();
162 if sender.send(Message::Text(json)).await.is_err() {
163 break;
164 }
165 }
166 }
167 }
168 });
169
170 let conn_id = connection_id.clone();
172 let mgr = manager.clone();
173 let recv_task = tokio::spawn(async move {
174 while let Some(Ok(msg)) = receiver.next().await {
175 match msg {
176 Message::Text(text) => {
177 if let Ok(new_filter) = serde_json::from_str::<SubscriptionFilter>(&text) {
179 mgr.register_client(conn_id.clone(), new_filter);
180 tracing::debug!("Updated filter for client {}", conn_id);
181 }
182 }
183 Message::Close(_) => {
184 break;
185 }
186 _ => {}
187 }
188 }
189 });
190
191 tokio::select! {
193 _ = send_task => {}
194 _ = recv_task => {}
195 }
196
197 manager.unregister_client(&connection_id);
198 tracing::info!("Client disconnected: {}", connection_id);
199}
200
201#[cfg(test)]
202mod tests {
203 use super::*;
204
205 #[test]
206 fn test_realtime_manager() {
207 let manager = RealtimeManager::new();
208 assert_eq!(manager.client_count(), 0);
209
210 manager.register_client("test".to_string(), SubscriptionFilter::default());
211 assert_eq!(manager.client_count(), 1);
212
213 manager.unregister_client("test");
214 assert_eq!(manager.client_count(), 0);
215 }
216
217 #[test]
218 fn test_subscription_filter() {
219 let filter = SubscriptionFilter {
220 event_types: Some(vec![super::super::events::EventType::MemoryCreated]),
221 memory_ids: None,
222 tags: None,
223 };
224
225 let event = RealtimeEvent::memory_created(1, "test".to_string());
226 assert!(filter.matches(&event));
227
228 let event = RealtimeEvent::memory_deleted(1);
229 assert!(!filter.matches(&event));
230 }
231}