hyperstack_server/websocket/
server.rs

1use crate::bus::BusManager;
2use crate::cache::{EntityCache, SnapshotBatchConfig};
3use crate::compression::maybe_compress;
4use crate::view::ViewIndex;
5use crate::websocket::client_manager::ClientManager;
6use crate::websocket::frame::{Mode, SnapshotEntity, SnapshotFrame};
7use crate::websocket::subscription::{ClientMessage, Subscription};
8use anyhow::Result;
9use futures_util::StreamExt;
10use std::net::SocketAddr;
11use std::sync::Arc;
12#[cfg(feature = "otel")]
13use std::time::Instant;
14
15use tokio::net::{TcpListener, TcpStream};
16use tokio_tungstenite::accept_async;
17use tokio_util::sync::CancellationToken;
18use tracing::{debug, error, info, info_span, warn, Instrument};
19use uuid::Uuid;
20
21#[cfg(feature = "otel")]
22use crate::metrics::Metrics;
23
24struct SubscriptionContext<'a> {
25    client_id: Uuid,
26    client_manager: &'a ClientManager,
27    bus_manager: &'a BusManager,
28    entity_cache: &'a EntityCache,
29    view_index: &'a ViewIndex,
30    #[cfg(feature = "otel")]
31    metrics: Option<Arc<Metrics>>,
32}
33
34pub struct WebSocketServer {
35    bind_addr: SocketAddr,
36    client_manager: ClientManager,
37    bus_manager: BusManager,
38    entity_cache: EntityCache,
39    view_index: Arc<ViewIndex>,
40    max_clients: usize,
41    #[cfg(feature = "otel")]
42    metrics: Option<Arc<Metrics>>,
43}
44
45impl WebSocketServer {
46    #[cfg(feature = "otel")]
47    pub fn new(
48        bind_addr: SocketAddr,
49        bus_manager: BusManager,
50        entity_cache: EntityCache,
51        view_index: Arc<ViewIndex>,
52        metrics: Option<Arc<Metrics>>,
53    ) -> Self {
54        Self {
55            bind_addr,
56            client_manager: ClientManager::new(),
57            bus_manager,
58            entity_cache,
59            view_index,
60            max_clients: 10000,
61            metrics,
62        }
63    }
64
65    #[cfg(not(feature = "otel"))]
66    pub fn new(
67        bind_addr: SocketAddr,
68        bus_manager: BusManager,
69        entity_cache: EntityCache,
70        view_index: Arc<ViewIndex>,
71    ) -> Self {
72        Self {
73            bind_addr,
74            client_manager: ClientManager::new(),
75            bus_manager,
76            entity_cache,
77            view_index,
78            max_clients: 10000,
79        }
80    }
81
82    pub fn with_max_clients(mut self, max_clients: usize) -> Self {
83        self.max_clients = max_clients;
84        self
85    }
86
87    pub async fn start(self) -> Result<()> {
88        info!(
89            "Starting WebSocket server on {} (max_clients: {})",
90            self.bind_addr, self.max_clients
91        );
92
93        let listener = TcpListener::bind(&self.bind_addr).await?;
94        info!("WebSocket server listening on {}", self.bind_addr);
95
96        self.client_manager.start_cleanup_task();
97
98        loop {
99            match listener.accept().await {
100                Ok((stream, addr)) => {
101                    let client_count = self.client_manager.client_count();
102                    if client_count >= self.max_clients {
103                        warn!(
104                            "Rejecting connection from {} - max clients ({}) reached",
105                            addr, self.max_clients
106                        );
107                        drop(stream);
108                        continue;
109                    }
110
111                    #[cfg(feature = "otel")]
112                    if let Some(ref metrics) = self.metrics {
113                        metrics.record_ws_connection();
114                    }
115
116                    info!(
117                        "New WebSocket connection from {} ({}/{} clients)",
118                        addr,
119                        client_count + 1,
120                        self.max_clients
121                    );
122                    let client_manager = self.client_manager.clone();
123                    let bus_manager = self.bus_manager.clone();
124                    let entity_cache = self.entity_cache.clone();
125                    let view_index = self.view_index.clone();
126                    #[cfg(feature = "otel")]
127                    let metrics = self.metrics.clone();
128
129                    tokio::spawn(
130                        async move {
131                            #[cfg(feature = "otel")]
132                            let result = handle_connection(
133                                stream,
134                                client_manager,
135                                bus_manager,
136                                entity_cache,
137                                view_index,
138                                metrics,
139                            )
140                            .await;
141                            #[cfg(not(feature = "otel"))]
142                            let result = handle_connection(
143                                stream,
144                                client_manager,
145                                bus_manager,
146                                entity_cache,
147                                view_index,
148                            )
149                            .await;
150
151                            if let Err(e) = result {
152                                error!("WebSocket connection error: {}", e);
153                            }
154                        }
155                        .instrument(info_span!("ws.connection", %addr)),
156                    );
157                }
158                Err(e) => {
159                    error!("Failed to accept connection: {}", e);
160                }
161            }
162        }
163    }
164}
165
166#[cfg(feature = "otel")]
167async fn handle_connection(
168    stream: TcpStream,
169    client_manager: ClientManager,
170    bus_manager: BusManager,
171    entity_cache: EntityCache,
172    view_index: Arc<ViewIndex>,
173    metrics: Option<Arc<Metrics>>,
174) -> Result<()> {
175    let ws_stream = accept_async(stream).await?;
176    let client_id = Uuid::new_v4();
177    let connection_start = Instant::now();
178
179    info!("WebSocket connection established for client {}", client_id);
180
181    let (ws_sender, mut ws_receiver) = ws_stream.split();
182
183    client_manager.add_client(client_id, ws_sender);
184
185    let ctx = SubscriptionContext {
186        client_id,
187        client_manager: &client_manager,
188        bus_manager: &bus_manager,
189        entity_cache: &entity_cache,
190        view_index: &view_index,
191        metrics: metrics.clone(),
192    };
193
194    let mut active_subscriptions: Vec<String> = Vec::new();
195
196    loop {
197        tokio::select! {
198            ws_msg = ws_receiver.next() => {
199                match ws_msg {
200                    Some(Ok(msg)) => {
201                        if msg.is_close() {
202                            info!("Client {} requested close", client_id);
203                            break;
204                        }
205
206                        client_manager.update_client_last_seen(client_id);
207
208                        if msg.is_text() {
209                            if let Some(ref m) = metrics {
210                                m.record_ws_message_received();
211                            }
212
213                            if let Ok(text) = msg.to_text() {
214                                debug!("Received text message from client {}: {}", client_id, text);
215
216                                if let Ok(client_msg) = serde_json::from_str::<ClientMessage>(text) {
217                                    match client_msg {
218                                        ClientMessage::Subscribe(subscription) => {
219                                            let view_id = subscription.view.clone();
220                                            let sub_key = subscription.sub_key();
221                                            client_manager.update_subscription(client_id, subscription.clone());
222
223                                            let cancel_token = CancellationToken::new();
224                                            let is_new = client_manager.add_client_subscription(
225                                                client_id,
226                                                sub_key.clone(),
227                                                cancel_token.clone(),
228                                            ).await;
229
230                                            if !is_new {
231                                                debug!("Client {} already subscribed to {}, ignoring duplicate", client_id, sub_key);
232                                                continue;
233                                            }
234
235                                            if let Some(ref m) = metrics {
236                                                m.record_subscription_created(&view_id);
237                                            }
238                                            active_subscriptions.push(view_id);
239
240                                            attach_client_to_bus(&ctx, subscription, cancel_token).await;
241                                        }
242                                        ClientMessage::Unsubscribe(unsub) => {
243                                            let sub_key = unsub.sub_key();
244                                            let removed = client_manager
245                                                .remove_client_subscription(client_id, &sub_key)
246                                                .await;
247
248                                            if removed {
249                                                info!("Client {} unsubscribed from {}", client_id, sub_key);
250                                                if let Some(ref m) = metrics {
251                                                    m.record_subscription_removed(&unsub.view);
252                                                }
253                                            }
254                                        }
255                                        ClientMessage::Ping => {
256                                            debug!("Received ping from client {}", client_id);
257                                        }
258                                    }
259                                } else if let Ok(subscription) = serde_json::from_str::<Subscription>(text) {
260                                    let view_id = subscription.view.clone();
261                                    let sub_key = subscription.sub_key();
262                                    client_manager.update_subscription(client_id, subscription.clone());
263
264                                    let cancel_token = CancellationToken::new();
265                                    let is_new = client_manager.add_client_subscription(
266                                        client_id,
267                                        sub_key.clone(),
268                                        cancel_token.clone(),
269                                    ).await;
270
271                                    if !is_new {
272                                        debug!("Client {} already subscribed to {}, ignoring duplicate", client_id, sub_key);
273                                        continue;
274                                    }
275
276                                    if let Some(ref m) = metrics {
277                                        m.record_subscription_created(&view_id);
278                                    }
279                                    active_subscriptions.push(view_id);
280
281                                    attach_client_to_bus(&ctx, subscription, cancel_token).await;
282                                } else {
283                                    debug!("Received non-subscription message from client {}: {}", client_id, text);
284                                }
285                            }
286                        }
287                    }
288                    Some(Err(e)) => {
289                        warn!("WebSocket error for client {}: {}", client_id, e);
290                        break;
291                    }
292                    None => {
293                        debug!("WebSocket stream ended for client {}", client_id);
294                        break;
295                    }
296                }
297            }
298        }
299    }
300
301    client_manager
302        .cancel_all_client_subscriptions(client_id)
303        .await;
304    client_manager.remove_client(client_id);
305
306    if let Some(ref m) = metrics {
307        let duration_secs = connection_start.elapsed().as_secs_f64();
308        m.record_ws_disconnection(duration_secs);
309
310        for view_id in active_subscriptions {
311            m.record_subscription_removed(&view_id);
312        }
313    }
314
315    info!("Client {} disconnected", client_id);
316
317    Ok(())
318}
319
320#[cfg(not(feature = "otel"))]
321async fn handle_connection(
322    stream: TcpStream,
323    client_manager: ClientManager,
324    bus_manager: BusManager,
325    entity_cache: EntityCache,
326    view_index: Arc<ViewIndex>,
327) -> Result<()> {
328    let ws_stream = accept_async(stream).await?;
329    let client_id = Uuid::new_v4();
330
331    info!("WebSocket connection established for client {}", client_id);
332
333    let (ws_sender, mut ws_receiver) = ws_stream.split();
334
335    client_manager.add_client(client_id, ws_sender);
336
337    let ctx = SubscriptionContext {
338        client_id,
339        client_manager: &client_manager,
340        bus_manager: &bus_manager,
341        entity_cache: &entity_cache,
342        view_index: &view_index,
343    };
344
345    loop {
346        tokio::select! {
347            ws_msg = ws_receiver.next() => {
348                match ws_msg {
349                    Some(Ok(msg)) => {
350                        if msg.is_close() {
351                            info!("Client {} requested close", client_id);
352                            break;
353                        }
354
355                        client_manager.update_client_last_seen(client_id);
356
357                        if msg.is_text() {
358                            if let Ok(text) = msg.to_text() {
359                                debug!("Received text message from client {}: {}", client_id, text);
360
361                                if let Ok(client_msg) = serde_json::from_str::<ClientMessage>(text) {
362                                    match client_msg {
363                                        ClientMessage::Subscribe(subscription) => {
364                                            let sub_key = subscription.sub_key();
365                                            client_manager.update_subscription(client_id, subscription.clone());
366
367                                            let cancel_token = CancellationToken::new();
368                                            let is_new = client_manager.add_client_subscription(
369                                                client_id,
370                                                sub_key.clone(),
371                                                cancel_token.clone(),
372                                            ).await;
373
374                                            if !is_new {
375                                                debug!("Client {} already subscribed to {}, ignoring duplicate", client_id, sub_key);
376                                                continue;
377                                            }
378
379                                            attach_client_to_bus(&ctx, subscription, cancel_token).await;
380                                        }
381                                        ClientMessage::Unsubscribe(unsub) => {
382                                            let sub_key = unsub.sub_key();
383                                            let removed = client_manager
384                                                .remove_client_subscription(client_id, &sub_key)
385                                                .await;
386
387                                            if removed {
388                                                info!("Client {} unsubscribed from {}", client_id, sub_key);
389                                            }
390                                        }
391                                        ClientMessage::Ping => {
392                                            debug!("Received ping from client {}", client_id);
393                                        }
394                                    }
395                                } else if let Ok(subscription) = serde_json::from_str::<Subscription>(text) {
396                                    let sub_key = subscription.sub_key();
397                                    client_manager.update_subscription(client_id, subscription.clone());
398
399                                    let cancel_token = CancellationToken::new();
400                                    let is_new = client_manager.add_client_subscription(
401                                        client_id,
402                                        sub_key.clone(),
403                                        cancel_token.clone(),
404                                    ).await;
405
406                                    if !is_new {
407                                        debug!("Client {} already subscribed to {}, ignoring duplicate", client_id, sub_key);
408                                        continue;
409                                    }
410
411                                    attach_client_to_bus(&ctx, subscription, cancel_token).await;
412                                } else {
413                                    debug!("Received non-subscription message from client {}: {}", client_id, text);
414                                }
415                            }
416                        }
417                    }
418                    Some(Err(e)) => {
419                        warn!("WebSocket error for client {}: {}", client_id, e);
420                        break;
421                    }
422                    None => {
423                        debug!("WebSocket stream ended for client {}", client_id);
424                        break;
425                    }
426                }
427            }
428        }
429    }
430
431    client_manager
432        .cancel_all_client_subscriptions(client_id)
433        .await;
434    client_manager.remove_client(client_id);
435    info!("Client {} disconnected", client_id);
436
437    Ok(())
438}
439
440async fn send_snapshot_batches(
441    client_id: Uuid,
442    entities: &[SnapshotEntity],
443    mode: Mode,
444    view_id: &str,
445    client_manager: &ClientManager,
446    batch_config: &SnapshotBatchConfig,
447    #[cfg(feature = "otel")] metrics: Option<&Arc<Metrics>>,
448) -> Result<()> {
449    let total = entities.len();
450    if total == 0 {
451        return Ok(());
452    }
453
454    let mut offset = 0;
455    let mut batch_num = 0;
456
457    while offset < total {
458        let batch_size = if batch_num == 0 {
459            batch_config.initial_batch_size
460        } else {
461            batch_config.subsequent_batch_size
462        };
463
464        let end = (offset + batch_size).min(total);
465        let batch_data: Vec<SnapshotEntity> = entities[offset..end].to_vec();
466        let is_complete = end >= total;
467
468        let snapshot_frame = SnapshotFrame {
469            mode,
470            export: view_id.to_string(),
471            op: "snapshot",
472            data: batch_data,
473            complete: is_complete,
474        };
475
476        if let Ok(json_payload) = serde_json::to_vec(&snapshot_frame) {
477            let payload = maybe_compress(&json_payload);
478            if client_manager
479                .send_compressed_async(client_id, payload)
480                .await
481                .is_err()
482            {
483                return Err(anyhow::anyhow!("Failed to send snapshot batch"));
484            }
485            #[cfg(feature = "otel")]
486            if let Some(m) = metrics {
487                m.record_ws_message_sent();
488            }
489        }
490
491        offset = end;
492        batch_num += 1;
493    }
494
495    debug!(
496        "Sent {} snapshot batches ({} entities) for {} to client {}",
497        batch_num, total, view_id, client_id
498    );
499
500    Ok(())
501}
502
503#[cfg(feature = "otel")]
504async fn attach_client_to_bus(
505    ctx: &SubscriptionContext<'_>,
506    subscription: Subscription,
507    cancel_token: CancellationToken,
508) {
509    let view_id = &subscription.view;
510
511    let view_spec = match ctx.view_index.get_view(view_id) {
512        Some(spec) => spec,
513        None => {
514            warn!("Unknown view ID: {}", view_id);
515            return;
516        }
517    };
518
519    match view_spec.mode {
520        Mode::State => {
521            let key = subscription.key.as_deref().unwrap_or("");
522            let mut rx = ctx.bus_manager.get_or_create_state_bus(view_id, key).await;
523
524            if !rx.borrow().is_empty() {
525                let data = rx.borrow().clone();
526                let _ = ctx.client_manager.send_to_client(ctx.client_id, data);
527                if let Some(ref m) = ctx.metrics {
528                    m.record_ws_message_sent();
529                }
530            }
531
532            let client_id = ctx.client_id;
533            let client_mgr = ctx.client_manager.clone();
534            let metrics_clone = ctx.metrics.clone();
535            let view_id_clone = view_id.clone();
536            let key_clone = key.to_string();
537            tokio::spawn(
538                async move {
539                    loop {
540                        tokio::select! {
541                            _ = cancel_token.cancelled() => {
542                                debug!("State subscription cancelled for client {}", client_id);
543                                break;
544                            }
545                            result = rx.changed() => {
546                                if result.is_err() {
547                                    break;
548                                }
549                                let data = rx.borrow().clone();
550                                if client_mgr.send_to_client(client_id, data).is_err() {
551                                    break;
552                                }
553                                if let Some(ref m) = metrics_clone {
554                                    m.record_ws_message_sent();
555                                }
556                            }
557                        }
558                    }
559                }
560                .instrument(info_span!("ws.subscribe.state", %client_id, view = %view_id_clone, key = %key_clone)),
561            );
562        }
563        Mode::List | Mode::Append => {
564            let mut rx = ctx.bus_manager.get_or_create_list_bus(view_id).await;
565
566            let snapshots = ctx.entity_cache.get_all(view_id).await;
567            let snapshot_entities: Vec<SnapshotEntity> = snapshots
568                .into_iter()
569                .filter(|(key, _)| subscription.matches_key(key))
570                .map(|(key, data)| SnapshotEntity { key, data })
571                .collect();
572
573            if !snapshot_entities.is_empty() {
574                let batch_config = ctx.entity_cache.snapshot_config();
575                if send_snapshot_batches(
576                    ctx.client_id,
577                    &snapshot_entities,
578                    view_spec.mode,
579                    view_id,
580                    ctx.client_manager,
581                    &batch_config,
582                    #[cfg(feature = "otel")]
583                    ctx.metrics.as_ref(),
584                )
585                .await
586                .is_err()
587                {
588                    return;
589                }
590            }
591
592            let client_id = ctx.client_id;
593            let client_mgr = ctx.client_manager.clone();
594            let sub = subscription.clone();
595            let metrics_clone = ctx.metrics.clone();
596            let view_id_clone = view_id.clone();
597            let mode = view_spec.mode;
598            tokio::spawn(
599                async move {
600                    loop {
601                        tokio::select! {
602                            _ = cancel_token.cancelled() => {
603                                debug!("List subscription cancelled for client {}", client_id);
604                                break;
605                            }
606                            result = rx.recv() => {
607                                match result {
608                                    Ok(envelope) => {
609                                        if sub.matches(&envelope.entity, &envelope.key) {
610                                            if client_mgr
611                                                .send_to_client(client_id, envelope.payload.clone())
612                                                .is_err()
613                                            {
614                                                break;
615                                            }
616                                            if let Some(ref m) = metrics_clone {
617                                                m.record_ws_message_sent();
618                                            }
619                                        }
620                                    }
621                                    Err(_) => break,
622                                }
623                            }
624                        }
625                    }
626                }
627                .instrument(info_span!("ws.subscribe.list", %client_id, view = %view_id_clone, mode = ?mode)),
628            );
629        }
630    }
631
632    info!(
633        "Client {} subscribed to {} (mode: {:?})",
634        ctx.client_id, view_id, view_spec.mode
635    );
636}
637
638#[cfg(not(feature = "otel"))]
639async fn attach_client_to_bus(
640    ctx: &SubscriptionContext<'_>,
641    subscription: Subscription,
642    cancel_token: CancellationToken,
643) {
644    let view_id = &subscription.view;
645
646    let view_spec = match ctx.view_index.get_view(view_id) {
647        Some(spec) => spec,
648        None => {
649            warn!("Unknown view ID: {}", view_id);
650            return;
651        }
652    };
653
654    match view_spec.mode {
655        Mode::State => {
656            let key = subscription.key.as_deref().unwrap_or("");
657            let mut rx = ctx.bus_manager.get_or_create_state_bus(view_id, key).await;
658
659            if !rx.borrow().is_empty() {
660                let data = rx.borrow().clone();
661                let _ = ctx.client_manager.send_to_client(ctx.client_id, data);
662            }
663
664            let client_id = ctx.client_id;
665            let client_mgr = ctx.client_manager.clone();
666            let view_id_clone = view_id.clone();
667            let key_clone = key.to_string();
668            tokio::spawn(
669                async move {
670                    loop {
671                        tokio::select! {
672                            _ = cancel_token.cancelled() => {
673                                debug!("State subscription cancelled for client {}", client_id);
674                                break;
675                            }
676                            result = rx.changed() => {
677                                if result.is_err() {
678                                    break;
679                                }
680                                let data = rx.borrow().clone();
681                                if client_mgr.send_to_client(client_id, data).is_err() {
682                                    break;
683                                }
684                            }
685                        }
686                    }
687                }
688                .instrument(info_span!("ws.subscribe.state", %client_id, view = %view_id_clone, key = %key_clone)),
689            );
690        }
691        Mode::List | Mode::Append => {
692            let mut rx = ctx.bus_manager.get_or_create_list_bus(view_id).await;
693
694            let snapshots = ctx.entity_cache.get_all(view_id).await;
695            let snapshot_entities: Vec<SnapshotEntity> = snapshots
696                .into_iter()
697                .filter(|(key, _)| subscription.matches_key(key))
698                .map(|(key, data)| SnapshotEntity { key, data })
699                .collect();
700
701            if !snapshot_entities.is_empty() {
702                let batch_config = ctx.entity_cache.snapshot_config();
703                if send_snapshot_batches(
704                    ctx.client_id,
705                    &snapshot_entities,
706                    view_spec.mode,
707                    view_id,
708                    ctx.client_manager,
709                    &batch_config,
710                )
711                .await
712                .is_err()
713                {
714                    return;
715                }
716            }
717
718            let client_id = ctx.client_id;
719            let client_mgr = ctx.client_manager.clone();
720            let sub = subscription.clone();
721            let view_id_clone = view_id.clone();
722            let mode = view_spec.mode;
723            tokio::spawn(
724                async move {
725                    loop {
726                        tokio::select! {
727                            _ = cancel_token.cancelled() => {
728                                debug!("List subscription cancelled for client {}", client_id);
729                                break;
730                            }
731                            result = rx.recv() => {
732                                match result {
733                                    Ok(envelope) => {
734                                        if sub.matches(&envelope.entity, &envelope.key)
735                                            && client_mgr
736                                                .send_to_client(client_id, envelope.payload.clone())
737                                                .is_err()
738                                        {
739                                            break;
740                                        }
741                                    }
742                                    Err(_) => break,
743                                }
744                            }
745                        }
746                    }
747                }
748                .instrument(info_span!("ws.subscribe.list", %client_id, view = %view_id_clone, mode = ?mode)),
749            );
750        }
751    }
752
753    info!(
754        "Client {} subscribed to {} (mode: {:?})",
755        ctx.client_id, view_id, view_spec.mode
756    );
757}