hyperstack_server/websocket/
server.rs

1use crate::bus::BusManager;
2use crate::cache::EntityCache;
3use crate::compression::maybe_compress;
4use crate::view::ViewIndex;
5use crate::websocket::client_manager::ClientManager;
6use crate::websocket::frame::{Mode, SnapshotEntity, SnapshotFrame};
7use crate::websocket::subscription::{ClientMessage, Subscription};
8use anyhow::Result;
9use futures_util::StreamExt;
10use std::net::SocketAddr;
11use std::sync::Arc;
12#[cfg(feature = "otel")]
13use std::time::Instant;
14
15use tokio::net::{TcpListener, TcpStream};
16use tokio_tungstenite::accept_async;
17use tokio_util::sync::CancellationToken;
18use tracing::{debug, error, info, info_span, warn, Instrument};
19use uuid::Uuid;
20
21#[cfg(feature = "otel")]
22use crate::metrics::Metrics;
23
24pub struct WebSocketServer {
25    bind_addr: SocketAddr,
26    client_manager: ClientManager,
27    bus_manager: BusManager,
28    entity_cache: EntityCache,
29    view_index: Arc<ViewIndex>,
30    max_clients: usize,
31    #[cfg(feature = "otel")]
32    metrics: Option<Arc<Metrics>>,
33}
34
35impl WebSocketServer {
36    #[cfg(feature = "otel")]
37    pub fn new(
38        bind_addr: SocketAddr,
39        bus_manager: BusManager,
40        entity_cache: EntityCache,
41        view_index: Arc<ViewIndex>,
42        metrics: Option<Arc<Metrics>>,
43    ) -> Self {
44        Self {
45            bind_addr,
46            client_manager: ClientManager::new(),
47            bus_manager,
48            entity_cache,
49            view_index,
50            max_clients: 10000,
51            metrics,
52        }
53    }
54
55    #[cfg(not(feature = "otel"))]
56    pub fn new(
57        bind_addr: SocketAddr,
58        bus_manager: BusManager,
59        entity_cache: EntityCache,
60        view_index: Arc<ViewIndex>,
61    ) -> Self {
62        Self {
63            bind_addr,
64            client_manager: ClientManager::new(),
65            bus_manager,
66            entity_cache,
67            view_index,
68            max_clients: 10000,
69        }
70    }
71
72    pub fn with_max_clients(mut self, max_clients: usize) -> Self {
73        self.max_clients = max_clients;
74        self
75    }
76
77    pub async fn start(self) -> Result<()> {
78        info!(
79            "Starting WebSocket server on {} (max_clients: {})",
80            self.bind_addr, self.max_clients
81        );
82
83        let listener = TcpListener::bind(&self.bind_addr).await?;
84        info!("WebSocket server listening on {}", self.bind_addr);
85
86        self.client_manager.start_cleanup_task();
87
88        loop {
89            match listener.accept().await {
90                Ok((stream, addr)) => {
91                    let client_count = self.client_manager.client_count();
92                    if client_count >= self.max_clients {
93                        warn!(
94                            "Rejecting connection from {} - max clients ({}) reached",
95                            addr, self.max_clients
96                        );
97                        drop(stream);
98                        continue;
99                    }
100
101                    #[cfg(feature = "otel")]
102                    if let Some(ref metrics) = self.metrics {
103                        metrics.record_ws_connection();
104                    }
105
106                    info!(
107                        "New WebSocket connection from {} ({}/{} clients)",
108                        addr,
109                        client_count + 1,
110                        self.max_clients
111                    );
112                    let client_manager = self.client_manager.clone();
113                    let bus_manager = self.bus_manager.clone();
114                    let entity_cache = self.entity_cache.clone();
115                    let view_index = self.view_index.clone();
116                    #[cfg(feature = "otel")]
117                    let metrics = self.metrics.clone();
118
119                    tokio::spawn(
120                        async move {
121                            #[cfg(feature = "otel")]
122                            let result = handle_connection(
123                                stream,
124                                client_manager,
125                                bus_manager,
126                                entity_cache,
127                                view_index,
128                                metrics,
129                            )
130                            .await;
131                            #[cfg(not(feature = "otel"))]
132                            let result = handle_connection(
133                                stream,
134                                client_manager,
135                                bus_manager,
136                                entity_cache,
137                                view_index,
138                            )
139                            .await;
140
141                            if let Err(e) = result {
142                                error!("WebSocket connection error: {}", e);
143                            }
144                        }
145                        .instrument(info_span!("ws.connection", %addr)),
146                    );
147                }
148                Err(e) => {
149                    error!("Failed to accept connection: {}", e);
150                }
151            }
152        }
153    }
154}
155
156#[cfg(feature = "otel")]
157async fn handle_connection(
158    stream: TcpStream,
159    client_manager: ClientManager,
160    bus_manager: BusManager,
161    entity_cache: EntityCache,
162    view_index: Arc<ViewIndex>,
163    metrics: Option<Arc<Metrics>>,
164) -> Result<()> {
165    let ws_stream = accept_async(stream).await?;
166    let client_id = Uuid::new_v4();
167    let connection_start = Instant::now();
168
169    info!("WebSocket connection established for client {}", client_id);
170
171    let (ws_sender, mut ws_receiver) = ws_stream.split();
172
173    client_manager.add_client(client_id, ws_sender);
174
175    let mut active_subscriptions: Vec<String> = Vec::new();
176
177    loop {
178        tokio::select! {
179            ws_msg = ws_receiver.next() => {
180                match ws_msg {
181                    Some(Ok(msg)) => {
182                        if msg.is_close() {
183                            info!("Client {} requested close", client_id);
184                            break;
185                        }
186
187                        client_manager.update_client_last_seen(client_id);
188
189                        if msg.is_text() {
190                            if let Some(ref m) = metrics {
191                                m.record_ws_message_received();
192                            }
193
194                            if let Ok(text) = msg.to_text() {
195                                debug!("Received text message from client {}: {}", client_id, text);
196
197                                if let Ok(client_msg) = serde_json::from_str::<ClientMessage>(text) {
198                                    match client_msg {
199                                        ClientMessage::Subscribe(subscription) => {
200                                            let view_id = subscription.view.clone();
201                                            let sub_key = subscription.sub_key();
202                                            client_manager.update_subscription(client_id, subscription.clone());
203
204                                            if let Some(ref m) = metrics {
205                                                m.record_subscription_created(&view_id);
206                                            }
207                                            active_subscriptions.push(view_id);
208
209                                            let cancel_token = CancellationToken::new();
210                                            client_manager.add_client_subscription(
211                                                client_id,
212                                                sub_key,
213                                                cancel_token.clone(),
214                                            ).await;
215
216                                            attach_client_to_bus(
217                                                client_id,
218                                                subscription,
219                                                &client_manager,
220                                                &bus_manager,
221                                                &entity_cache,
222                                                &view_index,
223                                                cancel_token,
224                                                metrics.clone(),
225                                            ).await;
226                                        }
227                                        ClientMessage::Unsubscribe(unsub) => {
228                                            let sub_key = unsub.sub_key();
229                                            let removed = client_manager
230                                                .remove_client_subscription(client_id, &sub_key)
231                                                .await;
232
233                                            if removed {
234                                                info!("Client {} unsubscribed from {}", client_id, sub_key);
235                                                if let Some(ref m) = metrics {
236                                                    m.record_subscription_removed(&unsub.view);
237                                                }
238                                            }
239                                        }
240                                        ClientMessage::Ping => {
241                                            debug!("Received ping from client {}", client_id);
242                                        }
243                                    }
244                                } else if let Ok(subscription) = serde_json::from_str::<Subscription>(text) {
245                                    let view_id = subscription.view.clone();
246                                    let sub_key = subscription.sub_key();
247                                    client_manager.update_subscription(client_id, subscription.clone());
248
249                                    if let Some(ref m) = metrics {
250                                        m.record_subscription_created(&view_id);
251                                    }
252                                    active_subscriptions.push(view_id);
253
254                                    let cancel_token = CancellationToken::new();
255                                    client_manager.add_client_subscription(
256                                        client_id,
257                                        sub_key,
258                                        cancel_token.clone(),
259                                    ).await;
260
261                                    attach_client_to_bus(
262                                        client_id,
263                                        subscription,
264                                        &client_manager,
265                                        &bus_manager,
266                                        &entity_cache,
267                                        &view_index,
268                                        cancel_token,
269                                        metrics.clone(),
270                                    ).await;
271                                } else {
272                                    debug!("Received non-subscription message from client {}: {}", client_id, text);
273                                }
274                            }
275                        }
276                    }
277                    Some(Err(e)) => {
278                        warn!("WebSocket error for client {}: {}", client_id, e);
279                        break;
280                    }
281                    None => {
282                        debug!("WebSocket stream ended for client {}", client_id);
283                        break;
284                    }
285                }
286            }
287        }
288    }
289
290    client_manager
291        .cancel_all_client_subscriptions(client_id)
292        .await;
293    client_manager.remove_client(client_id);
294
295    if let Some(ref m) = metrics {
296        let duration_secs = connection_start.elapsed().as_secs_f64();
297        m.record_ws_disconnection(duration_secs);
298
299        for view_id in active_subscriptions {
300            m.record_subscription_removed(&view_id);
301        }
302    }
303
304    info!("Client {} disconnected", client_id);
305
306    Ok(())
307}
308
309#[cfg(not(feature = "otel"))]
310async fn handle_connection(
311    stream: TcpStream,
312    client_manager: ClientManager,
313    bus_manager: BusManager,
314    entity_cache: EntityCache,
315    view_index: Arc<ViewIndex>,
316) -> Result<()> {
317    let ws_stream = accept_async(stream).await?;
318    let client_id = Uuid::new_v4();
319
320    info!("WebSocket connection established for client {}", client_id);
321
322    let (ws_sender, mut ws_receiver) = ws_stream.split();
323
324    client_manager.add_client(client_id, ws_sender);
325
326    loop {
327        tokio::select! {
328            ws_msg = ws_receiver.next() => {
329                match ws_msg {
330                    Some(Ok(msg)) => {
331                        if msg.is_close() {
332                            info!("Client {} requested close", client_id);
333                            break;
334                        }
335
336                        client_manager.update_client_last_seen(client_id);
337
338                        if msg.is_text() {
339                            if let Ok(text) = msg.to_text() {
340                                debug!("Received text message from client {}: {}", client_id, text);
341
342                                if let Ok(client_msg) = serde_json::from_str::<ClientMessage>(text) {
343                                    match client_msg {
344                                        ClientMessage::Subscribe(subscription) => {
345                                            let sub_key = subscription.sub_key();
346                                            client_manager.update_subscription(client_id, subscription.clone());
347
348                                            let cancel_token = CancellationToken::new();
349                                            client_manager.add_client_subscription(
350                                                client_id,
351                                                sub_key,
352                                                cancel_token.clone(),
353                                            ).await;
354
355                                            attach_client_to_bus(
356                                                client_id,
357                                                subscription,
358                                                &client_manager,
359                                                &bus_manager,
360                                                &entity_cache,
361                                                &view_index,
362                                                cancel_token,
363                                            ).await;
364                                        }
365                                        ClientMessage::Unsubscribe(unsub) => {
366                                            let sub_key = unsub.sub_key();
367                                            let removed = client_manager
368                                                .remove_client_subscription(client_id, &sub_key)
369                                                .await;
370
371                                            if removed {
372                                                info!("Client {} unsubscribed from {}", client_id, sub_key);
373                                            }
374                                        }
375                                        ClientMessage::Ping => {
376                                            debug!("Received ping from client {}", client_id);
377                                        }
378                                    }
379                                } else if let Ok(subscription) = serde_json::from_str::<Subscription>(text) {
380                                    let sub_key = subscription.sub_key();
381                                    client_manager.update_subscription(client_id, subscription.clone());
382
383                                    let cancel_token = CancellationToken::new();
384                                    client_manager.add_client_subscription(
385                                        client_id,
386                                        sub_key,
387                                        cancel_token.clone(),
388                                    ).await;
389
390                                    attach_client_to_bus(
391                                        client_id,
392                                        subscription,
393                                        &client_manager,
394                                        &bus_manager,
395                                        &entity_cache,
396                                        &view_index,
397                                        cancel_token,
398                                    ).await;
399                                } else {
400                                    debug!("Received non-subscription message from client {}: {}", client_id, text);
401                                }
402                            }
403                        }
404                    }
405                    Some(Err(e)) => {
406                        warn!("WebSocket error for client {}: {}", client_id, e);
407                        break;
408                    }
409                    None => {
410                        debug!("WebSocket stream ended for client {}", client_id);
411                        break;
412                    }
413                }
414            }
415        }
416    }
417
418    client_manager
419        .cancel_all_client_subscriptions(client_id)
420        .await;
421    client_manager.remove_client(client_id);
422    info!("Client {} disconnected", client_id);
423
424    Ok(())
425}
426
427#[cfg(feature = "otel")]
428async fn attach_client_to_bus(
429    client_id: Uuid,
430    subscription: Subscription,
431    client_manager: &ClientManager,
432    bus_manager: &BusManager,
433    entity_cache: &EntityCache,
434    view_index: &ViewIndex,
435    cancel_token: CancellationToken,
436    metrics: Option<Arc<Metrics>>,
437) {
438    let view_id = &subscription.view;
439
440    let view_spec = match view_index.get_view(view_id) {
441        Some(spec) => spec,
442        None => {
443            warn!("Unknown view ID: {}", view_id);
444            return;
445        }
446    };
447
448    match view_spec.mode {
449        Mode::State => {
450            let key = subscription.key.as_deref().unwrap_or("");
451            let mut rx = bus_manager.get_or_create_state_bus(view_id, key).await;
452
453            if !rx.borrow().is_empty() {
454                let data = rx.borrow().clone();
455                let _ = client_manager.send_to_client(client_id, data);
456                if let Some(ref m) = metrics {
457                    m.record_ws_message_sent();
458                }
459            }
460
461            let client_mgr = client_manager.clone();
462            let metrics_clone = metrics.clone();
463            let view_id_clone = view_id.clone();
464            let key_clone = key.to_string();
465            tokio::spawn(
466                async move {
467                    loop {
468                        tokio::select! {
469                            _ = cancel_token.cancelled() => {
470                                debug!("State subscription cancelled for client {}", client_id);
471                                break;
472                            }
473                            result = rx.changed() => {
474                                if result.is_err() {
475                                    break;
476                                }
477                                let data = rx.borrow().clone();
478                                if client_mgr.send_to_client(client_id, data).is_err() {
479                                    break;
480                                }
481                                if let Some(ref m) = metrics_clone {
482                                    m.record_ws_message_sent();
483                                }
484                            }
485                        }
486                    }
487                }
488                .instrument(info_span!("ws.subscribe.state", %client_id, view = %view_id_clone, key = %key_clone)),
489            );
490        }
491        Mode::List | Mode::Append => {
492            let mut rx = bus_manager.get_or_create_list_bus(view_id).await;
493
494            let snapshots = entity_cache.get_all(view_id).await;
495            let snapshot_entities: Vec<SnapshotEntity> = snapshots
496                .into_iter()
497                .filter(|(key, _)| subscription.matches_key(key))
498                .map(|(key, data)| SnapshotEntity { key, data })
499                .collect();
500
501            if !snapshot_entities.is_empty() {
502                let snapshot_frame = SnapshotFrame {
503                    mode: view_spec.mode,
504                    export: view_id.clone(),
505                    op: "snapshot",
506                    data: snapshot_entities,
507                };
508                if let Ok(json_payload) = serde_json::to_vec(&snapshot_frame) {
509                    let payload = maybe_compress(&json_payload);
510                    if client_manager
511                        .send_to_client_async(client_id, Arc::new(payload))
512                        .await
513                        .is_err()
514                    {
515                        return;
516                    }
517                    if let Some(ref m) = metrics {
518                        m.record_ws_message_sent();
519                    }
520                }
521            }
522
523            let client_mgr = client_manager.clone();
524            let sub = subscription.clone();
525            let metrics_clone = metrics.clone();
526            let view_id_clone = view_id.clone();
527            let mode = view_spec.mode;
528            tokio::spawn(
529                async move {
530                    loop {
531                        tokio::select! {
532                            _ = cancel_token.cancelled() => {
533                                debug!("List subscription cancelled for client {}", client_id);
534                                break;
535                            }
536                            result = rx.recv() => {
537                                match result {
538                                    Ok(envelope) => {
539                                        if sub.matches(&envelope.entity, &envelope.key) {
540                                            if client_mgr
541                                                .send_to_client(client_id, envelope.payload.clone())
542                                                .is_err()
543                                            {
544                                                break;
545                                            }
546                                            if let Some(ref m) = metrics_clone {
547                                                m.record_ws_message_sent();
548                                            }
549                                        }
550                                    }
551                                    Err(_) => break,
552                                }
553                            }
554                        }
555                    }
556                }
557                .instrument(info_span!("ws.subscribe.list", %client_id, view = %view_id_clone, mode = ?mode)),
558            );
559        }
560    }
561
562    info!(
563        "Client {} subscribed to {} (mode: {:?})",
564        client_id, view_id, view_spec.mode
565    );
566}
567
568#[cfg(not(feature = "otel"))]
569async fn attach_client_to_bus(
570    client_id: Uuid,
571    subscription: Subscription,
572    client_manager: &ClientManager,
573    bus_manager: &BusManager,
574    entity_cache: &EntityCache,
575    view_index: &ViewIndex,
576    cancel_token: CancellationToken,
577) {
578    let view_id = &subscription.view;
579
580    let view_spec = match view_index.get_view(view_id) {
581        Some(spec) => spec,
582        None => {
583            warn!("Unknown view ID: {}", view_id);
584            return;
585        }
586    };
587
588    match view_spec.mode {
589        Mode::State => {
590            let key = subscription.key.as_deref().unwrap_or("");
591            let mut rx = bus_manager.get_or_create_state_bus(view_id, key).await;
592
593            if !rx.borrow().is_empty() {
594                let data = rx.borrow().clone();
595                let _ = client_manager.send_to_client(client_id, data);
596            }
597
598            let client_mgr = client_manager.clone();
599            let view_id_clone = view_id.clone();
600            let key_clone = key.to_string();
601            tokio::spawn(
602                async move {
603                    loop {
604                        tokio::select! {
605                            _ = cancel_token.cancelled() => {
606                                debug!("State subscription cancelled for client {}", client_id);
607                                break;
608                            }
609                            result = rx.changed() => {
610                                if result.is_err() {
611                                    break;
612                                }
613                                let data = rx.borrow().clone();
614                                if client_mgr.send_to_client(client_id, data).is_err() {
615                                    break;
616                                }
617                            }
618                        }
619                    }
620                }
621                .instrument(info_span!("ws.subscribe.state", %client_id, view = %view_id_clone, key = %key_clone)),
622            );
623        }
624        Mode::List | Mode::Append => {
625            let mut rx = bus_manager.get_or_create_list_bus(view_id).await;
626
627            let snapshots = entity_cache.get_all(view_id).await;
628            let snapshot_entities: Vec<SnapshotEntity> = snapshots
629                .into_iter()
630                .filter(|(key, _)| subscription.matches_key(key))
631                .map(|(key, data)| SnapshotEntity { key, data })
632                .collect();
633
634            if !snapshot_entities.is_empty() {
635                let snapshot_frame = SnapshotFrame {
636                    mode: view_spec.mode,
637                    export: view_id.clone(),
638                    op: "snapshot",
639                    data: snapshot_entities,
640                };
641                if let Ok(json_payload) = serde_json::to_vec(&snapshot_frame) {
642                    let payload = maybe_compress(&json_payload);
643                    if client_manager
644                        .send_to_client_async(client_id, Arc::new(payload))
645                        .await
646                        .is_err()
647                    {
648                        return;
649                    }
650                }
651            }
652
653            let client_mgr = client_manager.clone();
654            let sub = subscription.clone();
655            let view_id_clone = view_id.clone();
656            let mode = view_spec.mode;
657            tokio::spawn(
658                async move {
659                    loop {
660                        tokio::select! {
661                            _ = cancel_token.cancelled() => {
662                                debug!("List subscription cancelled for client {}", client_id);
663                                break;
664                            }
665                            result = rx.recv() => {
666                                match result {
667                                    Ok(envelope) => {
668                                        if sub.matches(&envelope.entity, &envelope.key)
669                                            && client_mgr
670                                                .send_to_client(client_id, envelope.payload.clone())
671                                                .is_err()
672                                        {
673                                            break;
674                                        }
675                                    }
676                                    Err(_) => break,
677                                }
678                            }
679                        }
680                    }
681                }
682                .instrument(info_span!("ws.subscribe.list", %client_id, view = %view_id_clone, mode = ?mode)),
683            );
684        }
685    }
686
687    info!(
688        "Client {} subscribed to {} (mode: {:?})",
689        client_id, view_id, view_spec.mode
690    );
691}