use std::time::Duration;
use futures_util::{SinkExt, StreamExt};
use tokio::sync::mpsc;
use tokio::time::timeout;
use tokio_tungstenite::{connect_async, tungstenite::protocol::Message};
use whisky::WError;
use crate::client::stream::ReconnectConfig;
use crate::responses::stream::{
MarketDepthMessage, MarketPriceMessage, MarketStreamMessage, OhlcMessage, Trade,
};
#[derive(Debug, Clone)]
pub enum MarketStreamEvent {
Message(MarketStreamMessage),
Connected,
Reconnecting {
attempt: u32,
delay_ms: u64,
},
Disconnected {
reason: String,
},
MaxRetriesExceeded,
}
#[derive(Debug)]
pub struct MarketStreamHandle {
close_tx: Option<mpsc::Sender<()>>,
}
impl MarketStreamHandle {
pub async fn close(&mut self) {
if let Some(tx) = self.close_tx.take() {
let _ = tx.send(()).await;
}
}
pub fn is_active(&self) -> bool {
self.close_tx.is_some()
}
}
impl Drop for MarketStreamHandle {
fn drop(&mut self) {
if let Some(tx) = self.close_tx.take() {
let _ = tx.try_send(());
}
}
}
enum ConnectionResult {
UserClosed,
Error(String, bool),
ReceiverDropped,
}
pub struct MarketStream {
ws_url: String,
}
impl MarketStream {
pub fn new(ws_url: String) -> Self {
MarketStream { ws_url }
}
pub async fn subscribe_depth(
&self,
symbol: &str,
buffer_size: Option<usize>,
) -> Result<(MarketStreamHandle, mpsc::Receiver<MarketStreamMessage>), WError> {
let ws_endpoint = format!("{}/market/ws/depth/{}", self.ws_url, symbol);
self.connect_and_stream(ws_endpoint, buffer_size, |json| {
match serde_json::from_str::<MarketDepthMessage>(json) {
Ok(msg) => MarketStreamMessage::Depth(msg),
Err(_) => MarketStreamMessage::Unknown(json.to_string()),
}
})
.await
}
pub async fn subscribe_price(
&self,
symbol: &str,
buffer_size: Option<usize>,
) -> Result<(MarketStreamHandle, mpsc::Receiver<MarketStreamMessage>), WError> {
let ws_endpoint = format!("{}/market/ws/market-price/{}", self.ws_url, symbol);
self.connect_and_stream(ws_endpoint, buffer_size, |json| {
match serde_json::from_str::<MarketPriceMessage>(json) {
Ok(msg) => MarketStreamMessage::Price(msg),
Err(_) => MarketStreamMessage::Unknown(json.to_string()),
}
})
.await
}
pub async fn subscribe_recent_trades(
&self,
symbol: &str,
buffer_size: Option<usize>,
) -> Result<(MarketStreamHandle, mpsc::Receiver<MarketStreamMessage>), WError> {
let ws_endpoint = format!("{}/market/ws/recent-trade/{}", self.ws_url, symbol);
self.connect_and_stream(ws_endpoint, buffer_size, |json| {
match serde_json::from_str::<Vec<Trade>>(json) {
Ok(trades) => MarketStreamMessage::RecentTrades(trades),
Err(_) => MarketStreamMessage::Unknown(json.to_string()),
}
})
.await
}
pub async fn subscribe_ohlc(
&self,
symbol: &str,
interval: &str,
buffer_size: Option<usize>,
) -> Result<(MarketStreamHandle, mpsc::Receiver<MarketStreamMessage>), WError> {
let ws_endpoint = format!("{}/market/ws/graph/{}/{}", self.ws_url, symbol, interval);
self.connect_and_stream(ws_endpoint, buffer_size, |json| {
match serde_json::from_str::<OhlcMessage>(json) {
Ok(msg) => MarketStreamMessage::Ohlc(msg),
Err(_) => MarketStreamMessage::Unknown(json.to_string()),
}
})
.await
}
async fn connect_and_stream<F>(
&self,
ws_endpoint: String,
buffer_size: Option<usize>,
parser: F,
) -> Result<(MarketStreamHandle, mpsc::Receiver<MarketStreamMessage>), WError>
where
F: Fn(&str) -> MarketStreamMessage + Send + 'static,
{
let buffer = buffer_size.unwrap_or(100);
let (message_tx, message_rx) = mpsc::channel::<MarketStreamMessage>(buffer);
let (close_tx, mut close_rx) = mpsc::channel::<()>(1);
let connect_timeout = Duration::from_secs(30);
let (ws_stream, _response) = timeout(connect_timeout, connect_async(&ws_endpoint))
.await
.map_err(|_| WError::new("MarketStream", "Connection timeout"))?
.map_err(|e| WError::new("MarketStream", &format!("Connection failed: {}", e)))?;
let (mut write, mut read) = ws_stream.split();
tokio::spawn(async move {
loop {
tokio::select! {
_ = close_rx.recv() => {
let _ = write.send(Message::Close(None)).await;
break;
}
msg = read.next() => {
match msg {
Some(Ok(Message::Text(text))) => {
let stream_msg = parser(&text);
if message_tx.send(stream_msg).await.is_err() {
let _ = write.send(Message::Close(None)).await;
break;
}
}
Some(Ok(Message::Ping(data))) => {
if write.send(Message::Pong(data)).await.is_err() {
break;
}
}
Some(Ok(Message::Close(_))) => {
break;
}
Some(Err(_)) => {
break;
}
None => {
break;
}
_ => {
}
}
}
}
}
});
let handle = MarketStreamHandle {
close_tx: Some(close_tx),
};
Ok((handle, message_rx))
}
pub async fn subscribe_depth_with_reconnect(
&self,
symbol: &str,
buffer_size: Option<usize>,
reconnect_config: Option<ReconnectConfig>,
) -> Result<(MarketStreamHandle, mpsc::Receiver<MarketStreamEvent>), WError> {
let ws_endpoint = format!("{}/market/ws/depth/{}", self.ws_url, symbol);
self.connect_and_stream_with_reconnect(ws_endpoint, buffer_size, reconnect_config, |json| {
match serde_json::from_str::<MarketDepthMessage>(json) {
Ok(msg) => MarketStreamMessage::Depth(msg),
Err(_) => MarketStreamMessage::Unknown(json.to_string()),
}
})
.await
}
pub async fn subscribe_price_with_reconnect(
&self,
symbol: &str,
buffer_size: Option<usize>,
reconnect_config: Option<ReconnectConfig>,
) -> Result<(MarketStreamHandle, mpsc::Receiver<MarketStreamEvent>), WError> {
let ws_endpoint = format!("{}/market/ws/market-price/{}", self.ws_url, symbol);
self.connect_and_stream_with_reconnect(ws_endpoint, buffer_size, reconnect_config, |json| {
match serde_json::from_str::<MarketPriceMessage>(json) {
Ok(msg) => MarketStreamMessage::Price(msg),
Err(_) => MarketStreamMessage::Unknown(json.to_string()),
}
})
.await
}
pub async fn subscribe_recent_trades_with_reconnect(
&self,
symbol: &str,
buffer_size: Option<usize>,
reconnect_config: Option<ReconnectConfig>,
) -> Result<(MarketStreamHandle, mpsc::Receiver<MarketStreamEvent>), WError> {
let ws_endpoint = format!("{}/market/ws/recent-trade/{}", self.ws_url, symbol);
self.connect_and_stream_with_reconnect(ws_endpoint, buffer_size, reconnect_config, |json| {
match serde_json::from_str::<Vec<Trade>>(json) {
Ok(trades) => MarketStreamMessage::RecentTrades(trades),
Err(_) => MarketStreamMessage::Unknown(json.to_string()),
}
})
.await
}
pub async fn subscribe_ohlc_with_reconnect(
&self,
symbol: &str,
interval: &str,
buffer_size: Option<usize>,
reconnect_config: Option<ReconnectConfig>,
) -> Result<(MarketStreamHandle, mpsc::Receiver<MarketStreamEvent>), WError> {
let ws_endpoint = format!("{}/market/ws/graph/{}/{}", self.ws_url, symbol, interval);
self.connect_and_stream_with_reconnect(ws_endpoint, buffer_size, reconnect_config, |json| {
match serde_json::from_str::<OhlcMessage>(json) {
Ok(msg) => MarketStreamMessage::Ohlc(msg),
Err(_) => MarketStreamMessage::Unknown(json.to_string()),
}
})
.await
}
async fn connect_and_stream_with_reconnect<F>(
&self,
ws_endpoint: String,
buffer_size: Option<usize>,
reconnect_config: Option<ReconnectConfig>,
parser: F,
) -> Result<(MarketStreamHandle, mpsc::Receiver<MarketStreamEvent>), WError>
where
F: Fn(&str) -> MarketStreamMessage + Send + Sync + 'static,
{
let buffer = buffer_size.unwrap_or(100);
let config = reconnect_config.unwrap_or_default();
let (event_tx, event_rx) = mpsc::channel::<MarketStreamEvent>(buffer);
let (close_tx, close_rx) = mpsc::channel::<()>(1);
tokio::spawn(Self::run_reconnecting_stream(
ws_endpoint,
config,
event_tx,
close_rx,
parser,
));
let handle = MarketStreamHandle {
close_tx: Some(close_tx),
};
Ok((handle, event_rx))
}
async fn run_single_connection<F>(
ws_endpoint: &str,
connect_timeout: Duration,
event_tx: &mpsc::Sender<MarketStreamEvent>,
close_rx: &mut mpsc::Receiver<()>,
parser: &F,
) -> ConnectionResult
where
F: Fn(&str) -> MarketStreamMessage,
{
let connect_result = timeout(connect_timeout, connect_async(ws_endpoint)).await;
let ws_stream = match connect_result {
Ok(Ok((stream, _response))) => stream,
Ok(Err(e)) => {
return ConnectionResult::Error(format!("Connection failed: {}", e), false);
}
Err(_) => {
return ConnectionResult::Error("Connection timeout".to_string(), false);
}
};
if event_tx.send(MarketStreamEvent::Connected).await.is_err() {
return ConnectionResult::ReceiverDropped;
}
let (mut write, mut read) = ws_stream.split();
loop {
tokio::select! {
_ = close_rx.recv() => {
let _ = write.send(Message::Close(None)).await;
return ConnectionResult::UserClosed;
}
msg = read.next() => {
match msg {
Some(Ok(Message::Text(text))) => {
let stream_msg = parser(&text);
let event = MarketStreamEvent::Message(stream_msg);
if event_tx.send(event).await.is_err() {
let _ = write.send(Message::Close(None)).await;
return ConnectionResult::ReceiverDropped;
}
}
Some(Ok(Message::Ping(data))) => {
if write.send(Message::Pong(data)).await.is_err() {
return ConnectionResult::Error("Failed to send pong".to_string(), true);
}
}
Some(Ok(Message::Close(frame))) => {
let reason = frame
.map(|f| f.reason.to_string())
.unwrap_or_else(|| "Server closed connection".to_string());
return ConnectionResult::Error(reason, true);
}
Some(Err(e)) => {
return ConnectionResult::Error(format!("WebSocket error: {}", e), true);
}
None => {
return ConnectionResult::Error("Stream ended unexpectedly".to_string(), true);
}
_ => {
}
}
}
}
}
}
async fn run_reconnecting_stream<F>(
ws_endpoint: String,
config: ReconnectConfig,
event_tx: mpsc::Sender<MarketStreamEvent>,
mut close_rx: mpsc::Receiver<()>,
parser: F,
) where
F: Fn(&str) -> MarketStreamMessage + Send + Sync + 'static,
{
let connect_timeout = Duration::from_millis(config.connect_timeout_ms);
let mut attempt: u32 = 0;
loop {
if close_rx.try_recv().is_ok() {
break;
}
let result = Self::run_single_connection(
&ws_endpoint,
connect_timeout,
&event_tx,
&mut close_rx,
&parser,
)
.await;
match result {
ConnectionResult::UserClosed => {
break;
}
ConnectionResult::ReceiverDropped => {
break;
}
ConnectionResult::Error(reason, was_connected) => {
if was_connected {
attempt = 0;
}
let _ = event_tx
.send(MarketStreamEvent::Disconnected {
reason: reason.clone(),
})
.await;
if !config.should_retry(attempt) {
let _ = event_tx.send(MarketStreamEvent::MaxRetriesExceeded).await;
break;
}
let delay_ms = config.delay_for_attempt(attempt);
attempt += 1;
let _ = event_tx
.send(MarketStreamEvent::Reconnecting { attempt, delay_ms })
.await;
tokio::select! {
_ = close_rx.recv() => {
break;
}
_ = tokio::time::sleep(Duration::from_millis(delay_ms)) => {
}
}
}
}
}
}
pub async fn subscribe_depth_with_callback<F, Fut>(
&self,
symbol: &str,
mut callback: F,
) -> Result<(), WError>
where
F: FnMut(MarketStreamMessage) -> Fut,
Fut: std::future::Future<Output = bool>,
{
let (mut handle, mut receiver) = self.subscribe_depth(symbol, None).await?;
while let Some(message) = receiver.recv().await {
if !callback(message).await {
break;
}
}
handle.close().await;
Ok(())
}
pub async fn subscribe_price_with_callback<F, Fut>(
&self,
symbol: &str,
mut callback: F,
) -> Result<(), WError>
where
F: FnMut(MarketStreamMessage) -> Fut,
Fut: std::future::Future<Output = bool>,
{
let (mut handle, mut receiver) = self.subscribe_price(symbol, None).await?;
while let Some(message) = receiver.recv().await {
if !callback(message).await {
break;
}
}
handle.close().await;
Ok(())
}
pub async fn subscribe_recent_trades_with_callback<F, Fut>(
&self,
symbol: &str,
mut callback: F,
) -> Result<(), WError>
where
F: FnMut(MarketStreamMessage) -> Fut,
Fut: std::future::Future<Output = bool>,
{
let (mut handle, mut receiver) = self.subscribe_recent_trades(symbol, None).await?;
while let Some(message) = receiver.recv().await {
if !callback(message).await {
break;
}
}
handle.close().await;
Ok(())
}
pub async fn subscribe_ohlc_with_callback<F, Fut>(
&self,
symbol: &str,
interval: &str,
mut callback: F,
) -> Result<(), WError>
where
F: FnMut(MarketStreamMessage) -> Fut,
Fut: std::future::Future<Output = bool>,
{
let (mut handle, mut receiver) = self.subscribe_ohlc(symbol, interval, None).await?;
while let Some(message) = receiver.recv().await {
if !callback(message).await {
break;
}
}
handle.close().await;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_market_depth_parsing() {
let json = r#"{
"timestamp": 1704067200000,
"bids": [
{"price": 0.45, "quantity": 1000.0},
{"price": 0.44, "quantity": 2000.0}
],
"asks": [
{"price": 0.46, "quantity": 500.0},
{"price": 0.47, "quantity": 1500.0}
]
}"#;
let msg: MarketDepthMessage = serde_json::from_str(json).unwrap();
assert_eq!(msg.timestamp, 1704067200000);
assert_eq!(msg.bids.len(), 2);
assert_eq!(msg.asks.len(), 2);
assert_eq!(msg.bids[0].price, 0.45);
assert_eq!(msg.bids[0].quantity, 1000.0);
}
#[test]
fn test_market_price_parsing() {
let json = r#"{"price": 0.456789}"#;
let msg: MarketPriceMessage = serde_json::from_str(json).unwrap();
assert_eq!(msg.price, 0.456789);
}
#[test]
fn test_recent_trades_parsing() {
let json = r#"[
{
"order_id": "order-123",
"timestamp": "2024-01-01T00:00:00Z",
"symbol": "ADAUSDM",
"price": 0.45,
"amount": 100.0,
"side": "buy"
}
]"#;
let trades: Vec<Trade> = serde_json::from_str(json).unwrap();
assert_eq!(trades.len(), 1);
assert_eq!(trades[0].order_id, "order-123");
assert_eq!(trades[0].symbol, "ADAUSDM");
assert_eq!(trades[0].side, "buy");
}
#[test]
fn test_ohlc_parsing() {
let json = r#"{
"t": 1704067200,
"s": "ADAUSDM",
"o": 0.45,
"h": 0.48,
"l": 0.44,
"c": 0.47,
"v": 10000.5
}"#;
let msg: OhlcMessage = serde_json::from_str(json).unwrap();
assert_eq!(msg.timestamp, 1704067200);
assert_eq!(msg.symbol, "ADAUSDM");
assert_eq!(msg.open, 0.45);
assert_eq!(msg.high, 0.48);
assert_eq!(msg.low, 0.44);
assert_eq!(msg.close, 0.47);
assert_eq!(msg.volume, 10000.5);
}
#[test]
fn test_market_stream_message_helpers() {
let depth = MarketStreamMessage::Depth(MarketDepthMessage {
timestamp: 0,
bids: vec![],
asks: vec![],
});
assert!(depth.is_depth());
assert!(!depth.is_price());
let price = MarketStreamMessage::Price(MarketPriceMessage { price: 0.5 });
assert!(price.is_price());
assert!(!price.is_depth());
let trades = MarketStreamMessage::RecentTrades(vec![]);
assert!(trades.is_recent_trades());
let ohlc = MarketStreamMessage::Ohlc(OhlcMessage {
timestamp: 0,
symbol: "TEST".to_string(),
open: 0.0,
high: 0.0,
low: 0.0,
close: 0.0,
volume: 0.0,
});
assert!(ohlc.is_ohlc());
let unknown = MarketStreamMessage::Unknown("test".to_string());
assert!(unknown.is_unknown());
}
}