1use crate::bus::BusManager;
2use crate::cache::{EntityCache, SnapshotBatchConfig};
3use crate::compression::maybe_compress;
4use crate::view::ViewIndex;
5use crate::websocket::client_manager::ClientManager;
6use crate::websocket::frame::{Mode, SnapshotEntity, SnapshotFrame};
7use crate::websocket::subscription::{ClientMessage, Subscription};
8use anyhow::Result;
9use futures_util::StreamExt;
10use std::net::SocketAddr;
11use std::sync::Arc;
12#[cfg(feature = "otel")]
13use std::time::Instant;
14
15use tokio::net::{TcpListener, TcpStream};
16use tokio_tungstenite::accept_async;
17use tokio_util::sync::CancellationToken;
18use tracing::{debug, error, info, info_span, warn, Instrument};
19use uuid::Uuid;
20
21#[cfg(feature = "otel")]
22use crate::metrics::Metrics;
23
24struct SubscriptionContext<'a> {
25 client_id: Uuid,
26 client_manager: &'a ClientManager,
27 bus_manager: &'a BusManager,
28 entity_cache: &'a EntityCache,
29 view_index: &'a ViewIndex,
30 #[cfg(feature = "otel")]
31 metrics: Option<Arc<Metrics>>,
32}
33
34pub struct WebSocketServer {
35 bind_addr: SocketAddr,
36 client_manager: ClientManager,
37 bus_manager: BusManager,
38 entity_cache: EntityCache,
39 view_index: Arc<ViewIndex>,
40 max_clients: usize,
41 #[cfg(feature = "otel")]
42 metrics: Option<Arc<Metrics>>,
43}
44
45impl WebSocketServer {
46 #[cfg(feature = "otel")]
47 pub fn new(
48 bind_addr: SocketAddr,
49 bus_manager: BusManager,
50 entity_cache: EntityCache,
51 view_index: Arc<ViewIndex>,
52 metrics: Option<Arc<Metrics>>,
53 ) -> Self {
54 Self {
55 bind_addr,
56 client_manager: ClientManager::new(),
57 bus_manager,
58 entity_cache,
59 view_index,
60 max_clients: 10000,
61 metrics,
62 }
63 }
64
65 #[cfg(not(feature = "otel"))]
66 pub fn new(
67 bind_addr: SocketAddr,
68 bus_manager: BusManager,
69 entity_cache: EntityCache,
70 view_index: Arc<ViewIndex>,
71 ) -> Self {
72 Self {
73 bind_addr,
74 client_manager: ClientManager::new(),
75 bus_manager,
76 entity_cache,
77 view_index,
78 max_clients: 10000,
79 }
80 }
81
82 pub fn with_max_clients(mut self, max_clients: usize) -> Self {
83 self.max_clients = max_clients;
84 self
85 }
86
87 pub async fn start(self) -> Result<()> {
88 info!(
89 "Starting WebSocket server on {} (max_clients: {})",
90 self.bind_addr, self.max_clients
91 );
92
93 let listener = TcpListener::bind(&self.bind_addr).await?;
94 info!("WebSocket server listening on {}", self.bind_addr);
95
96 self.client_manager.start_cleanup_task();
97
98 loop {
99 match listener.accept().await {
100 Ok((stream, addr)) => {
101 let client_count = self.client_manager.client_count();
102 if client_count >= self.max_clients {
103 warn!(
104 "Rejecting connection from {} - max clients ({}) reached",
105 addr, self.max_clients
106 );
107 drop(stream);
108 continue;
109 }
110
111 #[cfg(feature = "otel")]
112 if let Some(ref metrics) = self.metrics {
113 metrics.record_ws_connection();
114 }
115
116 info!(
117 "New WebSocket connection from {} ({}/{} clients)",
118 addr,
119 client_count + 1,
120 self.max_clients
121 );
122 let client_manager = self.client_manager.clone();
123 let bus_manager = self.bus_manager.clone();
124 let entity_cache = self.entity_cache.clone();
125 let view_index = self.view_index.clone();
126 #[cfg(feature = "otel")]
127 let metrics = self.metrics.clone();
128
129 tokio::spawn(
130 async move {
131 #[cfg(feature = "otel")]
132 let result = handle_connection(
133 stream,
134 client_manager,
135 bus_manager,
136 entity_cache,
137 view_index,
138 metrics,
139 )
140 .await;
141 #[cfg(not(feature = "otel"))]
142 let result = handle_connection(
143 stream,
144 client_manager,
145 bus_manager,
146 entity_cache,
147 view_index,
148 )
149 .await;
150
151 if let Err(e) = result {
152 error!("WebSocket connection error: {}", e);
153 }
154 }
155 .instrument(info_span!("ws.connection", %addr)),
156 );
157 }
158 Err(e) => {
159 error!("Failed to accept connection: {}", e);
160 }
161 }
162 }
163 }
164}
165
166#[cfg(feature = "otel")]
167async fn handle_connection(
168 stream: TcpStream,
169 client_manager: ClientManager,
170 bus_manager: BusManager,
171 entity_cache: EntityCache,
172 view_index: Arc<ViewIndex>,
173 metrics: Option<Arc<Metrics>>,
174) -> Result<()> {
175 let ws_stream = accept_async(stream).await?;
176 let client_id = Uuid::new_v4();
177 let connection_start = Instant::now();
178
179 info!("WebSocket connection established for client {}", client_id);
180
181 let (ws_sender, mut ws_receiver) = ws_stream.split();
182
183 client_manager.add_client(client_id, ws_sender);
184
185 let ctx = SubscriptionContext {
186 client_id,
187 client_manager: &client_manager,
188 bus_manager: &bus_manager,
189 entity_cache: &entity_cache,
190 view_index: &view_index,
191 metrics: metrics.clone(),
192 };
193
194 let mut active_subscriptions: Vec<String> = Vec::new();
195
196 loop {
197 tokio::select! {
198 ws_msg = ws_receiver.next() => {
199 match ws_msg {
200 Some(Ok(msg)) => {
201 if msg.is_close() {
202 info!("Client {} requested close", client_id);
203 break;
204 }
205
206 client_manager.update_client_last_seen(client_id);
207
208 if msg.is_text() {
209 if let Some(ref m) = metrics {
210 m.record_ws_message_received();
211 }
212
213 if let Ok(text) = msg.to_text() {
214 debug!("Received text message from client {}: {}", client_id, text);
215
216 if let Ok(client_msg) = serde_json::from_str::<ClientMessage>(text) {
217 match client_msg {
218 ClientMessage::Subscribe(subscription) => {
219 let view_id = subscription.view.clone();
220 let sub_key = subscription.sub_key();
221 client_manager.update_subscription(client_id, subscription.clone());
222
223 let cancel_token = CancellationToken::new();
224 let is_new = client_manager.add_client_subscription(
225 client_id,
226 sub_key.clone(),
227 cancel_token.clone(),
228 ).await;
229
230 if !is_new {
231 debug!("Client {} already subscribed to {}, ignoring duplicate", client_id, sub_key);
232 continue;
233 }
234
235 if let Some(ref m) = metrics {
236 m.record_subscription_created(&view_id);
237 }
238 active_subscriptions.push(view_id);
239
240 attach_client_to_bus(&ctx, subscription, cancel_token).await;
241 }
242 ClientMessage::Unsubscribe(unsub) => {
243 let sub_key = unsub.sub_key();
244 let removed = client_manager
245 .remove_client_subscription(client_id, &sub_key)
246 .await;
247
248 if removed {
249 info!("Client {} unsubscribed from {}", client_id, sub_key);
250 if let Some(ref m) = metrics {
251 m.record_subscription_removed(&unsub.view);
252 }
253 }
254 }
255 ClientMessage::Ping => {
256 debug!("Received ping from client {}", client_id);
257 }
258 }
259 } else if let Ok(subscription) = serde_json::from_str::<Subscription>(text) {
260 let view_id = subscription.view.clone();
261 let sub_key = subscription.sub_key();
262 client_manager.update_subscription(client_id, subscription.clone());
263
264 let cancel_token = CancellationToken::new();
265 let is_new = client_manager.add_client_subscription(
266 client_id,
267 sub_key.clone(),
268 cancel_token.clone(),
269 ).await;
270
271 if !is_new {
272 debug!("Client {} already subscribed to {}, ignoring duplicate", client_id, sub_key);
273 continue;
274 }
275
276 if let Some(ref m) = metrics {
277 m.record_subscription_created(&view_id);
278 }
279 active_subscriptions.push(view_id);
280
281 attach_client_to_bus(&ctx, subscription, cancel_token).await;
282 } else {
283 debug!("Received non-subscription message from client {}: {}", client_id, text);
284 }
285 }
286 }
287 }
288 Some(Err(e)) => {
289 warn!("WebSocket error for client {}: {}", client_id, e);
290 break;
291 }
292 None => {
293 debug!("WebSocket stream ended for client {}", client_id);
294 break;
295 }
296 }
297 }
298 }
299 }
300
301 client_manager
302 .cancel_all_client_subscriptions(client_id)
303 .await;
304 client_manager.remove_client(client_id);
305
306 if let Some(ref m) = metrics {
307 let duration_secs = connection_start.elapsed().as_secs_f64();
308 m.record_ws_disconnection(duration_secs);
309
310 for view_id in active_subscriptions {
311 m.record_subscription_removed(&view_id);
312 }
313 }
314
315 info!("Client {} disconnected", client_id);
316
317 Ok(())
318}
319
320#[cfg(not(feature = "otel"))]
321async fn handle_connection(
322 stream: TcpStream,
323 client_manager: ClientManager,
324 bus_manager: BusManager,
325 entity_cache: EntityCache,
326 view_index: Arc<ViewIndex>,
327) -> Result<()> {
328 let ws_stream = accept_async(stream).await?;
329 let client_id = Uuid::new_v4();
330
331 info!("WebSocket connection established for client {}", client_id);
332
333 let (ws_sender, mut ws_receiver) = ws_stream.split();
334
335 client_manager.add_client(client_id, ws_sender);
336
337 let ctx = SubscriptionContext {
338 client_id,
339 client_manager: &client_manager,
340 bus_manager: &bus_manager,
341 entity_cache: &entity_cache,
342 view_index: &view_index,
343 };
344
345 loop {
346 tokio::select! {
347 ws_msg = ws_receiver.next() => {
348 match ws_msg {
349 Some(Ok(msg)) => {
350 if msg.is_close() {
351 info!("Client {} requested close", client_id);
352 break;
353 }
354
355 client_manager.update_client_last_seen(client_id);
356
357 if msg.is_text() {
358 if let Ok(text) = msg.to_text() {
359 debug!("Received text message from client {}: {}", client_id, text);
360
361 if let Ok(client_msg) = serde_json::from_str::<ClientMessage>(text) {
362 match client_msg {
363 ClientMessage::Subscribe(subscription) => {
364 let sub_key = subscription.sub_key();
365 client_manager.update_subscription(client_id, subscription.clone());
366
367 let cancel_token = CancellationToken::new();
368 let is_new = client_manager.add_client_subscription(
369 client_id,
370 sub_key.clone(),
371 cancel_token.clone(),
372 ).await;
373
374 if !is_new {
375 debug!("Client {} already subscribed to {}, ignoring duplicate", client_id, sub_key);
376 continue;
377 }
378
379 attach_client_to_bus(&ctx, subscription, cancel_token).await;
380 }
381 ClientMessage::Unsubscribe(unsub) => {
382 let sub_key = unsub.sub_key();
383 let removed = client_manager
384 .remove_client_subscription(client_id, &sub_key)
385 .await;
386
387 if removed {
388 info!("Client {} unsubscribed from {}", client_id, sub_key);
389 }
390 }
391 ClientMessage::Ping => {
392 debug!("Received ping from client {}", client_id);
393 }
394 }
395 } else if let Ok(subscription) = serde_json::from_str::<Subscription>(text) {
396 let sub_key = subscription.sub_key();
397 client_manager.update_subscription(client_id, subscription.clone());
398
399 let cancel_token = CancellationToken::new();
400 let is_new = client_manager.add_client_subscription(
401 client_id,
402 sub_key.clone(),
403 cancel_token.clone(),
404 ).await;
405
406 if !is_new {
407 debug!("Client {} already subscribed to {}, ignoring duplicate", client_id, sub_key);
408 continue;
409 }
410
411 attach_client_to_bus(&ctx, subscription, cancel_token).await;
412 } else {
413 debug!("Received non-subscription message from client {}: {}", client_id, text);
414 }
415 }
416 }
417 }
418 Some(Err(e)) => {
419 warn!("WebSocket error for client {}: {}", client_id, e);
420 break;
421 }
422 None => {
423 debug!("WebSocket stream ended for client {}", client_id);
424 break;
425 }
426 }
427 }
428 }
429 }
430
431 client_manager
432 .cancel_all_client_subscriptions(client_id)
433 .await;
434 client_manager.remove_client(client_id);
435 info!("Client {} disconnected", client_id);
436
437 Ok(())
438}
439
440async fn send_snapshot_batches(
441 client_id: Uuid,
442 entities: &[SnapshotEntity],
443 mode: Mode,
444 view_id: &str,
445 client_manager: &ClientManager,
446 batch_config: &SnapshotBatchConfig,
447 #[cfg(feature = "otel")] metrics: Option<&Arc<Metrics>>,
448) -> Result<()> {
449 let total = entities.len();
450 if total == 0 {
451 return Ok(());
452 }
453
454 let mut offset = 0;
455 let mut batch_num = 0;
456
457 while offset < total {
458 let batch_size = if batch_num == 0 {
459 batch_config.initial_batch_size
460 } else {
461 batch_config.subsequent_batch_size
462 };
463
464 let end = (offset + batch_size).min(total);
465 let batch_data: Vec<SnapshotEntity> = entities[offset..end].to_vec();
466 let is_complete = end >= total;
467
468 let snapshot_frame = SnapshotFrame {
469 mode,
470 export: view_id.to_string(),
471 op: "snapshot",
472 data: batch_data,
473 complete: is_complete,
474 };
475
476 if let Ok(json_payload) = serde_json::to_vec(&snapshot_frame) {
477 let payload = maybe_compress(&json_payload);
478 if client_manager
479 .send_compressed_async(client_id, payload)
480 .await
481 .is_err()
482 {
483 return Err(anyhow::anyhow!("Failed to send snapshot batch"));
484 }
485 #[cfg(feature = "otel")]
486 if let Some(m) = metrics {
487 m.record_ws_message_sent();
488 }
489 }
490
491 offset = end;
492 batch_num += 1;
493 }
494
495 debug!(
496 "Sent {} snapshot batches ({} entities) for {} to client {}",
497 batch_num, total, view_id, client_id
498 );
499
500 Ok(())
501}
502
503#[cfg(feature = "otel")]
504async fn attach_client_to_bus(
505 ctx: &SubscriptionContext<'_>,
506 subscription: Subscription,
507 cancel_token: CancellationToken,
508) {
509 let view_id = &subscription.view;
510
511 let view_spec = match ctx.view_index.get_view(view_id) {
512 Some(spec) => spec,
513 None => {
514 warn!("Unknown view ID: {}", view_id);
515 return;
516 }
517 };
518
519 match view_spec.mode {
520 Mode::State => {
521 let key = subscription.key.as_deref().unwrap_or("");
522 let mut rx = ctx.bus_manager.get_or_create_state_bus(view_id, key).await;
523
524 if !rx.borrow().is_empty() {
525 let data = rx.borrow().clone();
526 let _ = ctx.client_manager.send_to_client(ctx.client_id, data);
527 if let Some(ref m) = ctx.metrics {
528 m.record_ws_message_sent();
529 }
530 }
531
532 let client_id = ctx.client_id;
533 let client_mgr = ctx.client_manager.clone();
534 let metrics_clone = ctx.metrics.clone();
535 let view_id_clone = view_id.clone();
536 let key_clone = key.to_string();
537 tokio::spawn(
538 async move {
539 loop {
540 tokio::select! {
541 _ = cancel_token.cancelled() => {
542 debug!("State subscription cancelled for client {}", client_id);
543 break;
544 }
545 result = rx.changed() => {
546 if result.is_err() {
547 break;
548 }
549 let data = rx.borrow().clone();
550 if client_mgr.send_to_client(client_id, data).is_err() {
551 break;
552 }
553 if let Some(ref m) = metrics_clone {
554 m.record_ws_message_sent();
555 }
556 }
557 }
558 }
559 }
560 .instrument(info_span!("ws.subscribe.state", %client_id, view = %view_id_clone, key = %key_clone)),
561 );
562 }
563 Mode::List | Mode::Append => {
564 let mut rx = ctx.bus_manager.get_or_create_list_bus(view_id).await;
565
566 let snapshots = ctx.entity_cache.get_all(view_id).await;
567 let snapshot_entities: Vec<SnapshotEntity> = snapshots
568 .into_iter()
569 .filter(|(key, _)| subscription.matches_key(key))
570 .map(|(key, data)| SnapshotEntity { key, data })
571 .collect();
572
573 if !snapshot_entities.is_empty() {
574 let batch_config = ctx.entity_cache.snapshot_config();
575 if send_snapshot_batches(
576 ctx.client_id,
577 &snapshot_entities,
578 view_spec.mode,
579 view_id,
580 ctx.client_manager,
581 &batch_config,
582 #[cfg(feature = "otel")]
583 ctx.metrics.as_ref(),
584 )
585 .await
586 .is_err()
587 {
588 return;
589 }
590 }
591
592 let client_id = ctx.client_id;
593 let client_mgr = ctx.client_manager.clone();
594 let sub = subscription.clone();
595 let metrics_clone = ctx.metrics.clone();
596 let view_id_clone = view_id.clone();
597 let mode = view_spec.mode;
598 tokio::spawn(
599 async move {
600 loop {
601 tokio::select! {
602 _ = cancel_token.cancelled() => {
603 debug!("List subscription cancelled for client {}", client_id);
604 break;
605 }
606 result = rx.recv() => {
607 match result {
608 Ok(envelope) => {
609 if sub.matches(&envelope.entity, &envelope.key) {
610 if client_mgr
611 .send_to_client(client_id, envelope.payload.clone())
612 .is_err()
613 {
614 break;
615 }
616 if let Some(ref m) = metrics_clone {
617 m.record_ws_message_sent();
618 }
619 }
620 }
621 Err(_) => break,
622 }
623 }
624 }
625 }
626 }
627 .instrument(info_span!("ws.subscribe.list", %client_id, view = %view_id_clone, mode = ?mode)),
628 );
629 }
630 }
631
632 info!(
633 "Client {} subscribed to {} (mode: {:?})",
634 ctx.client_id, view_id, view_spec.mode
635 );
636}
637
638#[cfg(not(feature = "otel"))]
639async fn attach_client_to_bus(
640 ctx: &SubscriptionContext<'_>,
641 subscription: Subscription,
642 cancel_token: CancellationToken,
643) {
644 let view_id = &subscription.view;
645
646 let view_spec = match ctx.view_index.get_view(view_id) {
647 Some(spec) => spec,
648 None => {
649 warn!("Unknown view ID: {}", view_id);
650 return;
651 }
652 };
653
654 match view_spec.mode {
655 Mode::State => {
656 let key = subscription.key.as_deref().unwrap_or("");
657 let mut rx = ctx.bus_manager.get_or_create_state_bus(view_id, key).await;
658
659 if !rx.borrow().is_empty() {
660 let data = rx.borrow().clone();
661 let _ = ctx.client_manager.send_to_client(ctx.client_id, data);
662 }
663
664 let client_id = ctx.client_id;
665 let client_mgr = ctx.client_manager.clone();
666 let view_id_clone = view_id.clone();
667 let key_clone = key.to_string();
668 tokio::spawn(
669 async move {
670 loop {
671 tokio::select! {
672 _ = cancel_token.cancelled() => {
673 debug!("State subscription cancelled for client {}", client_id);
674 break;
675 }
676 result = rx.changed() => {
677 if result.is_err() {
678 break;
679 }
680 let data = rx.borrow().clone();
681 if client_mgr.send_to_client(client_id, data).is_err() {
682 break;
683 }
684 }
685 }
686 }
687 }
688 .instrument(info_span!("ws.subscribe.state", %client_id, view = %view_id_clone, key = %key_clone)),
689 );
690 }
691 Mode::List | Mode::Append => {
692 let mut rx = ctx.bus_manager.get_or_create_list_bus(view_id).await;
693
694 let snapshots = ctx.entity_cache.get_all(view_id).await;
695 let snapshot_entities: Vec<SnapshotEntity> = snapshots
696 .into_iter()
697 .filter(|(key, _)| subscription.matches_key(key))
698 .map(|(key, data)| SnapshotEntity { key, data })
699 .collect();
700
701 if !snapshot_entities.is_empty() {
702 let batch_config = ctx.entity_cache.snapshot_config();
703 if send_snapshot_batches(
704 ctx.client_id,
705 &snapshot_entities,
706 view_spec.mode,
707 view_id,
708 ctx.client_manager,
709 &batch_config,
710 )
711 .await
712 .is_err()
713 {
714 return;
715 }
716 }
717
718 let client_id = ctx.client_id;
719 let client_mgr = ctx.client_manager.clone();
720 let sub = subscription.clone();
721 let view_id_clone = view_id.clone();
722 let mode = view_spec.mode;
723 tokio::spawn(
724 async move {
725 loop {
726 tokio::select! {
727 _ = cancel_token.cancelled() => {
728 debug!("List subscription cancelled for client {}", client_id);
729 break;
730 }
731 result = rx.recv() => {
732 match result {
733 Ok(envelope) => {
734 if sub.matches(&envelope.entity, &envelope.key)
735 && client_mgr
736 .send_to_client(client_id, envelope.payload.clone())
737 .is_err()
738 {
739 break;
740 }
741 }
742 Err(_) => break,
743 }
744 }
745 }
746 }
747 }
748 .instrument(info_span!("ws.subscribe.list", %client_id, view = %view_id_clone, mode = ?mode)),
749 );
750 }
751 }
752
753 info!(
754 "Client {} subscribed to {} (mode: {:?})",
755 ctx.client_id, view_id, view_spec.mode
756 );
757}