use std::sync::Arc;
use std::time::{Duration, Instant};
use parking_lot::Mutex;
use serde_json::Value;
use tokio::sync::mpsc;
use tracing::debug;
use crate::hub::BextHub;
use crate::message::{ClientMessage, HubEvent, ServerMessage};
#[derive(Debug, Clone)]
pub struct WsSessionConfig {
pub heartbeat_interval: Duration,
pub pong_timeout: Duration,
}
impl Default for WsSessionConfig {
fn default() -> Self {
Self {
heartbeat_interval: Duration::from_secs(30),
pong_timeout: Duration::from_secs(10),
}
}
}
pub struct WsSession {
hub: Arc<BextHub>,
subscriber_id: Option<u64>,
hub_receiver: Option<mpsc::Receiver<HubEvent>>,
outbound: mpsc::Sender<ServerMessage>,
outbound_rx: Option<mpsc::Receiver<ServerMessage>>,
last_pong: Arc<Mutex<Instant>>,
config: WsSessionConfig,
}
impl WsSession {
pub fn new(hub: Arc<BextHub>, config: WsSessionConfig) -> Self {
let (outbound_tx, outbound_rx) = mpsc::channel(256);
Self {
hub,
subscriber_id: None,
hub_receiver: None,
outbound: outbound_tx,
outbound_rx: Some(outbound_rx),
last_pong: Arc::new(Mutex::new(Instant::now())),
config,
}
}
pub fn take_outbound_receiver(&mut self) -> Option<mpsc::Receiver<ServerMessage>> {
self.outbound_rx.take()
}
pub fn take_hub_receiver(&mut self) -> Option<mpsc::Receiver<HubEvent>> {
self.hub_receiver.take()
}
pub fn handle_text(&mut self, text: &str) -> Result<(), String> {
let msg: ClientMessage =
serde_json::from_str(text).map_err(|e| format!("invalid message: {}", e))?;
self.handle_message(msg);
Ok(())
}
pub fn handle_message(&mut self, msg: ClientMessage) {
match msg {
ClientMessage::Subscribe { topics } => self.handle_subscribe(topics),
ClientMessage::Unsubscribe { topics } => self.handle_unsubscribe(topics),
ClientMessage::Publish { topic, data } => self.handle_publish(topic, data),
ClientMessage::Pong => self.handle_pong(),
}
}
pub fn forward_hub_event(&self, event: HubEvent) {
let msg = ServerMessage::Event {
topic: event.topic,
data: event.data,
id: event.id,
};
let _ = self.outbound.try_send(msg);
}
pub fn send_ping(&self) {
let _ = self.outbound.try_send(ServerMessage::Ping);
}
pub fn is_alive(&self) -> bool {
let last = *self.last_pong.lock();
last.elapsed() < self.config.heartbeat_interval + self.config.pong_timeout
}
pub fn send_error(&self, message: String) {
let _ = self.outbound.try_send(ServerMessage::Error { message });
}
pub fn subscriber_id(&self) -> Option<u64> {
self.subscriber_id
}
pub fn config(&self) -> &WsSessionConfig {
&self.config
}
pub fn cleanup(&mut self) {
if let Some(id) = self.subscriber_id.take() {
self.hub.unsubscribe(id);
debug!(subscriber_id = id, "ws session cleaned up");
}
}
fn handle_subscribe(&mut self, topics: Vec<String>) {
if topics.is_empty() {
self.send_error("subscribe: topics list is empty".to_string());
return;
}
if let Some(id) = self.subscriber_id {
self.hub.add_topics(id, topics.clone());
} else {
match self.hub.subscribe(topics.clone()) {
Some((id, rx)) => {
self.subscriber_id = Some(id);
self.hub_receiver = Some(rx);
debug!(subscriber_id = id, "ws client subscribed");
}
None => {
self.send_error("max connections reached".to_string());
return;
}
}
}
let _ = self.outbound.try_send(ServerMessage::Subscribed { topics });
}
fn handle_unsubscribe(&mut self, topics: Vec<String>) {
if let Some(id) = self.subscriber_id {
self.hub.remove_topics(id, topics);
}
}
fn handle_publish(&self, topic: String, data: Value) {
self.hub.publish(&topic, data);
}
fn handle_pong(&self) {
let mut last = self.last_pong.lock();
*last = Instant::now();
}
}
impl Drop for WsSession {
fn drop(&mut self) {
self.cleanup();
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::hub::{BextHub, HubConfig};
use serde_json::json;
use std::sync::Arc;
fn test_hub() -> Arc<BextHub> {
Arc::new(BextHub::new(HubConfig::default()))
}
fn test_session(hub: Arc<BextHub>) -> WsSession {
WsSession::new(hub, WsSessionConfig::default())
}
#[test]
fn handle_text_valid_subscribe() {
let hub = test_hub();
let mut session = test_session(hub);
let result = session.handle_text(r#"{"type":"subscribe","topics":["app/events"]}"#);
assert!(result.is_ok());
assert!(session.subscriber_id().is_some());
}
#[test]
fn handle_text_valid_pong() {
let hub = test_hub();
let mut session = test_session(hub);
let result = session.handle_text(r#"{"type":"pong"}"#);
assert!(result.is_ok());
}
#[test]
fn handle_text_invalid_json() {
let hub = test_hub();
let mut session = test_session(hub);
let result = session.handle_text("not json");
assert!(result.is_err());
}
#[test]
fn handle_text_unknown_type() {
let hub = test_hub();
let mut session = test_session(hub);
let result = session.handle_text(r#"{"type":"unknown"}"#);
assert!(result.is_err());
}
#[test]
fn subscribe_creates_subscriber() {
let hub = test_hub();
let mut session = test_session(hub.clone());
let mut outbound = session.take_outbound_receiver().unwrap();
session.handle_message(ClientMessage::Subscribe {
topics: vec!["test".to_string()],
});
assert!(session.subscriber_id().is_some());
assert_eq!(hub.subscriber_count(), 1);
let msg = outbound.try_recv().unwrap();
match msg {
ServerMessage::Subscribed { topics } => {
assert_eq!(topics, vec!["test".to_string()]);
}
other => panic!("expected Subscribed, got {:?}", other),
}
}
#[test]
fn subscribe_empty_topics_sends_error() {
let hub = test_hub();
let mut session = test_session(hub);
let mut outbound = session.take_outbound_receiver().unwrap();
session.handle_message(ClientMessage::Subscribe { topics: vec![] });
assert!(session.subscriber_id().is_none());
let msg = outbound.try_recv().unwrap();
match msg {
ServerMessage::Error { message } => {
assert!(message.contains("empty"));
}
other => panic!("expected Error, got {:?}", other),
}
}
#[test]
fn subscribe_twice_adds_topics() {
let hub = test_hub();
let mut session = test_session(hub.clone());
let _outbound = session.take_outbound_receiver().unwrap();
session.handle_message(ClientMessage::Subscribe {
topics: vec!["a".to_string()],
});
let first_id = session.subscriber_id().unwrap();
session.handle_message(ClientMessage::Subscribe {
topics: vec!["b".to_string()],
});
assert_eq!(session.subscriber_id().unwrap(), first_id);
assert_eq!(hub.topic_count(), 2);
}
#[test]
fn unsubscribe_removes_topics() {
let hub = test_hub();
let mut session = test_session(hub.clone());
let _outbound = session.take_outbound_receiver().unwrap();
session.handle_message(ClientMessage::Subscribe {
topics: vec!["a".to_string(), "b".to_string()],
});
assert_eq!(hub.topic_count(), 2);
session.handle_message(ClientMessage::Unsubscribe {
topics: vec!["a".to_string()],
});
assert_eq!(hub.topic_count(), 1);
}
#[test]
fn unsubscribe_without_subscribe_is_noop() {
let hub = test_hub();
let mut session = test_session(hub);
session.handle_message(ClientMessage::Unsubscribe {
topics: vec!["a".to_string()],
});
}
#[tokio::test]
async fn publish_from_ws_delivers_to_other_subscribers() {
let hub = test_hub();
let mut session = test_session(hub.clone());
let _outbound = session.take_outbound_receiver().unwrap();
let (_id, mut rx) = hub.subscribe(vec!["chat".to_string()]).unwrap();
session.handle_message(ClientMessage::Publish {
topic: "chat".to_string(),
data: json!({"text": "hello"}),
});
let evt = rx.recv().await.unwrap();
assert_eq!(evt.topic, "chat");
assert_eq!(evt.data, json!({"text": "hello"}));
}
#[test]
fn pong_updates_last_pong_time() {
let hub = test_hub();
let mut session = test_session(hub);
{
let mut last = session.last_pong.lock();
*last = Instant::now() - Duration::from_secs(100);
}
assert!(!session.is_alive());
session.handle_message(ClientMessage::Pong);
assert!(session.is_alive());
}
#[test]
fn is_alive_true_initially() {
let hub = test_hub();
let session = test_session(hub);
assert!(session.is_alive());
}
#[test]
fn send_ping_queues_ping_message() {
let hub = test_hub();
let mut session = test_session(hub);
let mut outbound = session.take_outbound_receiver().unwrap();
session.send_ping();
let msg = outbound.try_recv().unwrap();
assert_eq!(msg, ServerMessage::Ping);
}
#[test]
fn forward_hub_event_sends_event_message() {
let hub = test_hub();
let mut session = test_session(hub);
let mut outbound = session.take_outbound_receiver().unwrap();
let event = HubEvent {
id: 5,
topic: "test".to_string(),
data: json!({"key": "val"}),
timestamp: chrono::Utc::now(),
};
session.forward_hub_event(event);
let msg = outbound.try_recv().unwrap();
match msg {
ServerMessage::Event { topic, data, id } => {
assert_eq!(topic, "test");
assert_eq!(data, json!({"key": "val"}));
assert_eq!(id, 5);
}
other => panic!("expected Event, got {:?}", other),
}
}
#[test]
fn cleanup_unsubscribes_from_hub() {
let hub = test_hub();
let mut session = test_session(hub.clone());
let _outbound = session.take_outbound_receiver().unwrap();
session.handle_message(ClientMessage::Subscribe {
topics: vec!["a".to_string()],
});
assert_eq!(hub.subscriber_count(), 1);
session.cleanup();
assert_eq!(hub.subscriber_count(), 0);
assert!(session.subscriber_id().is_none());
}
#[test]
fn drop_triggers_cleanup() {
let hub = test_hub();
{
let mut session = test_session(hub.clone());
let _outbound = session.take_outbound_receiver().unwrap();
session.handle_message(ClientMessage::Subscribe {
topics: vec!["a".to_string()],
});
assert_eq!(hub.subscriber_count(), 1);
}
assert_eq!(hub.subscriber_count(), 0);
}
#[test]
fn subscribe_at_max_connections_sends_error() {
let hub = Arc::new(BextHub::new(HubConfig {
max_connections: 1,
..Default::default()
}));
let mut s1 = test_session(hub.clone());
let _out1 = s1.take_outbound_receiver().unwrap();
s1.handle_message(ClientMessage::Subscribe {
topics: vec!["a".to_string()],
});
assert!(s1.subscriber_id().is_some());
let mut s2 = test_session(hub.clone());
let mut out2 = s2.take_outbound_receiver().unwrap();
s2.handle_message(ClientMessage::Subscribe {
topics: vec!["b".to_string()],
});
assert!(s2.subscriber_id().is_none());
let msg = out2.try_recv().unwrap();
match msg {
ServerMessage::Error { message } => {
assert!(message.contains("max connections"));
}
other => panic!("expected Error, got {:?}", other),
}
}
#[test]
fn send_error_queues_error_message() {
let hub = test_hub();
let mut session = test_session(hub);
let mut outbound = session.take_outbound_receiver().unwrap();
session.send_error("test error".to_string());
let msg = outbound.try_recv().unwrap();
match msg {
ServerMessage::Error { message } => {
assert_eq!(message, "test error");
}
other => panic!("expected Error, got {:?}", other),
}
}
}