use std::collections::HashSet;
use std::pin::Pin;
use std::sync::{Arc, Mutex as StdMutex};
use std::time::{Duration, Instant};
use async_trait::async_trait;
use futures_util::{Stream, StreamExt, SinkExt};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use tokio::sync::{Mutex, broadcast};
use tokio::time::sleep;
use tokio_tungstenite::{connect_async, tungstenite::Message, WebSocketStream, MaybeTlsStream};
use crate::core::{
ExchangeResult, ExchangeError, timestamp_millis,
AccountType, ConnectionStatus, StreamEvent, StreamType, SubscriptionRequest,
};
use crate::core::types::{WebSocketResult, WebSocketError, OrderBookLevel, OrderbookDelta as OrderbookDeltaData, OrderbookCapabilities};
use crate::core::traits::WebSocketConnector;
use super::auth::CryptoComAuth;
use super::endpoints::{InstrumentType, format_symbol as fmt_symbol};
type WsStream = WebSocketStream<MaybeTlsStream<tokio::net::TcpStream>>;
#[derive(Debug, Clone, Serialize)]
struct OutgoingMessage {
id: i64,
method: String,
#[serde(skip_serializing_if = "Option::is_none")]
api_key: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
sig: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
nonce: Option<i64>,
#[serde(skip_serializing_if = "Option::is_none")]
params: Option<SubscribeParams>,
}
#[derive(Debug, Clone, Serialize)]
struct SubscribeParams {
channels: Vec<String>,
}
#[derive(Debug, Clone, Deserialize)]
struct IncomingMessage {
#[serde(default)]
#[allow(dead_code)]
id: Option<i64>,
#[serde(default)]
method: Option<String>,
#[serde(default)]
code: Option<i64>,
#[serde(default)]
message: Option<String>,
#[serde(default)]
result: Option<Value>,
}
#[derive(Debug, Clone)]
pub enum WsEvent {
Ticker(Value),
OrderBook(Value),
Trade(Value),
UserOrder(Value),
UserBalance(Value),
Heartbeat,
SubscriptionSuccess(String),
Error(String),
}
pub struct CryptoComWebSocket {
auth: Option<CryptoComAuth>,
is_user_stream: bool,
ws_stream: Arc<Mutex<Option<WsStream>>>,
broadcast_tx: broadcast::Sender<WsEvent>,
stream_broadcast_tx: Arc<StdMutex<Option<broadcast::Sender<WebSocketResult<StreamEvent>>>>>,
subscriptions: Arc<Mutex<HashSet<String>>>,
trait_subscriptions: Arc<Mutex<HashSet<SubscriptionRequest>>>,
message_id: Arc<Mutex<i64>>,
is_connected: Arc<Mutex<bool>>,
account_type: AccountType,
last_ping: Arc<Mutex<Instant>>,
ws_ping_rtt_ms: Arc<Mutex<u64>>,
}
impl CryptoComWebSocket {
pub fn new(auth: Option<CryptoComAuth>, is_user_stream: bool) -> Self {
let (tx, _) = broadcast::channel(1000);
Self {
auth,
is_user_stream,
ws_stream: Arc::new(Mutex::new(None)),
broadcast_tx: tx,
stream_broadcast_tx: Arc::new(StdMutex::new(None)),
subscriptions: Arc::new(Mutex::new(HashSet::new())),
trait_subscriptions: Arc::new(Mutex::new(HashSet::new())),
message_id: Arc::new(Mutex::new(1)),
is_connected: Arc::new(Mutex::new(false)),
account_type: AccountType::Spot,
last_ping: Arc::new(Mutex::new(Instant::now())),
ws_ping_rtt_ms: Arc::new(Mutex::new(0)),
}
}
fn get_ws_url(&self) -> &'static str {
if self.is_user_stream {
"wss://stream.crypto.com/exchange/v1/user"
} else {
"wss://stream.crypto.com/exchange/v1/market"
}
}
async fn next_id(&self) -> i64 {
let mut id = self.message_id.lock().await;
let current = *id;
*id += 1;
current
}
pub async fn connect(&mut self) -> ExchangeResult<()> {
let url = self.get_ws_url();
let (ws_stream, _) = connect_async(url).await
.map_err(|e| ExchangeError::Network(format!("WebSocket connection failed: {}", e)))?;
sleep(Duration::from_secs(1)).await;
*self.ws_stream.lock().await = Some(ws_stream);
*self.is_connected.lock().await = true;
let (stream_sender, _) = broadcast::channel(1000);
*self.stream_broadcast_tx.lock().unwrap() = Some(stream_sender);
if self.is_user_stream {
self.authenticate().await?;
}
self.start_message_handler();
self.start_heartbeat_handler();
self.start_ws_ping_task();
Ok(())
}
async fn authenticate(&self) -> ExchangeResult<()> {
let auth = self.auth.as_ref()
.ok_or_else(|| ExchangeError::Auth("User stream requires authentication".to_string()))?;
let id = self.next_id().await;
let nonce = timestamp_millis();
let signature = auth.sign_ws_auth(id, nonce as i64);
let msg = OutgoingMessage {
id,
method: "public/auth".to_string(),
api_key: Some(auth.api_key().to_string()),
sig: Some(signature),
nonce: Some(nonce as i64),
params: None,
};
self.send_message(&msg).await?;
sleep(Duration::from_millis(500)).await;
Ok(())
}
async fn send_message(&self, msg: &OutgoingMessage) -> ExchangeResult<()> {
let msg_json = serde_json::to_string(msg)
.map_err(|e| ExchangeError::Parse(format!("Failed to serialize message: {}", e)))?;
let mut stream_guard = self.ws_stream.lock().await;
let stream = stream_guard.as_mut()
.ok_or_else(|| ExchangeError::Network("Not connected".to_string()))?;
stream.send(Message::Text(msg_json)).await
.map_err(|e| ExchangeError::Network(format!("Failed to send message: {}", e)))?;
Ok(())
}
fn start_message_handler(&self) {
let ws_stream = self.ws_stream.clone();
let broadcast_tx = self.broadcast_tx.clone();
let stream_broadcast_tx = self.stream_broadcast_tx.clone();
let is_connected = self.is_connected.clone();
let last_ping = self.last_ping.clone();
let ws_ping_rtt_ms = self.ws_ping_rtt_ms.clone();
tokio::spawn(async move {
loop {
let mut stream_guard = ws_stream.lock().await;
let stream = match stream_guard.as_mut() {
Some(s) => s,
None => {
drop(stream_guard);
sleep(Duration::from_millis(100)).await;
continue;
}
};
match stream.next().await {
Some(Ok(Message::Text(text))) => {
drop(stream_guard);
if let Some(event) = Self::parse_message(&text) {
if let Some(stream_event) = Self::ws_event_to_stream_event(&event) {
if let Some(tx) = stream_broadcast_tx.lock().unwrap().as_ref() {
let _ = tx.send(Ok(stream_event));
}
}
let _ = broadcast_tx.send(event);
}
}
Some(Ok(Message::Pong(_))) => {
drop(stream_guard);
let rtt = last_ping.lock().await.elapsed().as_millis() as u64;
*ws_ping_rtt_ms.lock().await = rtt;
}
Some(Ok(Message::Close(_))) => {
drop(stream_guard);
*is_connected.lock().await = false;
break;
}
Some(Err(e)) => {
drop(stream_guard);
if let Some(tx) = stream_broadcast_tx.lock().unwrap().as_ref() {
let _ = tx.send(Err(WebSocketError::ConnectionError(e.to_string())));
}
let _ = broadcast_tx.send(WsEvent::Error(e.to_string()));
break;
}
None => {
drop(stream_guard);
*is_connected.lock().await = false;
break;
}
_ => {
drop(stream_guard);
}
}
}
let _ = stream_broadcast_tx.lock().unwrap().take();
});
}
fn ws_event_to_stream_event(event: &WsEvent) -> Option<StreamEvent> {
match event {
WsEvent::Ticker(data) => {
let ticker = super::parser::CryptoComParser::parse_ws_ticker(data).ok()?;
Some(StreamEvent::Ticker(ticker))
}
WsEvent::OrderBook(data) => {
let bids = data.get("bids")
.and_then(|b| b.as_array())
.map(|arr| {
arr.iter().filter_map(|entry| {
let price = entry.get(0)?.as_str()?.parse::<f64>().ok()?;
let qty = entry.get(1)?.as_str()?.parse::<f64>().ok()?;
Some(OrderBookLevel::new(price, qty))
}).collect::<Vec<_>>()
})
.unwrap_or_default();
let asks = data.get("asks")
.and_then(|a| a.as_array())
.map(|arr| {
arr.iter().filter_map(|entry| {
let price = entry.get(0)?.as_str()?.parse::<f64>().ok()?;
let qty = entry.get(1)?.as_str()?.parse::<f64>().ok()?;
Some(OrderBookLevel::new(price, qty))
}).collect::<Vec<_>>()
})
.unwrap_or_default();
let timestamp = data.get("t")
.and_then(|t| t.as_i64())
.unwrap_or(0);
Some(StreamEvent::OrderbookDelta(OrderbookDeltaData {
bids,
asks,
timestamp,
first_update_id: None,
last_update_id: None,
prev_update_id: None,
event_time: None,
checksum: None,
}))
}
WsEvent::Trade(data) => {
let trade = super::parser::CryptoComParser::parse_ws_trade(data).ok()?;
Some(StreamEvent::Trade(trade))
}
WsEvent::UserOrder(_) | WsEvent::UserBalance(_) => {
None
}
WsEvent::Heartbeat | WsEvent::SubscriptionSuccess(_) | WsEvent::Error(_) => None,
}
}
fn parse_message(text: &str) -> Option<WsEvent> {
let msg: IncomingMessage = match serde_json::from_str(text) {
Ok(m) => m,
Err(e) => {
eprintln!("Failed to parse message: {} - {}", e, text);
return None;
}
};
match msg.method.as_deref() {
Some("public/heartbeat") => Some(WsEvent::Heartbeat),
Some("subscribe") => {
if let Some(ref result) = msg.result {
if result.get("data").is_some() {
return Self::parse_data_message(result);
}
if msg.code == Some(0) {
if let Some(subscription) = result.get("subscription").and_then(|s| s.as_str()) {
return Some(WsEvent::SubscriptionSuccess(subscription.to_string()));
}
}
}
None
}
Some("public/auth") => {
if msg.code != Some(0) {
let error_msg = msg.message.unwrap_or_else(|| "Authentication failed".to_string());
Some(WsEvent::Error(error_msg))
} else {
None }
}
None => {
if let Some(result) = msg.result {
return Self::parse_data_message(&result);
}
eprintln!("Unknown message format (no method, no result): {}", text);
None
}
Some(method) => {
if let Some(result) = msg.result {
Self::parse_data_message(&result)
} else {
eprintln!("Unknown method '{}': {}", method, text);
None
}
}
}
}
fn parse_data_message(result: &Value) -> Option<WsEvent> {
let channel = result.get("channel")?.as_str()?;
let data = result
.get("data")
.and_then(|d| d.as_array())
.and_then(|arr| arr.first())
.cloned()
.unwrap_or_else(|| result.clone());
match channel {
"ticker" => Some(WsEvent::Ticker(data)),
"book" => Some(WsEvent::OrderBook(data)),
"trade" => Some(WsEvent::Trade(data)),
"user.order" => Some(WsEvent::UserOrder(data)),
"user.balance" => Some(WsEvent::UserBalance(data)),
_ => None,
}
}
fn start_ws_ping_task(&self) {
let ws_stream = self.ws_stream.clone();
let last_ping = self.last_ping.clone();
let is_connected = self.is_connected.clone();
tokio::spawn(async move {
let mut interval = tokio::time::interval(Duration::from_secs(5));
interval.tick().await;
loop {
interval.tick().await;
if !*is_connected.lock().await {
break;
}
let mut stream_guard = ws_stream.lock().await;
if let Some(stream) = stream_guard.as_mut() {
*last_ping.lock().await = Instant::now();
if stream.send(Message::Ping(vec![])).await.is_err() {
break;
}
} else {
break;
}
}
});
}
fn start_heartbeat_handler(&self) {
let ws_stream = self.ws_stream.clone();
let message_id = self.message_id.clone();
let is_connected = self.is_connected.clone();
let mut rx = self.broadcast_tx.subscribe();
tokio::spawn(async move {
loop {
if !*is_connected.lock().await {
break;
}
tokio::select! {
_ = sleep(Duration::from_secs(30)) => {
}
event = rx.recv() => {
if let Ok(WsEvent::Heartbeat) = event {
let id = {
let mut mid = message_id.lock().await;
let current = *mid;
*mid += 1;
current
};
let pong = OutgoingMessage {
id,
method: "public/respond-heartbeat".to_string(),
api_key: None,
sig: None,
nonce: None,
params: None,
};
if let Ok(msg_json) = serde_json::to_string(&pong) {
let mut stream_guard = ws_stream.lock().await;
if let Some(stream) = stream_guard.as_mut() {
let _ = stream.send(Message::Text(msg_json)).await;
}
}
}
}
}
}
});
}
pub async fn subscribe_ticker(&mut self, instrument_name: &str) -> ExchangeResult<()> {
let channel = format!("ticker.{}", instrument_name);
self.subscribe_channels(vec![channel]).await
}
pub async fn subscribe_orderbook(&mut self, instrument_name: &str, depth: u32) -> ExchangeResult<()> {
let channel = format!("book.{}.{}", instrument_name, depth);
self.subscribe_channels(vec![channel]).await
}
pub async fn subscribe_trades(&mut self, instrument_name: &str) -> ExchangeResult<()> {
let channel = format!("trade.{}", instrument_name);
self.subscribe_channels(vec![channel]).await
}
pub async fn subscribe_user_orders(&mut self, instrument_name: &str) -> ExchangeResult<()> {
if !self.is_user_stream {
return Err(ExchangeError::UnsupportedOperation(
"User orders require user stream".to_string()
));
}
let channel = format!("user.order.{}", instrument_name);
self.subscribe_channels(vec![channel]).await
}
pub async fn subscribe_user_balance(&mut self) -> ExchangeResult<()> {
if !self.is_user_stream {
return Err(ExchangeError::UnsupportedOperation(
"User balance requires user stream".to_string()
));
}
self.subscribe_channels(vec!["user.balance".to_string()]).await
}
async fn subscribe_channels(&mut self, channels: Vec<String>) -> ExchangeResult<()> {
let id = self.next_id().await;
let nonce = timestamp_millis();
let msg = OutgoingMessage {
id,
method: "subscribe".to_string(),
api_key: None,
sig: None,
nonce: Some(nonce as i64),
params: Some(SubscribeParams { channels: channels.clone() }),
};
self.send_message(&msg).await?;
let mut subs = self.subscriptions.lock().await;
for channel in channels {
subs.insert(channel);
}
Ok(())
}
async fn unsubscribe_channels(&self, channels: Vec<String>) -> ExchangeResult<()> {
let id = self.next_id().await;
let nonce = timestamp_millis();
let msg = OutgoingMessage {
id,
method: "unsubscribe".to_string(),
api_key: None,
sig: None,
nonce: Some(nonce as i64),
params: Some(SubscribeParams { channels: channels.clone() }),
};
self.send_message(&msg).await?;
let mut subs = self.subscriptions.lock().await;
for channel in &channels {
subs.remove(channel);
}
Ok(())
}
fn build_channel(request: &SubscriptionRequest, account_type: AccountType) -> Vec<String> {
let instrument_type = match account_type {
AccountType::FuturesCross | AccountType::FuturesIsolated => InstrumentType::Perpetual,
_ => InstrumentType::Spot,
};
let symbol_str = fmt_symbol(&request.symbol.base, &request.symbol.quote, instrument_type);
match &request.stream_type {
StreamType::Ticker => vec![format!("ticker.{}", symbol_str)],
StreamType::Trade => vec![format!("trade.{}", symbol_str)],
StreamType::Orderbook | StreamType::OrderbookDelta => {
let depth = request.depth.unwrap_or(10);
vec![format!("book.{}.{}", symbol_str, depth)]
}
StreamType::Kline { interval } => vec![format!("candlestick.{}.{}", interval, symbol_str)],
StreamType::OrderUpdate => vec![format!("user.order.{}", symbol_str)],
StreamType::BalanceUpdate => vec!["user.balance".to_string()],
_ => vec![],
}
}
pub fn event_stream(&self) -> broadcast::Receiver<WsEvent> {
self.broadcast_tx.subscribe()
}
pub async fn is_connected(&self) -> bool {
*self.is_connected.lock().await
}
pub async fn disconnect(&mut self) -> ExchangeResult<()> {
*self.is_connected.lock().await = false;
*self.ws_stream.lock().await = None;
let _ = self.stream_broadcast_tx.lock().unwrap().take();
self.subscriptions.lock().await.clear();
Ok(())
}
}
#[async_trait]
impl WebSocketConnector for CryptoComWebSocket {
async fn connect(&mut self, account_type: AccountType) -> WebSocketResult<()> {
self.account_type = account_type;
let url = self.get_ws_url();
let (ws_stream, _) = connect_async(url).await
.map_err(|e| WebSocketError::ConnectionError(format!("WebSocket connection failed: {}", e)))?;
sleep(Duration::from_secs(1)).await;
*self.ws_stream.lock().await = Some(ws_stream);
*self.is_connected.lock().await = true;
let (stream_sender, _) = broadcast::channel(1000);
*self.stream_broadcast_tx.lock().unwrap() = Some(stream_sender);
if self.is_user_stream {
self.authenticate().await
.map_err(|e| WebSocketError::Auth(e.to_string()))?;
}
self.start_message_handler();
self.start_heartbeat_handler();
self.start_ws_ping_task();
Ok(())
}
async fn disconnect(&mut self) -> WebSocketResult<()> {
*self.is_connected.lock().await = false;
*self.ws_stream.lock().await = None;
let _ = self.stream_broadcast_tx.lock().unwrap().take();
self.subscriptions.lock().await.clear();
self.trait_subscriptions.lock().await.clear();
Ok(())
}
fn connection_status(&self) -> ConnectionStatus {
match self.is_connected.try_lock() {
Ok(connected) => {
if *connected {
ConnectionStatus::Connected
} else {
ConnectionStatus::Disconnected
}
}
Err(_) => ConnectionStatus::Disconnected,
}
}
async fn subscribe(&mut self, request: SubscriptionRequest) -> WebSocketResult<()> {
let channels = Self::build_channel(&request, self.account_type);
if channels.is_empty() {
return Err(WebSocketError::UnsupportedOperation(
format!("Unsupported stream type: {:?}", request.stream_type),
));
}
self.subscribe_channels(channels).await
.map_err(|e| WebSocketError::Subscription(e.to_string()))?;
self.trait_subscriptions.lock().await.insert(request);
Ok(())
}
async fn unsubscribe(&mut self, request: SubscriptionRequest) -> WebSocketResult<()> {
let channels = Self::build_channel(&request, self.account_type);
if channels.is_empty() {
return Err(WebSocketError::UnsupportedOperation(
format!("Unsupported stream type: {:?}", request.stream_type),
));
}
self.unsubscribe_channels(channels).await
.map_err(|e| WebSocketError::Subscription(e.to_string()))?;
self.trait_subscriptions.lock().await.remove(&request);
Ok(())
}
fn event_stream(&self) -> Pin<Box<dyn Stream<Item = WebSocketResult<StreamEvent>> + Send>> {
let rx = self.stream_broadcast_tx.lock().unwrap().as_ref()
.map(|tx| tx.subscribe())
.unwrap_or_else(|| broadcast::channel(1).1);
Box::pin(
tokio_stream::wrappers::BroadcastStream::new(rx).filter_map(|result| async move {
match result {
Ok(event) => Some(event),
Err(tokio_stream::wrappers::errors::BroadcastStreamRecvError::Lagged(_)) => {
Some(Err(WebSocketError::ReceiveError(
"Event stream lagged behind".to_string(),
)))
}
}
}),
)
}
fn active_subscriptions(&self) -> Vec<SubscriptionRequest> {
match self.trait_subscriptions.try_lock() {
Ok(subs) => subs.iter().cloned().collect(),
Err(_) => Vec::new(),
}
}
fn ping_rtt_handle(&self) -> Option<Arc<Mutex<u64>>> {
Some(self.ws_ping_rtt_ms.clone())
}
fn orderbook_capabilities(&self, _account_type: AccountType) -> OrderbookCapabilities {
OrderbookCapabilities {
ws_depths: &[10, 50],
ws_default_depth: Some(50),
rest_max_depth: Some(50),
rest_depth_values: &[],
supports_snapshot: true,
supports_delta: true,
update_speeds_ms: &[100, 500],
default_speed_ms: Some(100),
ws_channels: &[],
checksum: None,
has_sequence: false,
has_prev_sequence: false,
supports_aggregation: false,
aggregation_levels: &[],
}
}
}
async fn _wait_after_connection() {
tokio::time::sleep(Duration::from_secs(1)).await;
}
fn _build_auth_message(auth: &CryptoComAuth, id: i64, nonce: i64) -> serde_json::Value {
let signature = auth.sign_ws_auth(id, nonce);
serde_json::json!({
"id": id,
"method": "public/auth",
"api_key": auth.api_key(),
"sig": signature,
"nonce": nonce
})
}
fn _build_heartbeat_response(id: i64) -> serde_json::Value {
serde_json::json!({
"id": id,
"method": "public/respond-heartbeat"
})
}
fn _build_subscribe_message(id: i64, channels: Vec<String>, nonce: i64) -> serde_json::Value {
serde_json::json!({
"id": id,
"method": "subscribe",
"params": {
"channels": channels
},
"nonce": nonce
})
}