use std::collections::HashSet;
use std::pin::Pin;
use std::sync::{Arc, Mutex as StdMutex};
use std::time::{Duration, Instant};
use async_trait::async_trait;
use futures_util::{Stream, StreamExt, SinkExt, stream::{SplitSink, SplitStream}};
use serde_json::{json, Value};
use tokio::sync::{broadcast, Mutex};
use tokio_tungstenite::{connect_async, tungstenite::Message, WebSocketStream, MaybeTlsStream};
use crate::core::{
Credentials, AccountType, Symbol,
ExchangeError, ExchangeResult,
ConnectionStatus, StreamEvent, StreamType, SubscriptionRequest,
};
use crate::core::types::{WebSocketResult, WebSocketError};
use crate::core::traits::WebSocketConnector;
use super::endpoints::{MexcUrls, MexcWsChannels, format_symbol};
use super::parser::MexcParser;
type WsStream = WebSocketStream<MaybeTlsStream<tokio::net::TcpStream>>;
type WsSink = SplitSink<WsStream, Message>;
type WsReader = SplitStream<WsStream>;
pub struct MexcWebSocket {
_auth: Option<()>, account_type: AccountType,
status: Arc<Mutex<ConnectionStatus>>,
subscriptions: Arc<Mutex<HashSet<SubscriptionRequest>>>,
event_tx: Arc<StdMutex<Option<broadcast::Sender<WebSocketResult<StreamEvent>>>>>,
ws_writer: Arc<Mutex<Option<WsSink>>>,
ping_interval: Duration,
last_ping: Arc<Mutex<Instant>>,
ws_ping_rtt_ms: Arc<Mutex<u64>>,
}
impl MexcWebSocket {
pub async fn new(_credentials: Option<Credentials>) -> ExchangeResult<Self> {
Ok(Self {
_auth: None, account_type: AccountType::Spot, status: Arc::new(Mutex::new(ConnectionStatus::Disconnected)),
subscriptions: Arc::new(Mutex::new(HashSet::new())),
event_tx: Arc::new(StdMutex::new(None)),
ws_writer: Arc::new(Mutex::new(None)),
ping_interval: Duration::from_secs(20), last_ping: Arc::new(Mutex::new(Instant::now())),
ws_ping_rtt_ms: Arc::new(Mutex::new(0)),
})
}
async fn send_message(&self, msg: &Value) -> ExchangeResult<()> {
let msg_json = serde_json::to_string(msg)
.map_err(|e| ExchangeError::Parse(format!("Failed to serialize message: {}", e)))?;
let mut writer_guard = self.ws_writer.lock().await;
let writer = writer_guard.as_mut()
.ok_or_else(|| ExchangeError::Network("WebSocket not connected".to_string()))?;
writer.send(Message::Text(msg_json)).await
.map_err(|e| ExchangeError::Network(format!("Failed to send message: {}", e)))?;
Ok(())
}
pub async fn subscribe_ticker(&self, symbol: Symbol) -> ExchangeResult<()> {
let symbol_str = format_symbol(&symbol, self.account_type);
let stream_name = MexcWsChannels::mini_ticker(&symbol_str);
let msg = json!({
"method": "SUBSCRIPTION",
"params": [stream_name]
});
self.send_message(&msg).await?;
let mut subs = self.subscriptions.lock().await;
subs.insert(SubscriptionRequest {
stream_type: StreamType::Ticker,
symbol: symbol.clone(),
account_type: crate::core::AccountType::default(),
});
Ok(())
}
pub async fn subscribe_trades(&self, symbol: Symbol) -> ExchangeResult<()> {
let symbol_str = format_symbol(&symbol, self.account_type);
let stream_name = MexcWsChannels::aggre_deals(&symbol_str);
let msg = json!({
"method": "SUBSCRIPTION",
"params": [stream_name]
});
self.send_message(&msg).await?;
let mut subs = self.subscriptions.lock().await;
subs.insert(SubscriptionRequest {
stream_type: StreamType::Trade,
symbol: symbol.clone(),
account_type: crate::core::AccountType::default(),
});
Ok(())
}
fn start_message_loop(
mut reader: WsReader,
event_tx: Arc<StdMutex<Option<broadcast::Sender<WebSocketResult<StreamEvent>>>>>,
status: Arc<Mutex<ConnectionStatus>>,
last_ping: Arc<Mutex<Instant>>,
ws_ping_rtt_ms: Arc<Mutex<u64>>,
) {
tokio::spawn(async move {
while let Some(msg) = reader.next().await {
match msg {
Ok(Message::Text(text)) => {
match serde_json::from_str::<Value>(&text) {
Ok(json) => {
if json.get("msg").and_then(|m| m.as_str()) == Some("PONG") {
continue;
}
if json.get("msg").and_then(|m| m.as_str()) == Some("pong") {
continue;
}
if json.get("code").and_then(|c| c.as_i64()) == Some(0) {
if let Some(msg_str) = json.get("msg").and_then(|m| m.as_str()) {
if msg_str.starts_with("spot@") {
continue;
}
}
}
if let Some(msg_str) = json.get("msg").and_then(|m| m.as_str()) {
if msg_str.contains("Blocked") || msg_str.contains("Not Subscribed successfully") {
let tx_guard = event_tx.lock().unwrap();
if let Some(ref tx) = *tx_guard {
let _ = tx.send(Err(WebSocketError::Subscription(msg_str.to_string())));
}
continue;
}
}
if let Ok((channel, data)) = MexcParser::parse_ws_message(&json) {
let event_result = if channel.contains("deals") || channel.contains("bookTicker") || channel.contains("depth") {
MexcParser::parse_ws_ticker(data).map(StreamEvent::Ticker)
} else {
Err(ExchangeError::Parse(format!("Unknown channel: {}", channel)))
};
if let Ok(event) = event_result {
let tx_guard = event_tx.lock().unwrap();
if let Some(ref tx) = *tx_guard {
let _ = tx.send(Ok(event));
}
}
}
}
Err(e) => {
let tx_guard = event_tx.lock().unwrap();
if let Some(ref tx) = *tx_guard {
let _ = tx.send(Err(WebSocketError::Parse(e.to_string())));
}
}
}
}
Ok(Message::Binary(data)) => {
match MexcParser::parse_protobuf_message(&data) {
Ok((_channel, event)) => {
let tx_guard = event_tx.lock().unwrap();
if let Some(ref tx) = *tx_guard {
let _ = tx.send(Ok(event));
}
}
Err(e) => {
let err_msg = e.to_string();
if !err_msg.contains("Unsupported protobuf channel") {
let tx_guard = event_tx.lock().unwrap();
if let Some(ref tx) = *tx_guard {
let _ = tx.send(Err(WebSocketError::Parse(err_msg)));
}
}
}
}
}
Ok(Message::Pong(_)) => {
let rtt = last_ping.lock().await.elapsed().as_millis() as u64;
*ws_ping_rtt_ms.lock().await = rtt;
}
Ok(Message::Close(_)) => {
*status.lock().await = ConnectionStatus::Disconnected;
break;
}
Err(e) => {
let tx_guard = event_tx.lock().unwrap();
if let Some(ref tx) = *tx_guard {
let _ = tx.send(Err(WebSocketError::ConnectionError(e.to_string())));
}
break;
}
_ => {}
}
}
let _ = event_tx.lock().unwrap().take();
*status.lock().await = ConnectionStatus::Disconnected;
});
}
fn start_ping_task(
ws_writer: Arc<Mutex<Option<WsSink>>>,
ping_interval: Duration,
last_ping: Arc<Mutex<Instant>>,
) {
tokio::spawn(async move {
let mut interval = tokio::time::interval(ping_interval);
interval.tick().await;
loop {
interval.tick().await;
let ping_msg = json!({"method": "PING"});
if let Ok(ping_text) = serde_json::to_string(&ping_msg) {
let mut writer_guard = ws_writer.lock().await;
if let Some(ref mut writer) = *writer_guard {
if writer.send(Message::Text(ping_text)).await.is_err() {
break;
}
*last_ping.lock().await = Instant::now();
if writer.send(Message::Ping(vec![])).await.is_err() {
break;
}
} else {
break;
}
}
}
});
}
}
#[async_trait]
impl WebSocketConnector for MexcWebSocket {
async fn connect(&mut self, account_type: AccountType) -> WebSocketResult<()> {
self.account_type = account_type;
*self.status.lock().await = ConnectionStatus::Connecting;
let ws_url = MexcUrls::ws_url();
let (ws_stream, _) = connect_async(ws_url).await
.map_err(|e| WebSocketError::ConnectionError(format!("WebSocket connection failed: {}", e)))?;
let (write, read) = ws_stream.split();
*self.ws_writer.lock().await = Some(write);
let (tx, _) = broadcast::channel(1000);
*self.event_tx.lock().unwrap() = Some(tx);
Self::start_message_loop(
read,
self.event_tx.clone(),
self.status.clone(),
self.last_ping.clone(),
self.ws_ping_rtt_ms.clone(),
);
Self::start_ping_task(
self.ws_writer.clone(),
self.ping_interval,
self.last_ping.clone(),
);
*self.status.lock().await = ConnectionStatus::Connected;
Ok(())
}
async fn disconnect(&mut self) -> WebSocketResult<()> {
if let Some(mut writer) = self.ws_writer.lock().await.take() {
let _ = writer.close().await;
}
*self.status.lock().await = ConnectionStatus::Disconnected;
self.subscriptions.lock().await.clear();
Ok(())
}
fn connection_status(&self) -> ConnectionStatus {
self.status.try_lock()
.map(|guard| *guard)
.unwrap_or(ConnectionStatus::Disconnected)
}
async fn subscribe(&mut self, request: SubscriptionRequest) -> WebSocketResult<()> {
match request.stream_type {
StreamType::Ticker => {
self.subscribe_ticker(request.symbol.clone()).await
.map_err(|e| WebSocketError::Subscription(e.to_string()))?;
}
StreamType::Trade => {
self.subscribe_trades(request.symbol.clone()).await
.map_err(|e| WebSocketError::Subscription(e.to_string()))?;
}
_ => {
return Err(WebSocketError::Subscription(
format!("Unsupported stream type: {:?}", request.stream_type)
));
}
}
Ok(())
}
async fn unsubscribe(&mut self, request: SubscriptionRequest) -> WebSocketResult<()> {
let symbol_str = format_symbol(&request.symbol, self.account_type);
let channel = match &request.stream_type {
StreamType::Ticker => MexcWsChannels::mini_ticker(&symbol_str),
StreamType::Trade => MexcWsChannels::aggre_deals(&symbol_str),
StreamType::Orderbook | StreamType::OrderbookDelta => MexcWsChannels::aggre_depth(&symbol_str),
StreamType::Kline { interval } => MexcWsChannels::kline(&symbol_str, interval),
_ => {
self.subscriptions.lock().await.remove(&request);
return Ok(());
}
};
let msg = json!({
"method": "UNSUBSCRIPTION",
"params": [channel]
});
self.send_message(&msg).await
.map_err(|e| WebSocketError::Subscription(format!("Unsubscribe failed: {}", e)))?;
self.subscriptions.lock().await.remove(&request);
Ok(())
}
fn event_stream(&self) -> Pin<Box<dyn Stream<Item = WebSocketResult<StreamEvent>> + Send>> {
let tx_guard = self.event_tx.lock().unwrap();
if let Some(ref tx) = *tx_guard {
let rx = tx.subscribe();
Box::pin(tokio_stream::wrappers::BroadcastStream::new(rx).map(|r| {
r.map_err(|e| WebSocketError::ConnectionError(format!("Broadcast error: {}", e)))
.and_then(|x| x)
}))
} else {
Box::pin(futures_util::stream::empty())
}
}
fn active_subscriptions(&self) -> Vec<SubscriptionRequest> {
self.subscriptions.try_lock()
.map(|guard| guard.iter().cloned().collect())
.unwrap_or_default()
}
fn ping_rtt_handle(&self) -> Option<Arc<Mutex<u64>>> {
Some(self.ws_ping_rtt_ms.clone())
}
}