hyperstack_server/websocket/
server.rs

1use crate::bus::BusManager;
2use crate::view::ViewIndex;
3use crate::websocket::client_manager::ClientManager;
4use crate::websocket::frame::Mode;
5use crate::websocket::subscription::Subscription;
6use anyhow::Result;
7use futures_util::StreamExt;
8use std::net::SocketAddr;
9use std::sync::Arc;
10#[cfg(feature = "otel")]
11use std::time::Instant;
12
13use tokio::net::{TcpListener, TcpStream};
14use tokio_tungstenite::accept_async;
15use tracing::{debug, error, info, warn};
16use uuid::Uuid;
17
18#[cfg(feature = "otel")]
19use crate::metrics::Metrics;
20
21pub struct WebSocketServer {
22    bind_addr: SocketAddr,
23    client_manager: ClientManager,
24    bus_manager: BusManager,
25    view_index: Arc<ViewIndex>,
26    max_clients: usize,
27    #[cfg(feature = "otel")]
28    metrics: Option<Arc<Metrics>>,
29}
30
31impl WebSocketServer {
32    #[cfg(feature = "otel")]
33    pub fn new(
34        bind_addr: SocketAddr,
35        bus_manager: BusManager,
36        view_index: Arc<ViewIndex>,
37        metrics: Option<Arc<Metrics>>,
38    ) -> Self {
39        Self {
40            bind_addr,
41            client_manager: ClientManager::new(),
42            bus_manager,
43            view_index,
44            max_clients: 10000,
45            metrics,
46        }
47    }
48
49    #[cfg(not(feature = "otel"))]
50    pub fn new(bind_addr: SocketAddr, bus_manager: BusManager, view_index: Arc<ViewIndex>) -> Self {
51        Self {
52            bind_addr,
53            client_manager: ClientManager::new(),
54            bus_manager,
55            view_index,
56            max_clients: 10000,
57        }
58    }
59
60    pub fn with_max_clients(mut self, max_clients: usize) -> Self {
61        self.max_clients = max_clients;
62        self
63    }
64
65    pub async fn start(self) -> Result<()> {
66        info!(
67            "Starting WebSocket server on {} (max_clients: {})",
68            self.bind_addr, self.max_clients
69        );
70
71        let listener = TcpListener::bind(&self.bind_addr).await?;
72        info!("WebSocket server listening on {}", self.bind_addr);
73
74        // Start cleanup task
75        self.client_manager.start_cleanup_task().await;
76
77        // Accept incoming connections
78        loop {
79            match listener.accept().await {
80                Ok((stream, addr)) => {
81                    // Check if we've reached the maximum number of clients
82                    let client_count = self.client_manager.client_count().await;
83                    if client_count >= self.max_clients {
84                        warn!(
85                            "Rejecting connection from {} - max clients ({}) reached",
86                            addr, self.max_clients
87                        );
88                        // Accept the connection but immediately close it
89                        if let Ok(mut ws_stream) = accept_async(stream).await {
90                            let _ = ws_stream.close(None).await;
91                        }
92                        continue;
93                    }
94
95                    // Record connection metric
96                    #[cfg(feature = "otel")]
97                    if let Some(ref metrics) = self.metrics {
98                        metrics.record_ws_connection();
99                    }
100
101                    info!(
102                        "New WebSocket connection from {} ({}/{} clients)",
103                        addr,
104                        client_count + 1,
105                        self.max_clients
106                    );
107                    let client_manager = self.client_manager.clone();
108                    let bus_manager = self.bus_manager.clone();
109                    let view_index = self.view_index.clone();
110                    #[cfg(feature = "otel")]
111                    let metrics = self.metrics.clone();
112
113                    tokio::spawn(async move {
114                        #[cfg(feature = "otel")]
115                        let result = handle_connection(
116                            stream,
117                            client_manager,
118                            bus_manager,
119                            view_index,
120                            metrics,
121                        )
122                        .await;
123                        #[cfg(not(feature = "otel"))]
124                        let result =
125                            handle_connection(stream, client_manager, bus_manager, view_index)
126                                .await;
127
128                        if let Err(e) = result {
129                            error!("WebSocket connection error: {}", e);
130                        }
131                    });
132                }
133                Err(e) => {
134                    error!("Failed to accept connection: {}", e);
135                }
136            }
137        }
138    }
139}
140
141#[cfg(feature = "otel")]
142async fn handle_connection(
143    stream: TcpStream,
144    client_manager: ClientManager,
145    bus_manager: BusManager,
146    view_index: Arc<ViewIndex>,
147    metrics: Option<Arc<Metrics>>,
148) -> Result<()> {
149    let ws_stream = accept_async(stream).await?;
150    let client_id = Uuid::new_v4();
151    let connection_start = Instant::now();
152
153    info!("WebSocket connection established for client {}", client_id);
154
155    let (ws_sender, mut ws_receiver) = ws_stream.split();
156
157    // Register client
158    client_manager.add_client(client_id, ws_sender).await?;
159
160    // Track active subscriptions for this client (for cleanup)
161    let mut active_subscriptions: Vec<String> = Vec::new();
162
163    // Handle incoming messages from client
164    loop {
165        tokio::select! {
166            ws_msg = ws_receiver.next() => {
167                match ws_msg {
168                    Some(Ok(msg)) => {
169                        if msg.is_close() {
170                            info!("Client {} requested close", client_id);
171                            break;
172                        }
173
174                        client_manager.update_client_last_seen(client_id).await;
175
176                        if msg.is_text() {
177                            // Record message received metric
178                            if let Some(ref m) = metrics {
179                                m.record_ws_message_received();
180                            }
181
182                            if let Ok(text) = msg.to_text() {
183                                debug!("Received text message from client {}: {}", client_id, text);
184
185                                // Try to parse as subscription
186                                if let Ok(subscription) = serde_json::from_str::<Subscription>(text) {
187                                    let view_id = subscription.view.clone();
188                                    client_manager.update_subscription(client_id, subscription.clone()).await;
189
190                                    // Record subscription metric
191                                    if let Some(ref m) = metrics {
192                                        m.record_subscription_created(&view_id);
193                                    }
194                                    active_subscriptions.push(view_id);
195
196                                    // Attach client to appropriate bus
197                                    attach_client_to_bus(
198                                        client_id,
199                                        subscription,
200                                        &client_manager,
201                                        &bus_manager,
202                                        &view_index,
203                                        metrics.clone(),
204                                    ).await;
205                                } else {
206                                    debug!("Received non-subscription message from client {}: {}", client_id, text);
207                                }
208                            }
209                        }
210                    }
211                    Some(Err(e)) => {
212                        warn!("WebSocket error for client {}: {}", client_id, e);
213                        break;
214                    }
215                    None => {
216                        debug!("WebSocket stream ended for client {}", client_id);
217                        break;
218                    }
219                }
220            }
221        }
222    }
223
224    // Clean up client
225    client_manager.remove_client(client_id).await;
226
227    // Record disconnection metrics
228    if let Some(ref m) = metrics {
229        let duration_secs = connection_start.elapsed().as_secs_f64();
230        m.record_ws_disconnection(duration_secs);
231
232        // Clean up subscription metrics
233        for view_id in active_subscriptions {
234            m.record_subscription_removed(&view_id);
235        }
236    }
237
238    info!("Client {} disconnected", client_id);
239
240    Ok(())
241}
242
243#[cfg(not(feature = "otel"))]
244async fn handle_connection(
245    stream: TcpStream,
246    client_manager: ClientManager,
247    bus_manager: BusManager,
248    view_index: Arc<ViewIndex>,
249) -> Result<()> {
250    let ws_stream = accept_async(stream).await?;
251    let client_id = Uuid::new_v4();
252
253    info!("WebSocket connection established for client {}", client_id);
254
255    let (ws_sender, mut ws_receiver) = ws_stream.split();
256
257    // Register client
258    client_manager.add_client(client_id, ws_sender).await?;
259
260    // Handle incoming messages from client
261    loop {
262        tokio::select! {
263            ws_msg = ws_receiver.next() => {
264                match ws_msg {
265                    Some(Ok(msg)) => {
266                        if msg.is_close() {
267                            info!("Client {} requested close", client_id);
268                            break;
269                        }
270
271                        client_manager.update_client_last_seen(client_id).await;
272
273                        if msg.is_text() {
274                            if let Ok(text) = msg.to_text() {
275                                debug!("Received text message from client {}: {}", client_id, text);
276
277                                // Try to parse as subscription
278                                if let Ok(subscription) = serde_json::from_str::<Subscription>(text) {
279                                    client_manager.update_subscription(client_id, subscription.clone()).await;
280
281                                    // Attach client to appropriate bus
282                                    attach_client_to_bus(client_id, subscription, &client_manager, &bus_manager, &view_index).await;
283                                } else {
284                                    debug!("Received non-subscription message from client {}: {}", client_id, text);
285                                }
286                            }
287                        }
288                    }
289                    Some(Err(e)) => {
290                        warn!("WebSocket error for client {}: {}", client_id, e);
291                        break;
292                    }
293                    None => {
294                        debug!("WebSocket stream ended for client {}", client_id);
295                        break;
296                    }
297                }
298            }
299        }
300    }
301
302    // Clean up client
303    client_manager.remove_client(client_id).await;
304    info!("Client {} disconnected", client_id);
305
306    Ok(())
307}
308
309#[cfg(feature = "otel")]
310async fn attach_client_to_bus(
311    client_id: Uuid,
312    subscription: Subscription,
313    client_manager: &ClientManager,
314    bus_manager: &BusManager,
315    view_index: &ViewIndex,
316    metrics: Option<Arc<Metrics>>,
317) {
318    let view_id = &subscription.view;
319
320    // Get the view spec to determine the mode
321    let view_spec = match view_index.get_view(view_id) {
322        Some(spec) => spec,
323        None => {
324            warn!("Unknown view ID: {}", view_id);
325            return;
326        }
327    };
328
329    match view_spec.mode {
330        Mode::State => {
331            let key = subscription.key.as_deref().unwrap_or("");
332            let mut rx = bus_manager.get_or_create_state_bus(view_id, key).await;
333
334            // Send current value immediately (latest-only semantics)
335            if !rx.borrow().is_empty() {
336                let data = rx.borrow().clone();
337                let _ = client_manager.send_to_client(client_id, data).await;
338                if let Some(ref m) = metrics {
339                    m.record_ws_message_sent();
340                }
341            }
342
343            // Spawn task to listen for updates
344            let client_mgr = client_manager.clone();
345            let metrics_clone = metrics.clone();
346            tokio::spawn(async move {
347                while rx.changed().await.is_ok() {
348                    let data = rx.borrow().clone();
349                    if client_mgr.send_to_client(client_id, data).await.is_err() {
350                        break; // Client disconnected
351                    }
352                    if let Some(ref m) = metrics_clone {
353                        m.record_ws_message_sent();
354                    }
355                }
356            });
357        }
358        Mode::Kv | Mode::Append => {
359            let mut rx = bus_manager.get_or_create_kv_bus(view_id).await;
360
361            let client_mgr = client_manager.clone();
362            let sub = subscription.clone();
363            let metrics_clone = metrics.clone();
364            tokio::spawn(async move {
365                while let Ok(envelope) = rx.recv().await {
366                    // Filter messages based on subscription
367                    if sub.matches(&envelope.entity, &envelope.key) {
368                        if client_mgr
369                            .send_to_client(client_id, envelope.payload.clone())
370                            .await
371                            .is_err()
372                        {
373                            break; // Client disconnected
374                        }
375                        if let Some(ref m) = metrics_clone {
376                            m.record_ws_message_sent();
377                        }
378                    }
379                }
380            });
381        }
382        Mode::List => {
383            let mut rx = bus_manager.get_or_create_list_bus(view_id).await;
384
385            let client_mgr = client_manager.clone();
386            let sub = subscription.clone();
387            let metrics_clone = metrics.clone();
388            tokio::spawn(async move {
389                while let Ok(envelope) = rx.recv().await {
390                    // Filter messages based on subscription
391                    if sub.matches(&envelope.entity, &envelope.key) {
392                        if client_mgr
393                            .send_to_client(client_id, envelope.payload.clone())
394                            .await
395                            .is_err()
396                        {
397                            break; // Client disconnected
398                        }
399                        if let Some(ref m) = metrics_clone {
400                            m.record_ws_message_sent();
401                        }
402                    }
403                }
404            });
405        }
406    }
407
408    info!(
409        "Client {} subscribed to {} (mode: {:?})",
410        client_id, view_id, view_spec.mode
411    );
412}
413
414#[cfg(not(feature = "otel"))]
415async fn attach_client_to_bus(
416    client_id: Uuid,
417    subscription: Subscription,
418    client_manager: &ClientManager,
419    bus_manager: &BusManager,
420    view_index: &ViewIndex,
421) {
422    let view_id = &subscription.view;
423
424    // Get the view spec to determine the mode
425    let view_spec = match view_index.get_view(view_id) {
426        Some(spec) => spec,
427        None => {
428            warn!("Unknown view ID: {}", view_id);
429            return;
430        }
431    };
432
433    match view_spec.mode {
434        Mode::State => {
435            let key = subscription.key.as_deref().unwrap_or("");
436            let mut rx = bus_manager.get_or_create_state_bus(view_id, key).await;
437
438            // Send current value immediately (latest-only semantics)
439            if !rx.borrow().is_empty() {
440                let data = rx.borrow().clone();
441                let _ = client_manager.send_to_client(client_id, data).await;
442            }
443
444            // Spawn task to listen for updates
445            let client_mgr = client_manager.clone();
446            tokio::spawn(async move {
447                while rx.changed().await.is_ok() {
448                    let data = rx.borrow().clone();
449                    if client_mgr.send_to_client(client_id, data).await.is_err() {
450                        break; // Client disconnected
451                    }
452                }
453            });
454        }
455        Mode::Kv | Mode::Append => {
456            let mut rx = bus_manager.get_or_create_kv_bus(view_id).await;
457
458            let client_mgr = client_manager.clone();
459            let sub = subscription.clone();
460            tokio::spawn(async move {
461                while let Ok(envelope) = rx.recv().await {
462                    // Filter messages based on subscription
463                    if sub.matches(&envelope.entity, &envelope.key)
464                        && client_mgr
465                            .send_to_client(client_id, envelope.payload.clone())
466                            .await
467                            .is_err()
468                    {
469                        break; // Client disconnected
470                    }
471                }
472            });
473        }
474        Mode::List => {
475            let mut rx = bus_manager.get_or_create_list_bus(view_id).await;
476
477            let client_mgr = client_manager.clone();
478            let sub = subscription.clone();
479            tokio::spawn(async move {
480                while let Ok(envelope) = rx.recv().await {
481                    // Filter messages based on subscription
482                    if sub.matches(&envelope.entity, &envelope.key)
483                        && client_mgr
484                            .send_to_client(client_id, envelope.payload.clone())
485                            .await
486                            .is_err()
487                    {
488                        break; // Client disconnected
489                    }
490                }
491            });
492        }
493    }
494
495    info!(
496        "Client {} subscribed to {} (mode: {:?})",
497        client_id, view_id, view_spec.mode
498    );
499}