use crate::models::WebSocketMessage;
use crate::MarketDataError;
use std::sync::mpsc;
use std::sync::Mutex;
use std::time::Duration;
pub struct MessageReceiver {
rx: Mutex<mpsc::Receiver<WebSocketMessage>>,
}
impl MessageReceiver {
pub fn new(rx: mpsc::Receiver<WebSocketMessage>) -> Self {
Self { rx: Mutex::new(rx) }
}
pub fn receive(&self) -> Result<WebSocketMessage, MarketDataError> {
let rx = self.rx.lock().map_err(|_| MarketDataError::ConnectionError {
msg: "Message receiver lock poisoned".to_string(),
})?;
rx.recv().map_err(|_| MarketDataError::ConnectionError {
msg: "Message channel closed".to_string(),
})
}
pub fn receive_timeout(
&self,
timeout: Duration,
) -> Result<Option<WebSocketMessage>, MarketDataError> {
let rx = self.rx.lock().map_err(|_| MarketDataError::ConnectionError {
msg: "Message receiver lock poisoned".to_string(),
})?;
match rx.recv_timeout(timeout) {
Ok(msg) => Ok(Some(msg)),
Err(mpsc::RecvTimeoutError::Timeout) => Ok(None),
Err(mpsc::RecvTimeoutError::Disconnected) => {
Err(MarketDataError::ConnectionError {
msg: "Message channel closed".to_string(),
})
}
}
}
pub fn try_receive(&self) -> Option<WebSocketMessage> {
self.rx.lock().ok()?.try_recv().ok()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_receive_blocking() {
let (tx, rx) = mpsc::channel();
let receiver = MessageReceiver::new(rx);
std::thread::spawn(move || {
std::thread::sleep(Duration::from_millis(10));
let msg = WebSocketMessage {
event: "data".to_string(),
data: None,
channel: Some("trades".to_string()),
symbol: Some("2330".to_string()),
id: None,
};
tx.send(msg).unwrap();
});
let result = receiver.receive();
assert!(result.is_ok());
let msg = result.unwrap();
assert_eq!(msg.event, "data");
assert_eq!(msg.channel, Some("trades".to_string()));
}
#[test]
fn test_receive_timeout_returns_none() {
let (_tx, rx) = mpsc::channel();
let receiver = MessageReceiver::new(rx);
let result = receiver.receive_timeout(Duration::from_millis(50));
assert!(result.is_ok());
assert!(result.unwrap().is_none());
}
#[test]
fn test_receive_timeout_returns_message() {
let (tx, rx) = mpsc::channel();
let receiver = MessageReceiver::new(rx);
let msg = WebSocketMessage {
event: "data".to_string(),
data: None,
channel: Some("trades".to_string()),
symbol: Some("2330".to_string()),
id: None,
};
tx.send(msg).unwrap();
let result = receiver.receive_timeout(Duration::from_secs(1));
assert!(result.is_ok());
let received = result.unwrap();
assert!(received.is_some());
assert_eq!(received.unwrap().event, "data");
}
#[test]
fn test_try_receive_non_blocking() {
let (tx, rx) = mpsc::channel();
let receiver = MessageReceiver::new(rx);
assert!(receiver.try_receive().is_none());
let msg = WebSocketMessage {
event: "data".to_string(),
data: None,
channel: None,
symbol: None,
id: None,
};
tx.send(msg).unwrap();
let received = receiver.try_receive();
assert!(received.is_some());
assert_eq!(received.unwrap().event, "data");
}
#[test]
fn test_channel_closed_returns_error() {
let (tx, rx) = mpsc::channel();
let receiver = MessageReceiver::new(rx);
drop(tx);
let result = receiver.receive();
assert!(result.is_err());
match result {
Err(MarketDataError::ConnectionError { msg }) => {
assert!(msg.contains("closed"));
}
_ => panic!("Expected ConnectionError"),
}
}
#[test]
fn test_channel_closed_timeout_returns_error() {
let (tx, rx) = mpsc::channel();
let receiver = MessageReceiver::new(rx);
drop(tx);
let result = receiver.receive_timeout(Duration::from_secs(1));
assert!(result.is_err());
}
#[test]
fn test_try_receive_after_close() {
let (tx, rx) = mpsc::channel();
let receiver = MessageReceiver::new(rx);
let msg = WebSocketMessage {
event: "data".to_string(),
data: None,
channel: None,
symbol: None,
id: None,
};
tx.send(msg).unwrap();
drop(tx);
let received = receiver.try_receive();
assert!(received.is_some());
let received2 = receiver.try_receive();
assert!(received2.is_none());
}
}