use crate::models::WebSocketMessage;
use crate::websocket::{ConnectionEvent, SubscriptionManager};
use crate::MarketDataError;
use futures_util::stream::SplitStream;
use futures_util::StreamExt;
use std::sync::mpsc;
use std::sync::{Arc, Mutex};
use std::time::Duration;
use tokio::net::TcpStream;
use tokio_tungstenite::{MaybeTlsStream, WebSocketStream};
use tokio_tungstenite::tungstenite::Message;
type WsStream = SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>;
pub struct MessageReceiver {
rx: Mutex<mpsc::Receiver<WebSocketMessage>>,
}
impl MessageReceiver {
pub fn new(rx: mpsc::Receiver<WebSocketMessage>) -> Self {
Self { rx: Mutex::new(rx) }
}
pub fn receive(&self) -> Result<WebSocketMessage, MarketDataError> {
let rx = self.rx.lock().map_err(|_| MarketDataError::ConnectionError {
msg: "Message receiver lock poisoned".to_string(),
})?;
rx.recv().map_err(|_| MarketDataError::ConnectionError {
msg: "Message channel closed".to_string(),
})
}
pub fn receive_timeout(
&self,
timeout: Duration,
) -> Result<Option<WebSocketMessage>, MarketDataError> {
let rx = self.rx.lock().map_err(|_| MarketDataError::ConnectionError {
msg: "Message receiver lock poisoned".to_string(),
})?;
match rx.recv_timeout(timeout) {
Ok(msg) => Ok(Some(msg)),
Err(mpsc::RecvTimeoutError::Timeout) => Ok(None),
Err(mpsc::RecvTimeoutError::Disconnected) => {
Err(MarketDataError::ConnectionError {
msg: "Message channel closed".to_string(),
})
}
}
}
pub fn try_receive(&self) -> Option<WebSocketMessage> {
self.rx.lock().ok()?.try_recv().ok()
}
}
pub(crate) async fn dispatch_messages(
mut ws_read: WsStream,
message_tx: mpsc::Sender<WebSocketMessage>,
event_tx: mpsc::Sender<ConnectionEvent>,
heartbeat_timeout: Option<Duration>,
subscriptions: Arc<SubscriptionManager>,
) -> Option<u16> {
loop {
let frame_result = match heartbeat_timeout {
Some(timeout) => match tokio::time::timeout(timeout, ws_read.next()).await {
Ok(opt) => opt,
Err(_elapsed) => {
let _ = event_tx.send(ConnectionEvent::HeartbeatTimeout {
elapsed: timeout,
});
return None;
}
},
None => ws_read.next().await,
};
let msg_result = match frame_result {
Some(r) => r,
None => {
let _ = event_tx.send(ConnectionEvent::Disconnected {
code: None,
reason: "Connection closed".to_string(),
});
return None;
}
};
match msg_result {
Ok(Message::Text(text)) => {
match serde_json::from_str::<WebSocketMessage>(&text) {
Ok(ws_msg) => {
handle_subscribed_event(&subscriptions, &ws_msg);
if message_tx.send(ws_msg).is_err() {
return None;
}
}
Err(e) => {
let _ = event_tx.send(ConnectionEvent::Error {
message: format!("Failed to deserialize message: {}", e),
code: 2003,
});
}
}
}
Ok(Message::Binary(data)) => {
match serde_json::from_slice::<WebSocketMessage>(&data) {
Ok(ws_msg) => {
handle_subscribed_event(&subscriptions, &ws_msg);
if message_tx.send(ws_msg).is_err() {
return None;
}
}
Err(e) => {
let _ = event_tx.send(ConnectionEvent::Error {
message: format!("Failed to deserialize binary message: {}", e),
code: 2003,
});
}
}
}
Ok(Message::Pong(_)) => {
}
Ok(Message::Close(close_frame)) => {
let code = close_frame.as_ref().map(|cf| cf.code.into());
let reason = close_frame
.as_ref()
.map(|cf| cf.reason.to_string())
.unwrap_or_else(|| "Server initiated close".to_string());
let _ = event_tx.send(ConnectionEvent::Disconnected {
code,
reason,
});
return code;
}
Ok(Message::Ping(_)) => {
}
Err(e) => {
let _ = event_tx.send(ConnectionEvent::Error {
message: format!("WebSocket error: {}", e),
code: 2001,
});
return None;
}
Ok(Message::Frame(_)) => {
}
}
}
}
fn build_sub_key(channel: &str, symbol: &str, after_hours: bool, odd_lot: bool) -> String {
let base = format!("{}:{}", channel, symbol);
if after_hours {
format!("{base}:afterhours")
} else if odd_lot {
format!("{base}:oddlot")
} else {
base
}
}
pub(crate) fn handle_subscribed_event(
subscriptions: &SubscriptionManager,
msg: &WebSocketMessage,
) {
if msg.event != "subscribed" {
return;
}
if let Some(arr) = msg.data.as_ref().and_then(|d| d.as_array()) {
for entry in arr {
let Some(id) = entry.get("id").and_then(|v| v.as_str()) else {
continue;
};
let Some(channel) = entry.get("channel").and_then(|v| v.as_str()) else {
continue;
};
let Some(symbol) = entry.get("symbol").and_then(|v| v.as_str()) else {
continue;
};
let after_hours = entry
.get("afterHours")
.and_then(|v| v.as_bool())
.unwrap_or(false);
let odd_lot = entry
.get("intradayOddLot")
.and_then(|v| v.as_bool())
.unwrap_or(false);
subscriptions.record_server_id(
build_sub_key(channel, symbol, after_hours, odd_lot),
id.to_string(),
);
}
return;
}
let data_obj = msg.data.as_ref().and_then(|d| d.as_object());
let id = data_obj
.and_then(|d| d.get("id"))
.and_then(|v| v.as_str())
.map(String::from)
.or_else(|| msg.id.clone());
let channel = data_obj
.and_then(|d| d.get("channel"))
.and_then(|v| v.as_str())
.map(String::from)
.or_else(|| msg.channel.clone());
let symbol = data_obj
.and_then(|d| d.get("symbol"))
.and_then(|v| v.as_str())
.map(String::from)
.or_else(|| msg.symbol.clone());
let after_hours = data_obj
.and_then(|d| d.get("afterHours"))
.and_then(|v| v.as_bool())
.unwrap_or(false);
let odd_lot = data_obj
.and_then(|d| d.get("intradayOddLot"))
.and_then(|v| v.as_bool())
.unwrap_or(false);
if let (Some(id), Some(channel), Some(symbol)) = (id, channel, symbol) {
subscriptions.record_server_id(
build_sub_key(&channel, &symbol, after_hours, odd_lot),
id,
);
}
}
#[allow(dead_code)] pub(crate) async fn send_pings(
mut ws_sink: futures_util::stream::SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>,
ping_rx: mpsc::Receiver<()>,
) {
use futures_util::SinkExt;
while ping_rx.recv().is_ok() {
if ws_sink.send(Message::Ping(vec![].into())).await.is_err() {
break;
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_receive_blocking() {
let (tx, rx) = mpsc::channel();
let receiver = MessageReceiver::new(rx);
std::thread::spawn(move || {
std::thread::sleep(Duration::from_millis(10));
let msg = WebSocketMessage {
event: "data".to_string(),
data: None,
channel: Some("trades".to_string()),
symbol: Some("2330".to_string()),
id: None,
};
tx.send(msg).unwrap();
});
let result = receiver.receive();
assert!(result.is_ok());
let msg = result.unwrap();
assert_eq!(msg.event, "data");
assert_eq!(msg.channel, Some("trades".to_string()));
}
#[test]
fn test_receive_timeout_returns_none() {
let (_tx, rx) = mpsc::channel();
let receiver = MessageReceiver::new(rx);
let result = receiver.receive_timeout(Duration::from_millis(50));
assert!(result.is_ok());
assert!(result.unwrap().is_none());
}
#[test]
fn test_receive_timeout_returns_message() {
let (tx, rx) = mpsc::channel();
let receiver = MessageReceiver::new(rx);
let msg = WebSocketMessage {
event: "data".to_string(),
data: None,
channel: Some("trades".to_string()),
symbol: Some("2330".to_string()),
id: None,
};
tx.send(msg).unwrap();
let result = receiver.receive_timeout(Duration::from_secs(1));
assert!(result.is_ok());
let received = result.unwrap();
assert!(received.is_some());
assert_eq!(received.unwrap().event, "data");
}
#[test]
fn test_try_receive_non_blocking() {
let (tx, rx) = mpsc::channel();
let receiver = MessageReceiver::new(rx);
assert!(receiver.try_receive().is_none());
let msg = WebSocketMessage {
event: "data".to_string(),
data: None,
channel: None,
symbol: None,
id: None,
};
tx.send(msg).unwrap();
let received = receiver.try_receive();
assert!(received.is_some());
assert_eq!(received.unwrap().event, "data");
}
#[test]
fn test_channel_closed_returns_error() {
let (tx, rx) = mpsc::channel();
let receiver = MessageReceiver::new(rx);
drop(tx);
let result = receiver.receive();
assert!(result.is_err());
match result {
Err(MarketDataError::ConnectionError { msg }) => {
assert!(msg.contains("closed"));
}
_ => panic!("Expected ConnectionError"),
}
}
#[test]
fn test_channel_closed_timeout_returns_error() {
let (tx, rx) = mpsc::channel();
let receiver = MessageReceiver::new(rx);
drop(tx);
let result = receiver.receive_timeout(Duration::from_secs(1));
assert!(result.is_err());
}
#[test]
fn test_try_receive_after_close() {
let (tx, rx) = mpsc::channel();
let receiver = MessageReceiver::new(rx);
let msg = WebSocketMessage {
event: "data".to_string(),
data: None,
channel: None,
symbol: None,
id: None,
};
tx.send(msg).unwrap();
drop(tx);
let received = receiver.try_receive();
assert!(received.is_some());
let received2 = receiver.try_receive();
assert!(received2.is_none());
}
fn parse_msg(json: &str) -> WebSocketMessage {
serde_json::from_str(json).unwrap()
}
#[test]
fn test_handle_subscribed_ignores_non_subscribed() {
let manager = SubscriptionManager::new();
let msg = parse_msg(
r#"{"event":"data","id":"sub-1","channel":"trades","symbol":"2330"}"#,
);
handle_subscribed_event(&manager, &msg);
assert!(manager.take_server_id("trades:2330").is_none());
}
#[test]
fn test_handle_subscribed_single_top_level() {
let manager = SubscriptionManager::new();
let msg = parse_msg(
r#"{"event":"subscribed","id":"sub-abc","channel":"trades","symbol":"2330"}"#,
);
handle_subscribed_event(&manager, &msg);
assert_eq!(
manager.take_server_id("trades:2330"),
Some("sub-abc".to_string())
);
}
#[test]
fn test_handle_subscribed_single_with_after_hours() {
let manager = SubscriptionManager::new();
let msg = parse_msg(
r#"{"event":"subscribed","data":{"id":"sub-ah","channel":"books","symbol":"TXFE6","afterHours":true}}"#,
);
handle_subscribed_event(&manager, &msg);
assert_eq!(
manager.take_server_id("books:TXFE6:afterhours"),
Some("sub-ah".to_string())
);
assert!(manager.take_server_id("books:TXFE6").is_none());
}
#[test]
fn test_handle_subscribed_single_with_odd_lot() {
let manager = SubscriptionManager::new();
let msg = parse_msg(
r#"{"event":"subscribed","data":{"id":"sub-odd","channel":"trades","symbol":"2330","intradayOddLot":true}}"#,
);
handle_subscribed_event(&manager, &msg);
assert_eq!(
manager.take_server_id("trades:2330:oddlot"),
Some("sub-odd".to_string())
);
}
#[test]
fn test_handle_subscribed_batched_array() {
let manager = SubscriptionManager::new();
let msg = parse_msg(
r#"{"event":"subscribed","data":[
{"id":"sub-1","channel":"trades","symbol":"2330"},
{"id":"sub-2","channel":"books","symbol":"TXFE6","afterHours":true},
{"id":"sub-3","channel":"trades","symbol":"2317","intradayOddLot":true}
]}"#,
);
handle_subscribed_event(&manager, &msg);
assert_eq!(manager.take_server_id("trades:2330"), Some("sub-1".into()));
assert_eq!(
manager.take_server_id("books:TXFE6:afterhours"),
Some("sub-2".into())
);
assert_eq!(
manager.take_server_id("trades:2317:oddlot"),
Some("sub-3".into())
);
}
#[test]
fn test_handle_subscribed_missing_fields_no_op() {
let manager = SubscriptionManager::new();
let msg = parse_msg(r#"{"event":"subscribed","symbol":"2330"}"#);
handle_subscribed_event(&manager, &msg);
assert!(manager.take_server_id("trades:2330").is_none());
}
}