Skip to main content

bext_realtime/
ws.rs

1//! WebSocket session management: handles ping/pong heartbeats, JSON-framed
2//! client messages (subscribe, unsubscribe, publish), and server push delivery.
3
4use std::sync::Arc;
5use std::time::{Duration, Instant};
6
7use parking_lot::Mutex;
8use serde_json::Value;
9use tokio::sync::mpsc;
10use tracing::debug;
11
12use crate::hub::BextHub;
13use crate::message::{ClientMessage, HubEvent, ServerMessage};
14
15/// Configuration for a WebSocket session.
16#[derive(Debug, Clone)]
17pub struct WsSessionConfig {
18    /// How often to send Ping frames.
19    pub heartbeat_interval: Duration,
20    /// How long to wait for a Pong before considering the connection dead.
21    pub pong_timeout: Duration,
22}
23
24impl Default for WsSessionConfig {
25    fn default() -> Self {
26        Self {
27            heartbeat_interval: Duration::from_secs(30),
28            pong_timeout: Duration::from_secs(10),
29        }
30    }
31}
32
33/// Manages one WebSocket connection's lifecycle and message routing.
34///
35/// This struct doesn't own the WebSocket transport directly — it provides
36/// the logic layer. The transport integration (e.g. actix-web, tungstenite)
37/// calls into `WsSession` methods.
38pub struct WsSession {
39    /// Shared hub reference.
40    hub: Arc<BextHub>,
41    /// This session's subscriber ID in the hub (set after first subscribe).
42    subscriber_id: Option<u64>,
43    /// Receiver for hub events routed to this subscriber (bounded).
44    hub_receiver: Option<mpsc::Receiver<HubEvent>>,
45    /// Outbound message queue (read by the transport layer, bounded).
46    outbound: mpsc::Sender<ServerMessage>,
47    /// Outbound receiver (consumed by the transport layer).
48    outbound_rx: Option<mpsc::Receiver<ServerMessage>>,
49    /// Last time we received a Pong.
50    last_pong: Arc<Mutex<Instant>>,
51    /// Configuration.
52    config: WsSessionConfig,
53}
54
55impl WsSession {
56    /// Create a new WebSocket session.
57    ///
58    /// Call `take_outbound_receiver()` to get the stream of messages to send
59    /// to the WebSocket client.
60    pub fn new(hub: Arc<BextHub>, config: WsSessionConfig) -> Self {
61        let (outbound_tx, outbound_rx) = mpsc::channel(256);
62        Self {
63            hub,
64            subscriber_id: None,
65            hub_receiver: None,
66            outbound: outbound_tx,
67            outbound_rx: Some(outbound_rx),
68            last_pong: Arc::new(Mutex::new(Instant::now())),
69            config,
70        }
71    }
72
73    /// Take the outbound message receiver.
74    ///
75    /// The transport layer reads from this to send messages over the WebSocket.
76    /// Can only be called once.
77    pub fn take_outbound_receiver(&mut self) -> Option<mpsc::Receiver<ServerMessage>> {
78        self.outbound_rx.take()
79    }
80
81    /// Take the hub event receiver.
82    ///
83    /// The transport layer reads from this and calls `forward_hub_event` for each.
84    /// Can only be called once (after at least one subscribe).
85    pub fn take_hub_receiver(&mut self) -> Option<mpsc::Receiver<HubEvent>> {
86        self.hub_receiver.take()
87    }
88
89    /// Handle an incoming text message from the WebSocket client.
90    ///
91    /// Parses JSON into `ClientMessage` and dispatches accordingly.
92    /// Returns an error string if parsing fails.
93    pub fn handle_text(&mut self, text: &str) -> Result<(), String> {
94        let msg: ClientMessage =
95            serde_json::from_str(text).map_err(|e| format!("invalid message: {}", e))?;
96        self.handle_message(msg);
97        Ok(())
98    }
99
100    /// Handle a parsed `ClientMessage`.
101    pub fn handle_message(&mut self, msg: ClientMessage) {
102        match msg {
103            ClientMessage::Subscribe { topics } => self.handle_subscribe(topics),
104            ClientMessage::Unsubscribe { topics } => self.handle_unsubscribe(topics),
105            ClientMessage::Publish { topic, data } => self.handle_publish(topic, data),
106            ClientMessage::Pong => self.handle_pong(),
107        }
108    }
109
110    /// Forward a hub event to the WebSocket client as a `ServerMessage::Event`.
111    pub fn forward_hub_event(&self, event: HubEvent) {
112        let msg = ServerMessage::Event {
113            topic: event.topic,
114            data: event.data,
115            id: event.id,
116        };
117        let _ = self.outbound.try_send(msg);
118    }
119
120    /// Send a Ping to the client. Called periodically by the transport layer.
121    pub fn send_ping(&self) {
122        let _ = self.outbound.try_send(ServerMessage::Ping);
123    }
124
125    /// Check if the connection is alive (received a pong within timeout).
126    pub fn is_alive(&self) -> bool {
127        let last = *self.last_pong.lock();
128        last.elapsed() < self.config.heartbeat_interval + self.config.pong_timeout
129    }
130
131    /// Send an error message to the client.
132    pub fn send_error(&self, message: String) {
133        let _ = self.outbound.try_send(ServerMessage::Error { message });
134    }
135
136    /// Get the subscriber ID (if subscribed).
137    pub fn subscriber_id(&self) -> Option<u64> {
138        self.subscriber_id
139    }
140
141    /// Get the configuration.
142    pub fn config(&self) -> &WsSessionConfig {
143        &self.config
144    }
145
146    /// Clean up on disconnect — unsubscribe from the hub.
147    pub fn cleanup(&mut self) {
148        if let Some(id) = self.subscriber_id.take() {
149            self.hub.unsubscribe(id);
150            debug!(subscriber_id = id, "ws session cleaned up");
151        }
152    }
153
154    // ── Private handlers ────────────────────────────────────────────
155
156    fn handle_subscribe(&mut self, topics: Vec<String>) {
157        if topics.is_empty() {
158            self.send_error("subscribe: topics list is empty".to_string());
159            return;
160        }
161
162        if let Some(id) = self.subscriber_id {
163            // Already subscribed — add more topics
164            self.hub.add_topics(id, topics.clone());
165        } else {
166            // First subscription — register with hub
167            match self.hub.subscribe(topics.clone()) {
168                Some((id, rx)) => {
169                    self.subscriber_id = Some(id);
170                    self.hub_receiver = Some(rx);
171                    debug!(subscriber_id = id, "ws client subscribed");
172                }
173                None => {
174                    self.send_error("max connections reached".to_string());
175                    return;
176                }
177            }
178        }
179
180        let _ = self.outbound.try_send(ServerMessage::Subscribed { topics });
181    }
182
183    fn handle_unsubscribe(&mut self, topics: Vec<String>) {
184        if let Some(id) = self.subscriber_id {
185            self.hub.remove_topics(id, topics);
186        }
187    }
188
189    fn handle_publish(&self, topic: String, data: Value) {
190        self.hub.publish(&topic, data);
191    }
192
193    fn handle_pong(&self) {
194        let mut last = self.last_pong.lock();
195        *last = Instant::now();
196    }
197}
198
199impl Drop for WsSession {
200    fn drop(&mut self) {
201        self.cleanup();
202    }
203}
204
205#[cfg(test)]
206mod tests {
207    use super::*;
208    use crate::hub::{BextHub, HubConfig};
209    use serde_json::json;
210    use std::sync::Arc;
211
212    fn test_hub() -> Arc<BextHub> {
213        Arc::new(BextHub::new(HubConfig::default()))
214    }
215
216    fn test_session(hub: Arc<BextHub>) -> WsSession {
217        WsSession::new(hub, WsSessionConfig::default())
218    }
219
220    // ── Message parsing ─────────────────────────────────────────────
221
222    #[test]
223    fn handle_text_valid_subscribe() {
224        let hub = test_hub();
225        let mut session = test_session(hub);
226        let result = session.handle_text(r#"{"type":"subscribe","topics":["app/events"]}"#);
227        assert!(result.is_ok());
228        assert!(session.subscriber_id().is_some());
229    }
230
231    #[test]
232    fn handle_text_valid_pong() {
233        let hub = test_hub();
234        let mut session = test_session(hub);
235        let result = session.handle_text(r#"{"type":"pong"}"#);
236        assert!(result.is_ok());
237    }
238
239    #[test]
240    fn handle_text_invalid_json() {
241        let hub = test_hub();
242        let mut session = test_session(hub);
243        let result = session.handle_text("not json");
244        assert!(result.is_err());
245    }
246
247    #[test]
248    fn handle_text_unknown_type() {
249        let hub = test_hub();
250        let mut session = test_session(hub);
251        let result = session.handle_text(r#"{"type":"unknown"}"#);
252        assert!(result.is_err());
253    }
254
255    // ── Subscribe flow ──────────────────────────────────────────────
256
257    #[test]
258    fn subscribe_creates_subscriber() {
259        let hub = test_hub();
260        let mut session = test_session(hub.clone());
261        let mut outbound = session.take_outbound_receiver().unwrap();
262
263        session.handle_message(ClientMessage::Subscribe {
264            topics: vec!["test".to_string()],
265        });
266
267        assert!(session.subscriber_id().is_some());
268        assert_eq!(hub.subscriber_count(), 1);
269
270        // Should receive Subscribed confirmation
271        let msg = outbound.try_recv().unwrap();
272        match msg {
273            ServerMessage::Subscribed { topics } => {
274                assert_eq!(topics, vec!["test".to_string()]);
275            }
276            other => panic!("expected Subscribed, got {:?}", other),
277        }
278    }
279
280    #[test]
281    fn subscribe_empty_topics_sends_error() {
282        let hub = test_hub();
283        let mut session = test_session(hub);
284        let mut outbound = session.take_outbound_receiver().unwrap();
285
286        session.handle_message(ClientMessage::Subscribe { topics: vec![] });
287
288        assert!(session.subscriber_id().is_none());
289
290        let msg = outbound.try_recv().unwrap();
291        match msg {
292            ServerMessage::Error { message } => {
293                assert!(message.contains("empty"));
294            }
295            other => panic!("expected Error, got {:?}", other),
296        }
297    }
298
299    #[test]
300    fn subscribe_twice_adds_topics() {
301        let hub = test_hub();
302        let mut session = test_session(hub.clone());
303        let _outbound = session.take_outbound_receiver().unwrap();
304
305        session.handle_message(ClientMessage::Subscribe {
306            topics: vec!["a".to_string()],
307        });
308        let first_id = session.subscriber_id().unwrap();
309
310        session.handle_message(ClientMessage::Subscribe {
311            topics: vec!["b".to_string()],
312        });
313        // Should keep the same subscriber ID
314        assert_eq!(session.subscriber_id().unwrap(), first_id);
315        // Hub should have 2 topics
316        assert_eq!(hub.topic_count(), 2);
317    }
318
319    // ── Unsubscribe flow ────────────────────────────────────────────
320
321    #[test]
322    fn unsubscribe_removes_topics() {
323        let hub = test_hub();
324        let mut session = test_session(hub.clone());
325        let _outbound = session.take_outbound_receiver().unwrap();
326
327        session.handle_message(ClientMessage::Subscribe {
328            topics: vec!["a".to_string(), "b".to_string()],
329        });
330        assert_eq!(hub.topic_count(), 2);
331
332        session.handle_message(ClientMessage::Unsubscribe {
333            topics: vec!["a".to_string()],
334        });
335        assert_eq!(hub.topic_count(), 1);
336    }
337
338    #[test]
339    fn unsubscribe_without_subscribe_is_noop() {
340        let hub = test_hub();
341        let mut session = test_session(hub);
342        session.handle_message(ClientMessage::Unsubscribe {
343            topics: vec!["a".to_string()],
344        });
345        // Should not panic
346    }
347
348    // ── Publish flow ────────────────────────────────────────────────
349
350    #[tokio::test]
351    async fn publish_from_ws_delivers_to_other_subscribers() {
352        let hub = test_hub();
353        let mut session = test_session(hub.clone());
354        let _outbound = session.take_outbound_receiver().unwrap();
355
356        // Another subscriber listens
357        let (_id, mut rx) = hub.subscribe(vec!["chat".to_string()]).unwrap();
358
359        // Publish via ws session
360        session.handle_message(ClientMessage::Publish {
361            topic: "chat".to_string(),
362            data: json!({"text": "hello"}),
363        });
364
365        let evt = rx.recv().await.unwrap();
366        assert_eq!(evt.topic, "chat");
367        assert_eq!(evt.data, json!({"text": "hello"}));
368    }
369
370    // ── Pong / liveness ─────────────────────────────────────────────
371
372    #[test]
373    fn pong_updates_last_pong_time() {
374        let hub = test_hub();
375        let mut session = test_session(hub);
376
377        // Set last_pong to the past
378        {
379            let mut last = session.last_pong.lock();
380            *last = Instant::now() - Duration::from_secs(100);
381        }
382
383        assert!(!session.is_alive());
384
385        session.handle_message(ClientMessage::Pong);
386        assert!(session.is_alive());
387    }
388
389    #[test]
390    fn is_alive_true_initially() {
391        let hub = test_hub();
392        let session = test_session(hub);
393        assert!(session.is_alive());
394    }
395
396    // ── Ping ────────────────────────────────────────────────────────
397
398    #[test]
399    fn send_ping_queues_ping_message() {
400        let hub = test_hub();
401        let mut session = test_session(hub);
402        let mut outbound = session.take_outbound_receiver().unwrap();
403
404        session.send_ping();
405
406        let msg = outbound.try_recv().unwrap();
407        assert_eq!(msg, ServerMessage::Ping);
408    }
409
410    // ── Forward hub event ───────────────────────────────────────────
411
412    #[test]
413    fn forward_hub_event_sends_event_message() {
414        let hub = test_hub();
415        let mut session = test_session(hub);
416        let mut outbound = session.take_outbound_receiver().unwrap();
417
418        let event = HubEvent {
419            id: 5,
420            topic: "test".to_string(),
421            data: json!({"key": "val"}),
422            timestamp: chrono::Utc::now(),
423        };
424        session.forward_hub_event(event);
425
426        let msg = outbound.try_recv().unwrap();
427        match msg {
428            ServerMessage::Event { topic, data, id } => {
429                assert_eq!(topic, "test");
430                assert_eq!(data, json!({"key": "val"}));
431                assert_eq!(id, 5);
432            }
433            other => panic!("expected Event, got {:?}", other),
434        }
435    }
436
437    // ── Cleanup / Drop ──────────────────────────────────────────────
438
439    #[test]
440    fn cleanup_unsubscribes_from_hub() {
441        let hub = test_hub();
442        let mut session = test_session(hub.clone());
443        let _outbound = session.take_outbound_receiver().unwrap();
444
445        session.handle_message(ClientMessage::Subscribe {
446            topics: vec!["a".to_string()],
447        });
448        assert_eq!(hub.subscriber_count(), 1);
449
450        session.cleanup();
451        assert_eq!(hub.subscriber_count(), 0);
452        assert!(session.subscriber_id().is_none());
453    }
454
455    #[test]
456    fn drop_triggers_cleanup() {
457        let hub = test_hub();
458        {
459            let mut session = test_session(hub.clone());
460            let _outbound = session.take_outbound_receiver().unwrap();
461
462            session.handle_message(ClientMessage::Subscribe {
463                topics: vec!["a".to_string()],
464            });
465            assert_eq!(hub.subscriber_count(), 1);
466        } // session dropped here
467
468        assert_eq!(hub.subscriber_count(), 0);
469    }
470
471    // ── Max connections via WS ──────────────────────────────────────
472
473    #[test]
474    fn subscribe_at_max_connections_sends_error() {
475        let hub = Arc::new(BextHub::new(HubConfig {
476            max_connections: 1,
477            ..Default::default()
478        }));
479
480        // First session succeeds
481        let mut s1 = test_session(hub.clone());
482        let _out1 = s1.take_outbound_receiver().unwrap();
483        s1.handle_message(ClientMessage::Subscribe {
484            topics: vec!["a".to_string()],
485        });
486        assert!(s1.subscriber_id().is_some());
487
488        // Second session fails
489        let mut s2 = test_session(hub.clone());
490        let mut out2 = s2.take_outbound_receiver().unwrap();
491        s2.handle_message(ClientMessage::Subscribe {
492            topics: vec!["b".to_string()],
493        });
494        assert!(s2.subscriber_id().is_none());
495
496        let msg = out2.try_recv().unwrap();
497        match msg {
498            ServerMessage::Error { message } => {
499                assert!(message.contains("max connections"));
500            }
501            other => panic!("expected Error, got {:?}", other),
502        }
503    }
504
505    // ── Send error ──────────────────────────────────────────────────
506
507    #[test]
508    fn send_error_queues_error_message() {
509        let hub = test_hub();
510        let mut session = test_session(hub);
511        let mut outbound = session.take_outbound_receiver().unwrap();
512
513        session.send_error("test error".to_string());
514
515        let msg = outbound.try_recv().unwrap();
516        match msg {
517            ServerMessage::Error { message } => {
518                assert_eq!(message, "test error");
519            }
520            other => panic!("expected Error, got {:?}", other),
521        }
522    }
523}