hyperstack_server/websocket/
server.rs

1use crate::bus::BusManager;
2use crate::view::ViewIndex;
3use crate::websocket::client_manager::ClientManager;
4use crate::websocket::frame::Mode;
5use crate::websocket::subscription::Subscription;
6use anyhow::Result;
7use futures_util::StreamExt;
8use std::net::SocketAddr;
9use std::sync::Arc;
10
11use tokio::net::{TcpListener, TcpStream};
12use tokio_tungstenite::accept_async;
13use tracing::{debug, error, info, warn};
14use uuid::Uuid;
15
16#[cfg(feature = "otel")]
17use crate::metrics::Metrics;
18
19pub struct WebSocketServer {
20    bind_addr: SocketAddr,
21    client_manager: ClientManager,
22    bus_manager: BusManager,
23    view_index: Arc<ViewIndex>,
24    max_clients: usize,
25    #[cfg(feature = "otel")]
26    metrics: Option<Arc<Metrics>>,
27}
28
29impl WebSocketServer {
30    #[cfg(feature = "otel")]
31    pub fn new(
32        bind_addr: SocketAddr,
33        bus_manager: BusManager,
34        view_index: Arc<ViewIndex>,
35        metrics: Option<Arc<Metrics>>,
36    ) -> Self {
37        Self {
38            bind_addr,
39            client_manager: ClientManager::new(),
40            bus_manager,
41            view_index,
42            max_clients: 10000,
43            metrics,
44        }
45    }
46
47    #[cfg(not(feature = "otel"))]
48    pub fn new(
49        bind_addr: SocketAddr,
50        bus_manager: BusManager,
51        view_index: Arc<ViewIndex>,
52    ) -> Self {
53        Self {
54            bind_addr,
55            client_manager: ClientManager::new(),
56            bus_manager,
57            view_index,
58            max_clients: 10000,
59        }
60    }
61
62    pub fn with_max_clients(mut self, max_clients: usize) -> Self {
63        self.max_clients = max_clients;
64        self
65    }
66
67    pub async fn start(self) -> Result<()> {
68        info!(
69            "Starting WebSocket server on {} (max_clients: {})",
70            self.bind_addr, self.max_clients
71        );
72
73        let listener = TcpListener::bind(&self.bind_addr).await?;
74        info!("WebSocket server listening on {}", self.bind_addr);
75
76        // Start cleanup task
77        self.client_manager.start_cleanup_task().await;
78
79        // Accept incoming connections
80        loop {
81            match listener.accept().await {
82                Ok((stream, addr)) => {
83                    // Check if we've reached the maximum number of clients
84                    let client_count = self.client_manager.client_count().await;
85                    if client_count >= self.max_clients {
86                        warn!(
87                            "Rejecting connection from {} - max clients ({}) reached",
88                            addr, self.max_clients
89                        );
90                        // Accept the connection but immediately close it
91                        if let Ok(mut ws_stream) = accept_async(stream).await {
92                            let _ = ws_stream.close(None).await;
93                        }
94                        continue;
95                    }
96
97                    // Record connection metric
98                    #[cfg(feature = "otel")]
99                    if let Some(ref metrics) = self.metrics {
100                        metrics.record_ws_connection();
101                    }
102
103                    info!(
104                        "New WebSocket connection from {} ({}/{} clients)",
105                        addr,
106                        client_count + 1,
107                        self.max_clients
108                    );
109                    let client_manager = self.client_manager.clone();
110                    let bus_manager = self.bus_manager.clone();
111                    let view_index = self.view_index.clone();
112                    #[cfg(feature = "otel")]
113                    let metrics = self.metrics.clone();
114
115                    tokio::spawn(async move {
116                        #[cfg(feature = "otel")]
117                        let result = handle_connection(
118                            stream,
119                            client_manager,
120                            bus_manager,
121                            view_index,
122                            metrics,
123                        )
124                        .await;
125                        #[cfg(not(feature = "otel"))]
126                        let result =
127                            handle_connection(stream, client_manager, bus_manager, view_index)
128                                .await;
129
130                        if let Err(e) = result {
131                            error!("WebSocket connection error: {}", e);
132                        }
133                    });
134                }
135                Err(e) => {
136                    error!("Failed to accept connection: {}", e);
137                }
138            }
139        }
140    }
141}
142
143#[cfg(feature = "otel")]
144async fn handle_connection(
145    stream: TcpStream,
146    client_manager: ClientManager,
147    bus_manager: BusManager,
148    view_index: Arc<ViewIndex>,
149    metrics: Option<Arc<Metrics>>,
150) -> Result<()> {
151    let ws_stream = accept_async(stream).await?;
152    let client_id = Uuid::new_v4();
153    let connection_start = Instant::now();
154
155    info!("WebSocket connection established for client {}", client_id);
156
157    let (ws_sender, mut ws_receiver) = ws_stream.split();
158
159    // Register client
160    client_manager.add_client(client_id, ws_sender).await?;
161
162    // Track active subscriptions for this client (for cleanup)
163    let mut active_subscriptions: Vec<String> = Vec::new();
164
165    // Handle incoming messages from client
166    loop {
167        tokio::select! {
168            ws_msg = ws_receiver.next() => {
169                match ws_msg {
170                    Some(Ok(msg)) => {
171                        if msg.is_close() {
172                            info!("Client {} requested close", client_id);
173                            break;
174                        }
175
176                        client_manager.update_client_last_seen(client_id).await;
177
178                        if msg.is_text() {
179                            // Record message received metric
180                            if let Some(ref m) = metrics {
181                                m.record_ws_message_received();
182                            }
183
184                            if let Ok(text) = msg.to_text() {
185                                debug!("Received text message from client {}: {}", client_id, text);
186
187                                // Try to parse as subscription
188                                if let Ok(subscription) = serde_json::from_str::<Subscription>(text) {
189                                    let view_id = subscription.view.clone();
190                                    client_manager.update_subscription(client_id, subscription.clone()).await;
191
192                                    // Record subscription metric
193                                    if let Some(ref m) = metrics {
194                                        m.record_subscription_created(&view_id);
195                                    }
196                                    active_subscriptions.push(view_id);
197
198                                    // Attach client to appropriate bus
199                                    attach_client_to_bus(
200                                        client_id,
201                                        subscription,
202                                        &client_manager,
203                                        &bus_manager,
204                                        &view_index,
205                                        metrics.clone(),
206                                    ).await;
207                                } else {
208                                    debug!("Received non-subscription message from client {}: {}", client_id, text);
209                                }
210                            }
211                        }
212                    }
213                    Some(Err(e)) => {
214                        warn!("WebSocket error for client {}: {}", client_id, e);
215                        break;
216                    }
217                    None => {
218                        debug!("WebSocket stream ended for client {}", client_id);
219                        break;
220                    }
221                }
222            }
223        }
224    }
225
226    // Clean up client
227    client_manager.remove_client(client_id).await;
228
229    // Record disconnection metrics
230    if let Some(ref m) = metrics {
231        let duration_secs = connection_start.elapsed().as_secs_f64();
232        m.record_ws_disconnection(duration_secs);
233
234        // Clean up subscription metrics
235        for view_id in active_subscriptions {
236            m.record_subscription_removed(&view_id);
237        }
238    }
239
240    info!("Client {} disconnected", client_id);
241
242    Ok(())
243}
244
245#[cfg(not(feature = "otel"))]
246async fn handle_connection(
247    stream: TcpStream,
248    client_manager: ClientManager,
249    bus_manager: BusManager,
250    view_index: Arc<ViewIndex>,
251) -> Result<()> {
252    let ws_stream = accept_async(stream).await?;
253    let client_id = Uuid::new_v4();
254
255    info!("WebSocket connection established for client {}", client_id);
256
257    let (ws_sender, mut ws_receiver) = ws_stream.split();
258
259    // Register client
260    client_manager.add_client(client_id, ws_sender).await?;
261
262    // Handle incoming messages from client
263    loop {
264        tokio::select! {
265            ws_msg = ws_receiver.next() => {
266                match ws_msg {
267                    Some(Ok(msg)) => {
268                        if msg.is_close() {
269                            info!("Client {} requested close", client_id);
270                            break;
271                        }
272
273                        client_manager.update_client_last_seen(client_id).await;
274
275                        if msg.is_text() {
276                            if let Ok(text) = msg.to_text() {
277                                debug!("Received text message from client {}: {}", client_id, text);
278
279                                // Try to parse as subscription
280                                if let Ok(subscription) = serde_json::from_str::<Subscription>(text) {
281                                    client_manager.update_subscription(client_id, subscription.clone()).await;
282
283                                    // Attach client to appropriate bus
284                                    attach_client_to_bus(client_id, subscription, &client_manager, &bus_manager, &view_index).await;
285                                } else {
286                                    debug!("Received non-subscription message from client {}: {}", client_id, text);
287                                }
288                            }
289                        }
290                    }
291                    Some(Err(e)) => {
292                        warn!("WebSocket error for client {}: {}", client_id, e);
293                        break;
294                    }
295                    None => {
296                        debug!("WebSocket stream ended for client {}", client_id);
297                        break;
298                    }
299                }
300            }
301        }
302    }
303
304    // Clean up client
305    client_manager.remove_client(client_id).await;
306    info!("Client {} disconnected", client_id);
307
308    Ok(())
309}
310
311#[cfg(feature = "otel")]
312async fn attach_client_to_bus(
313    client_id: Uuid,
314    subscription: Subscription,
315    client_manager: &ClientManager,
316    bus_manager: &BusManager,
317    view_index: &ViewIndex,
318    metrics: Option<Arc<Metrics>>,
319) {
320    let view_id = &subscription.view;
321
322    // Get the view spec to determine the mode
323    let view_spec = match view_index.get_view(view_id) {
324        Some(spec) => spec,
325        None => {
326            warn!("Unknown view ID: {}", view_id);
327            return;
328        }
329    };
330
331    match view_spec.mode {
332        Mode::State => {
333            let key = subscription.key.as_deref().unwrap_or("");
334            let mut rx = bus_manager.get_or_create_state_bus(view_id, key).await;
335
336            // Send current value immediately (latest-only semantics)
337            if !rx.borrow().is_empty() {
338                let data = rx.borrow().clone();
339                let _ = client_manager.send_to_client(client_id, data).await;
340                if let Some(ref m) = metrics {
341                    m.record_ws_message_sent();
342                }
343            }
344
345            // Spawn task to listen for updates
346            let client_mgr = client_manager.clone();
347            let metrics_clone = metrics.clone();
348            tokio::spawn(async move {
349                while rx.changed().await.is_ok() {
350                    let data = rx.borrow().clone();
351                    if client_mgr.send_to_client(client_id, data).await.is_err() {
352                        break; // Client disconnected
353                    }
354                    if let Some(ref m) = metrics_clone {
355                        m.record_ws_message_sent();
356                    }
357                }
358            });
359        }
360        Mode::Kv | Mode::Append => {
361            let mut rx = bus_manager.get_or_create_kv_bus(view_id).await;
362
363            let client_mgr = client_manager.clone();
364            let sub = subscription.clone();
365            let metrics_clone = metrics.clone();
366            tokio::spawn(async move {
367                while let Ok(envelope) = rx.recv().await {
368                    // Filter messages based on subscription
369                    if sub.matches(&envelope.entity, &envelope.key) {
370                        if client_mgr
371                            .send_to_client(client_id, envelope.payload.clone())
372                            .await
373                            .is_err()
374                        {
375                            break; // Client disconnected
376                        }
377                        if let Some(ref m) = metrics_clone {
378                            m.record_ws_message_sent();
379                        }
380                    }
381                }
382            });
383        }
384        Mode::List => {
385            let mut rx = bus_manager.get_or_create_list_bus(view_id).await;
386
387            let client_mgr = client_manager.clone();
388            let sub = subscription.clone();
389            let metrics_clone = metrics.clone();
390            tokio::spawn(async move {
391                while let Ok(envelope) = rx.recv().await {
392                    // Filter messages based on subscription
393                    if sub.matches(&envelope.entity, &envelope.key) {
394                        if client_mgr
395                            .send_to_client(client_id, envelope.payload.clone())
396                            .await
397                            .is_err()
398                        {
399                            break; // Client disconnected
400                        }
401                        if let Some(ref m) = metrics_clone {
402                            m.record_ws_message_sent();
403                        }
404                    }
405                }
406            });
407        }
408    }
409
410    info!(
411        "Client {} subscribed to {} (mode: {:?})",
412        client_id, view_id, view_spec.mode
413    );
414}
415
416#[cfg(not(feature = "otel"))]
417async fn attach_client_to_bus(
418    client_id: Uuid,
419    subscription: Subscription,
420    client_manager: &ClientManager,
421    bus_manager: &BusManager,
422    view_index: &ViewIndex,
423) {
424    let view_id = &subscription.view;
425
426    // Get the view spec to determine the mode
427    let view_spec = match view_index.get_view(view_id) {
428        Some(spec) => spec,
429        None => {
430            warn!("Unknown view ID: {}", view_id);
431            return;
432        }
433    };
434
435    match view_spec.mode {
436        Mode::State => {
437            let key = subscription.key.as_deref().unwrap_or("");
438            let mut rx = bus_manager.get_or_create_state_bus(view_id, key).await;
439
440            // Send current value immediately (latest-only semantics)
441            if !rx.borrow().is_empty() {
442                let data = rx.borrow().clone();
443                let _ = client_manager.send_to_client(client_id, data).await;
444            }
445
446            // Spawn task to listen for updates
447            let client_mgr = client_manager.clone();
448            tokio::spawn(async move {
449                while rx.changed().await.is_ok() {
450                    let data = rx.borrow().clone();
451                    if client_mgr.send_to_client(client_id, data).await.is_err() {
452                        break; // Client disconnected
453                    }
454                }
455            });
456        }
457        Mode::Kv | Mode::Append => {
458            let mut rx = bus_manager.get_or_create_kv_bus(view_id).await;
459
460            let client_mgr = client_manager.clone();
461            let sub = subscription.clone();
462            tokio::spawn(async move {
463                while let Ok(envelope) = rx.recv().await {
464                    // Filter messages based on subscription
465                    if sub.matches(&envelope.entity, &envelope.key)
466                        && client_mgr
467                            .send_to_client(client_id, envelope.payload.clone())
468                            .await
469                            .is_err()
470                    {
471                        break; // Client disconnected
472                    }
473                }
474            });
475        }
476        Mode::List => {
477            let mut rx = bus_manager.get_or_create_list_bus(view_id).await;
478
479            let client_mgr = client_manager.clone();
480            let sub = subscription.clone();
481            tokio::spawn(async move {
482                while let Ok(envelope) = rx.recv().await {
483                    // Filter messages based on subscription
484                    if sub.matches(&envelope.entity, &envelope.key)
485                        && client_mgr
486                            .send_to_client(client_id, envelope.payload.clone())
487                            .await
488                            .is_err()
489                    {
490                        break; // Client disconnected
491                    }
492                }
493            });
494        }
495    }
496
497    info!(
498        "Client {} subscribed to {} (mode: {:?})",
499        client_id, view_id, view_spec.mode
500    );
501}