hyperstack_server/websocket/
server.rs

1use crate::bus::BusManager;
2use crate::cache::EntityCache;
3use crate::view::ViewIndex;
4use crate::websocket::client_manager::ClientManager;
5use crate::websocket::frame::{Frame, Mode};
6use crate::websocket::subscription::Subscription;
7use anyhow::Result;
8use bytes::Bytes;
9use futures_util::StreamExt;
10use std::net::SocketAddr;
11use std::sync::Arc;
12#[cfg(feature = "otel")]
13use std::time::Instant;
14
15use tokio::net::{TcpListener, TcpStream};
16use tokio_tungstenite::accept_async;
17use tracing::{debug, error, info, warn};
18use uuid::Uuid;
19
20#[cfg(feature = "otel")]
21use crate::metrics::Metrics;
22
23pub struct WebSocketServer {
24    bind_addr: SocketAddr,
25    client_manager: ClientManager,
26    bus_manager: BusManager,
27    entity_cache: EntityCache,
28    view_index: Arc<ViewIndex>,
29    max_clients: usize,
30    #[cfg(feature = "otel")]
31    metrics: Option<Arc<Metrics>>,
32}
33
34impl WebSocketServer {
35    #[cfg(feature = "otel")]
36    pub fn new(
37        bind_addr: SocketAddr,
38        bus_manager: BusManager,
39        entity_cache: EntityCache,
40        view_index: Arc<ViewIndex>,
41        metrics: Option<Arc<Metrics>>,
42    ) -> Self {
43        Self {
44            bind_addr,
45            client_manager: ClientManager::new(),
46            bus_manager,
47            entity_cache,
48            view_index,
49            max_clients: 10000,
50            metrics,
51        }
52    }
53
54    #[cfg(not(feature = "otel"))]
55    pub fn new(
56        bind_addr: SocketAddr,
57        bus_manager: BusManager,
58        entity_cache: EntityCache,
59        view_index: Arc<ViewIndex>,
60    ) -> Self {
61        Self {
62            bind_addr,
63            client_manager: ClientManager::new(),
64            bus_manager,
65            entity_cache,
66            view_index,
67            max_clients: 10000,
68        }
69    }
70
71    pub fn with_max_clients(mut self, max_clients: usize) -> Self {
72        self.max_clients = max_clients;
73        self
74    }
75
76    pub async fn start(self) -> Result<()> {
77        info!(
78            "Starting WebSocket server on {} (max_clients: {})",
79            self.bind_addr, self.max_clients
80        );
81
82        let listener = TcpListener::bind(&self.bind_addr).await?;
83        info!("WebSocket server listening on {}", self.bind_addr);
84
85        // Start cleanup task
86        self.client_manager.start_cleanup_task().await;
87
88        // Accept incoming connections
89        loop {
90            match listener.accept().await {
91                Ok((stream, addr)) => {
92                    // Check if we've reached the maximum number of clients
93                    let client_count = self.client_manager.client_count().await;
94                    if client_count >= self.max_clients {
95                        warn!(
96                            "Rejecting connection from {} - max clients ({}) reached",
97                            addr, self.max_clients
98                        );
99                        // Accept the connection but immediately close it
100                        if let Ok(mut ws_stream) = accept_async(stream).await {
101                            let _ = ws_stream.close(None).await;
102                        }
103                        continue;
104                    }
105
106                    // Record connection metric
107                    #[cfg(feature = "otel")]
108                    if let Some(ref metrics) = self.metrics {
109                        metrics.record_ws_connection();
110                    }
111
112                    info!(
113                        "New WebSocket connection from {} ({}/{} clients)",
114                        addr,
115                        client_count + 1,
116                        self.max_clients
117                    );
118                    let client_manager = self.client_manager.clone();
119                    let bus_manager = self.bus_manager.clone();
120                    let entity_cache = self.entity_cache.clone();
121                    let view_index = self.view_index.clone();
122                    #[cfg(feature = "otel")]
123                    let metrics = self.metrics.clone();
124
125                    tokio::spawn(async move {
126                        #[cfg(feature = "otel")]
127                        let result = handle_connection(
128                            stream,
129                            client_manager,
130                            bus_manager,
131                            entity_cache,
132                            view_index,
133                            metrics,
134                        )
135                        .await;
136                        #[cfg(not(feature = "otel"))]
137                        let result = handle_connection(
138                            stream,
139                            client_manager,
140                            bus_manager,
141                            entity_cache,
142                            view_index,
143                        )
144                        .await;
145
146                        if let Err(e) = result {
147                            error!("WebSocket connection error: {}", e);
148                        }
149                    });
150                }
151                Err(e) => {
152                    error!("Failed to accept connection: {}", e);
153                }
154            }
155        }
156    }
157}
158
159#[cfg(feature = "otel")]
160async fn handle_connection(
161    stream: TcpStream,
162    client_manager: ClientManager,
163    bus_manager: BusManager,
164    entity_cache: EntityCache,
165    view_index: Arc<ViewIndex>,
166    metrics: Option<Arc<Metrics>>,
167) -> Result<()> {
168    let ws_stream = accept_async(stream).await?;
169    let client_id = Uuid::new_v4();
170    let connection_start = Instant::now();
171
172    info!("WebSocket connection established for client {}", client_id);
173
174    let (ws_sender, mut ws_receiver) = ws_stream.split();
175
176    client_manager.add_client(client_id, ws_sender).await?;
177
178    let mut active_subscriptions: Vec<String> = Vec::new();
179
180    loop {
181        tokio::select! {
182            ws_msg = ws_receiver.next() => {
183                match ws_msg {
184                    Some(Ok(msg)) => {
185                        if msg.is_close() {
186                            info!("Client {} requested close", client_id);
187                            break;
188                        }
189
190                        client_manager.update_client_last_seen(client_id).await;
191
192                        if msg.is_text() {
193                            if let Some(ref m) = metrics {
194                                m.record_ws_message_received();
195                            }
196
197                            if let Ok(text) = msg.to_text() {
198                                debug!("Received text message from client {}: {}", client_id, text);
199
200                                if let Ok(subscription) = serde_json::from_str::<Subscription>(text) {
201                                    let view_id = subscription.view.clone();
202                                    client_manager.update_subscription(client_id, subscription.clone()).await;
203
204                                    if let Some(ref m) = metrics {
205                                        m.record_subscription_created(&view_id);
206                                    }
207                                    active_subscriptions.push(view_id);
208
209                                    attach_client_to_bus(
210                                        client_id,
211                                        subscription,
212                                        &client_manager,
213                                        &bus_manager,
214                                        &entity_cache,
215                                        &view_index,
216                                        metrics.clone(),
217                                    ).await;
218                                } else {
219                                    debug!("Received non-subscription message from client {}: {}", client_id, text);
220                                }
221                            }
222                        }
223                    }
224                    Some(Err(e)) => {
225                        warn!("WebSocket error for client {}: {}", client_id, e);
226                        break;
227                    }
228                    None => {
229                        debug!("WebSocket stream ended for client {}", client_id);
230                        break;
231                    }
232                }
233            }
234        }
235    }
236
237    client_manager.remove_client(client_id).await;
238
239    if let Some(ref m) = metrics {
240        let duration_secs = connection_start.elapsed().as_secs_f64();
241        m.record_ws_disconnection(duration_secs);
242
243        for view_id in active_subscriptions {
244            m.record_subscription_removed(&view_id);
245        }
246    }
247
248    info!("Client {} disconnected", client_id);
249
250    Ok(())
251}
252
253#[cfg(not(feature = "otel"))]
254async fn handle_connection(
255    stream: TcpStream,
256    client_manager: ClientManager,
257    bus_manager: BusManager,
258    entity_cache: EntityCache,
259    view_index: Arc<ViewIndex>,
260) -> Result<()> {
261    let ws_stream = accept_async(stream).await?;
262    let client_id = Uuid::new_v4();
263
264    info!("WebSocket connection established for client {}", client_id);
265
266    let (ws_sender, mut ws_receiver) = ws_stream.split();
267
268    client_manager.add_client(client_id, ws_sender).await?;
269
270    loop {
271        tokio::select! {
272            ws_msg = ws_receiver.next() => {
273                match ws_msg {
274                    Some(Ok(msg)) => {
275                        if msg.is_close() {
276                            info!("Client {} requested close", client_id);
277                            break;
278                        }
279
280                        client_manager.update_client_last_seen(client_id).await;
281
282                        if msg.is_text() {
283                            if let Ok(text) = msg.to_text() {
284                                debug!("Received text message from client {}: {}", client_id, text);
285
286                                if let Ok(subscription) = serde_json::from_str::<Subscription>(text) {
287                                    client_manager.update_subscription(client_id, subscription.clone()).await;
288
289                                    attach_client_to_bus(
290                                        client_id,
291                                        subscription,
292                                        &client_manager,
293                                        &bus_manager,
294                                        &entity_cache,
295                                        &view_index,
296                                    ).await;
297                                } else {
298                                    debug!("Received non-subscription message from client {}: {}", client_id, text);
299                                }
300                            }
301                        }
302                    }
303                    Some(Err(e)) => {
304                        warn!("WebSocket error for client {}: {}", client_id, e);
305                        break;
306                    }
307                    None => {
308                        debug!("WebSocket stream ended for client {}", client_id);
309                        break;
310                    }
311                }
312            }
313        }
314    }
315
316    client_manager.remove_client(client_id).await;
317    info!("Client {} disconnected", client_id);
318
319    Ok(())
320}
321
322#[cfg(feature = "otel")]
323async fn attach_client_to_bus(
324    client_id: Uuid,
325    subscription: Subscription,
326    client_manager: &ClientManager,
327    bus_manager: &BusManager,
328    entity_cache: &EntityCache,
329    view_index: &ViewIndex,
330    metrics: Option<Arc<Metrics>>,
331) {
332    let view_id = &subscription.view;
333
334    let view_spec = match view_index.get_view(view_id) {
335        Some(spec) => spec,
336        None => {
337            warn!("Unknown view ID: {}", view_id);
338            return;
339        }
340    };
341
342    match view_spec.mode {
343        Mode::State => {
344            let key = subscription.key.as_deref().unwrap_or("");
345            let mut rx = bus_manager.get_or_create_state_bus(view_id, key).await;
346
347            if !rx.borrow().is_empty() {
348                let data = rx.borrow().clone();
349                let _ = client_manager.send_to_client(client_id, data).await;
350                if let Some(ref m) = metrics {
351                    m.record_ws_message_sent();
352                }
353            }
354
355            let client_mgr = client_manager.clone();
356            let metrics_clone = metrics.clone();
357            tokio::spawn(async move {
358                while rx.changed().await.is_ok() {
359                    let data = rx.borrow().clone();
360                    if client_mgr.send_to_client(client_id, data).await.is_err() {
361                        break;
362                    }
363                    if let Some(ref m) = metrics_clone {
364                        m.record_ws_message_sent();
365                    }
366                }
367            });
368        }
369        Mode::List | Mode::Append => {
370            let mut rx = bus_manager.get_or_create_list_bus(view_id).await;
371
372            let snapshots = entity_cache.get_all(view_id).await;
373            for (key, entity) in snapshots {
374                if subscription.matches_key(&key) {
375                    let frame = Frame {
376                        mode: view_spec.mode,
377                        export: view_id.clone(),
378                        op: "upsert",
379                        key,
380                        data: entity,
381                    };
382                    if let Ok(payload) = serde_json::to_vec(&frame) {
383                        let _ = client_manager
384                            .send_to_client(client_id, Arc::new(Bytes::from(payload)))
385                            .await;
386                        if let Some(ref m) = metrics {
387                            m.record_ws_message_sent();
388                        }
389                    }
390                }
391            }
392
393            let client_mgr = client_manager.clone();
394            let sub = subscription.clone();
395            let metrics_clone = metrics.clone();
396            tokio::spawn(async move {
397                while let Ok(envelope) = rx.recv().await {
398                    if sub.matches(&envelope.entity, &envelope.key) {
399                        if client_mgr
400                            .send_to_client(client_id, envelope.payload.clone())
401                            .await
402                            .is_err()
403                        {
404                            break;
405                        }
406                        if let Some(ref m) = metrics_clone {
407                            m.record_ws_message_sent();
408                        }
409                    }
410                }
411            });
412        }
413    }
414
415    info!(
416        "Client {} subscribed to {} (mode: {:?})",
417        client_id, view_id, view_spec.mode
418    );
419}
420
421#[cfg(not(feature = "otel"))]
422async fn attach_client_to_bus(
423    client_id: Uuid,
424    subscription: Subscription,
425    client_manager: &ClientManager,
426    bus_manager: &BusManager,
427    entity_cache: &EntityCache,
428    view_index: &ViewIndex,
429) {
430    let view_id = &subscription.view;
431
432    let view_spec = match view_index.get_view(view_id) {
433        Some(spec) => spec,
434        None => {
435            warn!("Unknown view ID: {}", view_id);
436            return;
437        }
438    };
439
440    match view_spec.mode {
441        Mode::State => {
442            let key = subscription.key.as_deref().unwrap_or("");
443            let mut rx = bus_manager.get_or_create_state_bus(view_id, key).await;
444
445            if !rx.borrow().is_empty() {
446                let data = rx.borrow().clone();
447                let _ = client_manager.send_to_client(client_id, data).await;
448            }
449
450            let client_mgr = client_manager.clone();
451            tokio::spawn(async move {
452                while rx.changed().await.is_ok() {
453                    let data = rx.borrow().clone();
454                    if client_mgr.send_to_client(client_id, data).await.is_err() {
455                        break;
456                    }
457                }
458            });
459        }
460        Mode::List | Mode::Append => {
461            let mut rx = bus_manager.get_or_create_list_bus(view_id).await;
462
463            let snapshots = entity_cache.get_all(view_id).await;
464            for (key, entity) in snapshots {
465                if subscription.matches_key(&key) {
466                    let frame = Frame {
467                        mode: view_spec.mode,
468                        export: view_id.clone(),
469                        op: "upsert",
470                        key,
471                        data: entity,
472                    };
473                    if let Ok(payload) = serde_json::to_vec(&frame) {
474                        let _ = client_manager
475                            .send_to_client(client_id, Arc::new(Bytes::from(payload)))
476                            .await;
477                    }
478                }
479            }
480
481            let client_mgr = client_manager.clone();
482            let sub = subscription.clone();
483            tokio::spawn(async move {
484                while let Ok(envelope) = rx.recv().await {
485                    if sub.matches(&envelope.entity, &envelope.key)
486                        && client_mgr
487                            .send_to_client(client_id, envelope.payload.clone())
488                            .await
489                            .is_err()
490                    {
491                        break;
492                    }
493                }
494            });
495        }
496    }
497
498    info!(
499        "Client {} subscribed to {} (mode: {:?})",
500        client_id, view_id, view_spec.mode
501    );
502}