1use std::sync::Arc;
5use std::time::{Duration, Instant};
6
7use parking_lot::Mutex;
8use serde_json::Value;
9use tokio::sync::mpsc;
10use tracing::debug;
11
12use crate::hub::BextHub;
13use crate::message::{ClientMessage, HubEvent, ServerMessage};
14
15#[derive(Debug, Clone)]
17pub struct WsSessionConfig {
18 pub heartbeat_interval: Duration,
20 pub pong_timeout: Duration,
22}
23
24impl Default for WsSessionConfig {
25 fn default() -> Self {
26 Self {
27 heartbeat_interval: Duration::from_secs(30),
28 pong_timeout: Duration::from_secs(10),
29 }
30 }
31}
32
33pub struct WsSession {
39 hub: Arc<BextHub>,
41 subscriber_id: Option<u64>,
43 hub_receiver: Option<mpsc::Receiver<HubEvent>>,
45 outbound: mpsc::Sender<ServerMessage>,
47 outbound_rx: Option<mpsc::Receiver<ServerMessage>>,
49 last_pong: Arc<Mutex<Instant>>,
51 config: WsSessionConfig,
53}
54
55impl WsSession {
56 pub fn new(hub: Arc<BextHub>, config: WsSessionConfig) -> Self {
61 let (outbound_tx, outbound_rx) = mpsc::channel(256);
62 Self {
63 hub,
64 subscriber_id: None,
65 hub_receiver: None,
66 outbound: outbound_tx,
67 outbound_rx: Some(outbound_rx),
68 last_pong: Arc::new(Mutex::new(Instant::now())),
69 config,
70 }
71 }
72
73 pub fn take_outbound_receiver(&mut self) -> Option<mpsc::Receiver<ServerMessage>> {
78 self.outbound_rx.take()
79 }
80
81 pub fn take_hub_receiver(&mut self) -> Option<mpsc::Receiver<HubEvent>> {
86 self.hub_receiver.take()
87 }
88
89 pub fn handle_text(&mut self, text: &str) -> Result<(), String> {
94 let msg: ClientMessage =
95 serde_json::from_str(text).map_err(|e| format!("invalid message: {}", e))?;
96 self.handle_message(msg);
97 Ok(())
98 }
99
100 pub fn handle_message(&mut self, msg: ClientMessage) {
102 match msg {
103 ClientMessage::Subscribe { topics } => self.handle_subscribe(topics),
104 ClientMessage::Unsubscribe { topics } => self.handle_unsubscribe(topics),
105 ClientMessage::Publish { topic, data } => self.handle_publish(topic, data),
106 ClientMessage::Pong => self.handle_pong(),
107 }
108 }
109
110 pub fn forward_hub_event(&self, event: HubEvent) {
112 let msg = ServerMessage::Event {
113 topic: event.topic,
114 data: event.data,
115 id: event.id,
116 };
117 let _ = self.outbound.try_send(msg);
118 }
119
120 pub fn send_ping(&self) {
122 let _ = self.outbound.try_send(ServerMessage::Ping);
123 }
124
125 pub fn is_alive(&self) -> bool {
127 let last = *self.last_pong.lock();
128 last.elapsed() < self.config.heartbeat_interval + self.config.pong_timeout
129 }
130
131 pub fn send_error(&self, message: String) {
133 let _ = self.outbound.try_send(ServerMessage::Error { message });
134 }
135
136 pub fn subscriber_id(&self) -> Option<u64> {
138 self.subscriber_id
139 }
140
141 pub fn config(&self) -> &WsSessionConfig {
143 &self.config
144 }
145
146 pub fn cleanup(&mut self) {
148 if let Some(id) = self.subscriber_id.take() {
149 self.hub.unsubscribe(id);
150 debug!(subscriber_id = id, "ws session cleaned up");
151 }
152 }
153
154 fn handle_subscribe(&mut self, topics: Vec<String>) {
157 if topics.is_empty() {
158 self.send_error("subscribe: topics list is empty".to_string());
159 return;
160 }
161
162 if let Some(id) = self.subscriber_id {
163 self.hub.add_topics(id, topics.clone());
165 } else {
166 match self.hub.subscribe(topics.clone()) {
168 Some((id, rx)) => {
169 self.subscriber_id = Some(id);
170 self.hub_receiver = Some(rx);
171 debug!(subscriber_id = id, "ws client subscribed");
172 }
173 None => {
174 self.send_error("max connections reached".to_string());
175 return;
176 }
177 }
178 }
179
180 let _ = self.outbound.try_send(ServerMessage::Subscribed { topics });
181 }
182
183 fn handle_unsubscribe(&mut self, topics: Vec<String>) {
184 if let Some(id) = self.subscriber_id {
185 self.hub.remove_topics(id, topics);
186 }
187 }
188
189 fn handle_publish(&self, topic: String, data: Value) {
190 self.hub.publish(&topic, data);
191 }
192
193 fn handle_pong(&self) {
194 let mut last = self.last_pong.lock();
195 *last = Instant::now();
196 }
197}
198
199impl Drop for WsSession {
200 fn drop(&mut self) {
201 self.cleanup();
202 }
203}
204
205#[cfg(test)]
206mod tests {
207 use super::*;
208 use crate::hub::{BextHub, HubConfig};
209 use serde_json::json;
210 use std::sync::Arc;
211
212 fn test_hub() -> Arc<BextHub> {
213 Arc::new(BextHub::new(HubConfig::default()))
214 }
215
216 fn test_session(hub: Arc<BextHub>) -> WsSession {
217 WsSession::new(hub, WsSessionConfig::default())
218 }
219
220 #[test]
223 fn handle_text_valid_subscribe() {
224 let hub = test_hub();
225 let mut session = test_session(hub);
226 let result = session.handle_text(r#"{"type":"subscribe","topics":["app/events"]}"#);
227 assert!(result.is_ok());
228 assert!(session.subscriber_id().is_some());
229 }
230
231 #[test]
232 fn handle_text_valid_pong() {
233 let hub = test_hub();
234 let mut session = test_session(hub);
235 let result = session.handle_text(r#"{"type":"pong"}"#);
236 assert!(result.is_ok());
237 }
238
239 #[test]
240 fn handle_text_invalid_json() {
241 let hub = test_hub();
242 let mut session = test_session(hub);
243 let result = session.handle_text("not json");
244 assert!(result.is_err());
245 }
246
247 #[test]
248 fn handle_text_unknown_type() {
249 let hub = test_hub();
250 let mut session = test_session(hub);
251 let result = session.handle_text(r#"{"type":"unknown"}"#);
252 assert!(result.is_err());
253 }
254
255 #[test]
258 fn subscribe_creates_subscriber() {
259 let hub = test_hub();
260 let mut session = test_session(hub.clone());
261 let mut outbound = session.take_outbound_receiver().unwrap();
262
263 session.handle_message(ClientMessage::Subscribe {
264 topics: vec!["test".to_string()],
265 });
266
267 assert!(session.subscriber_id().is_some());
268 assert_eq!(hub.subscriber_count(), 1);
269
270 let msg = outbound.try_recv().unwrap();
272 match msg {
273 ServerMessage::Subscribed { topics } => {
274 assert_eq!(topics, vec!["test".to_string()]);
275 }
276 other => panic!("expected Subscribed, got {:?}", other),
277 }
278 }
279
280 #[test]
281 fn subscribe_empty_topics_sends_error() {
282 let hub = test_hub();
283 let mut session = test_session(hub);
284 let mut outbound = session.take_outbound_receiver().unwrap();
285
286 session.handle_message(ClientMessage::Subscribe { topics: vec![] });
287
288 assert!(session.subscriber_id().is_none());
289
290 let msg = outbound.try_recv().unwrap();
291 match msg {
292 ServerMessage::Error { message } => {
293 assert!(message.contains("empty"));
294 }
295 other => panic!("expected Error, got {:?}", other),
296 }
297 }
298
299 #[test]
300 fn subscribe_twice_adds_topics() {
301 let hub = test_hub();
302 let mut session = test_session(hub.clone());
303 let _outbound = session.take_outbound_receiver().unwrap();
304
305 session.handle_message(ClientMessage::Subscribe {
306 topics: vec!["a".to_string()],
307 });
308 let first_id = session.subscriber_id().unwrap();
309
310 session.handle_message(ClientMessage::Subscribe {
311 topics: vec!["b".to_string()],
312 });
313 assert_eq!(session.subscriber_id().unwrap(), first_id);
315 assert_eq!(hub.topic_count(), 2);
317 }
318
319 #[test]
322 fn unsubscribe_removes_topics() {
323 let hub = test_hub();
324 let mut session = test_session(hub.clone());
325 let _outbound = session.take_outbound_receiver().unwrap();
326
327 session.handle_message(ClientMessage::Subscribe {
328 topics: vec!["a".to_string(), "b".to_string()],
329 });
330 assert_eq!(hub.topic_count(), 2);
331
332 session.handle_message(ClientMessage::Unsubscribe {
333 topics: vec!["a".to_string()],
334 });
335 assert_eq!(hub.topic_count(), 1);
336 }
337
338 #[test]
339 fn unsubscribe_without_subscribe_is_noop() {
340 let hub = test_hub();
341 let mut session = test_session(hub);
342 session.handle_message(ClientMessage::Unsubscribe {
343 topics: vec!["a".to_string()],
344 });
345 }
347
348 #[tokio::test]
351 async fn publish_from_ws_delivers_to_other_subscribers() {
352 let hub = test_hub();
353 let mut session = test_session(hub.clone());
354 let _outbound = session.take_outbound_receiver().unwrap();
355
356 let (_id, mut rx) = hub.subscribe(vec!["chat".to_string()]).unwrap();
358
359 session.handle_message(ClientMessage::Publish {
361 topic: "chat".to_string(),
362 data: json!({"text": "hello"}),
363 });
364
365 let evt = rx.recv().await.unwrap();
366 assert_eq!(evt.topic, "chat");
367 assert_eq!(evt.data, json!({"text": "hello"}));
368 }
369
370 #[test]
373 fn pong_updates_last_pong_time() {
374 let hub = test_hub();
375 let mut session = test_session(hub);
376
377 {
379 let mut last = session.last_pong.lock();
380 *last = Instant::now() - Duration::from_secs(100);
381 }
382
383 assert!(!session.is_alive());
384
385 session.handle_message(ClientMessage::Pong);
386 assert!(session.is_alive());
387 }
388
389 #[test]
390 fn is_alive_true_initially() {
391 let hub = test_hub();
392 let session = test_session(hub);
393 assert!(session.is_alive());
394 }
395
396 #[test]
399 fn send_ping_queues_ping_message() {
400 let hub = test_hub();
401 let mut session = test_session(hub);
402 let mut outbound = session.take_outbound_receiver().unwrap();
403
404 session.send_ping();
405
406 let msg = outbound.try_recv().unwrap();
407 assert_eq!(msg, ServerMessage::Ping);
408 }
409
410 #[test]
413 fn forward_hub_event_sends_event_message() {
414 let hub = test_hub();
415 let mut session = test_session(hub);
416 let mut outbound = session.take_outbound_receiver().unwrap();
417
418 let event = HubEvent {
419 id: 5,
420 topic: "test".to_string(),
421 data: json!({"key": "val"}),
422 timestamp: chrono::Utc::now(),
423 };
424 session.forward_hub_event(event);
425
426 let msg = outbound.try_recv().unwrap();
427 match msg {
428 ServerMessage::Event { topic, data, id } => {
429 assert_eq!(topic, "test");
430 assert_eq!(data, json!({"key": "val"}));
431 assert_eq!(id, 5);
432 }
433 other => panic!("expected Event, got {:?}", other),
434 }
435 }
436
437 #[test]
440 fn cleanup_unsubscribes_from_hub() {
441 let hub = test_hub();
442 let mut session = test_session(hub.clone());
443 let _outbound = session.take_outbound_receiver().unwrap();
444
445 session.handle_message(ClientMessage::Subscribe {
446 topics: vec!["a".to_string()],
447 });
448 assert_eq!(hub.subscriber_count(), 1);
449
450 session.cleanup();
451 assert_eq!(hub.subscriber_count(), 0);
452 assert!(session.subscriber_id().is_none());
453 }
454
455 #[test]
456 fn drop_triggers_cleanup() {
457 let hub = test_hub();
458 {
459 let mut session = test_session(hub.clone());
460 let _outbound = session.take_outbound_receiver().unwrap();
461
462 session.handle_message(ClientMessage::Subscribe {
463 topics: vec!["a".to_string()],
464 });
465 assert_eq!(hub.subscriber_count(), 1);
466 } assert_eq!(hub.subscriber_count(), 0);
469 }
470
471 #[test]
474 fn subscribe_at_max_connections_sends_error() {
475 let hub = Arc::new(BextHub::new(HubConfig {
476 max_connections: 1,
477 ..Default::default()
478 }));
479
480 let mut s1 = test_session(hub.clone());
482 let _out1 = s1.take_outbound_receiver().unwrap();
483 s1.handle_message(ClientMessage::Subscribe {
484 topics: vec!["a".to_string()],
485 });
486 assert!(s1.subscriber_id().is_some());
487
488 let mut s2 = test_session(hub.clone());
490 let mut out2 = s2.take_outbound_receiver().unwrap();
491 s2.handle_message(ClientMessage::Subscribe {
492 topics: vec!["b".to_string()],
493 });
494 assert!(s2.subscriber_id().is_none());
495
496 let msg = out2.try_recv().unwrap();
497 match msg {
498 ServerMessage::Error { message } => {
499 assert!(message.contains("max connections"));
500 }
501 other => panic!("expected Error, got {:?}", other),
502 }
503 }
504
505 #[test]
508 fn send_error_queues_error_message() {
509 let hub = test_hub();
510 let mut session = test_session(hub);
511 let mut outbound = session.take_outbound_receiver().unwrap();
512
513 session.send_error("test error".to_string());
514
515 let msg = outbound.try_recv().unwrap();
516 match msg {
517 ServerMessage::Error { message } => {
518 assert_eq!(message, "test error");
519 }
520 other => panic!("expected Error, got {:?}", other),
521 }
522 }
523}