use std::collections::HashSet;
use std::pin::Pin;
use std::sync::{Arc, Mutex as StdMutex, OnceLock};
use std::time::{Duration, Instant};
use async_trait::async_trait;
use futures_util::{Stream, StreamExt, SinkExt, stream::{SplitSink, SplitStream}};
use serde::{Deserialize, Serialize};
use serde_json::{json, Value};
use tokio::sync::{broadcast, Mutex};
use tokio::time::sleep;
use tokio_tungstenite::{connect_async, tungstenite::Message, WebSocketStream, MaybeTlsStream};
use crate::core::{
Credentials, AccountType, ExchangeResult,
ConnectionStatus, StreamEvent, StreamType, SubscriptionRequest,
timestamp_seconds,
};
use crate::core::types::{WebSocketResult, WebSocketError, OrderBookLevel, OrderbookCapabilities, WsBookChannel};
use crate::core::traits::WebSocketConnector;
use crate::core::utils::WeightRateLimiter;
use super::auth::GateioAuth;
use super::endpoints::{GateioUrls, format_symbol};
use super::parser::GateioParser;
static WS_RATE_LIMITER: OnceLock<Arc<StdMutex<WeightRateLimiter>>> = OnceLock::new();
fn get_ws_rate_limiter() -> &'static Arc<StdMutex<WeightRateLimiter>> {
WS_RATE_LIMITER.get_or_init(|| {
Arc::new(StdMutex::new(
WeightRateLimiter::new(100, Duration::from_secs(1))
))
})
}
type WsStream = WebSocketStream<MaybeTlsStream<tokio::net::TcpStream>>;
type WsSink = SplitSink<WsStream, Message>;
type WsReader = SplitStream<WsStream>;
#[derive(Debug, Clone, Serialize)]
struct OutgoingMessage {
time: i64,
channel: String,
event: String,
#[serde(skip_serializing_if = "Option::is_none")]
payload: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
auth: Option<AuthData>,
}
#[derive(Debug, Clone, Serialize)]
struct AuthData {
method: String,
#[serde(rename = "KEY")]
key: String,
#[serde(rename = "SIGN")]
sign: String,
}
#[derive(Debug, Clone, Deserialize)]
#[allow(dead_code)]
struct IncomingMessage {
time: Option<i64>,
time_ms: Option<i64>,
channel: Option<String>,
event: Option<String>,
result: Option<Value>,
error: Option<ErrorData>,
}
#[derive(Debug, Clone, Deserialize)]
#[allow(dead_code)]
struct ErrorData {
code: Option<i32>,
message: Option<String>,
}
pub struct GateioWebSocket {
auth: Option<GateioAuth>,
_testnet: bool,
account_type: AccountType,
status: Arc<Mutex<ConnectionStatus>>,
subscriptions: Arc<Mutex<HashSet<SubscriptionRequest>>>,
event_tx: Arc<StdMutex<Option<broadcast::Sender<WebSocketResult<StreamEvent>>>>>,
ws_writer: Arc<Mutex<Option<WsSink>>>,
ping_interval: Duration,
last_ping: Arc<Mutex<Instant>>,
urls: GateioUrls,
ws_ping_rtt_ms: Arc<Mutex<u64>>,
}
impl GateioWebSocket {
pub async fn new(
credentials: Option<Credentials>,
testnet: bool,
account_type: AccountType,
) -> ExchangeResult<Self> {
let auth = if let Some(creds) = credentials {
Some(GateioAuth::new(&creds)?)
} else {
None
};
let urls = if testnet {
GateioUrls::TESTNET
} else {
GateioUrls::MAINNET
};
Ok(Self {
auth,
_testnet: testnet,
account_type,
status: Arc::new(Mutex::new(ConnectionStatus::Disconnected)),
subscriptions: Arc::new(Mutex::new(HashSet::new())),
event_tx: Arc::new(StdMutex::new(None)),
ws_writer: Arc::new(Mutex::new(None)),
ping_interval: Duration::from_secs(20), last_ping: Arc::new(Mutex::new(Instant::now())),
urls,
ws_ping_rtt_ms: Arc::new(Mutex::new(0)),
})
}
async fn send_text(&self, text: String) -> WebSocketResult<()> {
let mut writer_guard = self.ws_writer.lock().await;
let writer = writer_guard.as_mut()
.ok_or_else(|| WebSocketError::ConnectionError("Not connected".to_string()))?;
writer.send(Message::Text(text)).await
.map_err(|e| WebSocketError::ConnectionError(e.to_string()))
}
fn start_message_loop(
mut reader: WsReader,
event_tx: Arc<StdMutex<Option<broadcast::Sender<WebSocketResult<StreamEvent>>>>>,
status: Arc<Mutex<ConnectionStatus>>,
account_type: AccountType,
last_ping: Arc<Mutex<Instant>>,
ws_ping_rtt_ms: Arc<Mutex<u64>>,
) {
tokio::spawn(async move {
while let Some(msg) = reader.next().await {
match msg {
Ok(Message::Text(text)) => {
if text.contains(".pong") {
if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(&text) {
if parsed.get("channel")
.and_then(|c| c.as_str())
.map(|c| c.ends_with(".pong"))
.unwrap_or(false)
{
let rtt = last_ping.lock().await.elapsed().as_millis() as u64;
*ws_ping_rtt_ms.lock().await = rtt;
}
}
}
if let Err(e) = Self::handle_message(&text, &event_tx, account_type) {
let tx_guard = event_tx.lock().unwrap();
if let Some(ref tx) = *tx_guard {
let _ = tx.send(Err(e));
}
}
}
Ok(Message::Close(_)) => {
*status.lock().await = ConnectionStatus::Disconnected;
break;
}
Err(e) => {
let tx_guard = event_tx.lock().unwrap();
if let Some(ref tx) = *tx_guard {
let _ = tx.send(Err(WebSocketError::ConnectionError(e.to_string())));
}
break;
}
_ => {}
}
}
let _ = event_tx.lock().unwrap().take();
*status.lock().await = ConnectionStatus::Disconnected;
});
}
fn handle_message(
text: &str,
event_tx: &Arc<StdMutex<Option<broadcast::Sender<WebSocketResult<StreamEvent>>>>>,
_account_type: AccountType,
) -> WebSocketResult<()> {
let msg: IncomingMessage = serde_json::from_str(text)
.map_err(|e| WebSocketError::Parse(format!("Failed to parse message: {}", e)))?;
if let Some(channel) = &msg.channel {
if channel.ends_with(".pong") {
return Ok(());
}
}
if msg.event.as_deref() == Some("subscribe") {
if let Some(error) = msg.error {
return Err(WebSocketError::ProtocolError(
error.message.unwrap_or_else(|| "Subscription failed".to_string())
));
}
return Ok(());
}
if msg.event.as_deref() == Some("update") {
if let (Some(channel), Some(result)) = (&msg.channel, &msg.result) {
if let Some(event) = Self::parse_data_message(channel, result)? {
let tx_guard = event_tx.lock().unwrap();
if let Some(ref tx) = *tx_guard {
let _ = tx.send(Ok(event));
}
}
}
}
Ok(())
}
fn parse_data_message(
channel: &str,
data: &Value,
) -> WebSocketResult<Option<StreamEvent>> {
if channel.contains(".tickers") {
let ticker = GateioParser::parse_ws_ticker(data)
.map_err(|e| WebSocketError::Parse(e.to_string()))?;
Ok(Some(StreamEvent::Ticker(ticker)))
} else if channel.contains(".trades") {
let trade = GateioParser::parse_ws_trade(data)
.map_err(|e| WebSocketError::Parse(e.to_string()))?;
Ok(Some(StreamEvent::Trade(trade)))
} else if channel.contains(".order_book") {
let orderbook = Self::parse_orderbook_ws(data)
.map_err(|e| WebSocketError::Parse(e.to_string()))?;
Ok(Some(StreamEvent::OrderbookSnapshot(orderbook)))
} else if channel.contains(".candlesticks") {
let kline = Self::parse_kline_ws(data)
.map_err(|e| WebSocketError::Parse(e.to_string()))?;
Ok(Some(StreamEvent::Kline(kline)))
} else if channel.contains(".orders") {
let event = GateioParser::parse_ws_order_update(data)
.map_err(|e| WebSocketError::Parse(e.to_string()))?;
Ok(Some(StreamEvent::OrderUpdate(event)))
} else if channel.contains(".balances") {
let event = GateioParser::parse_ws_balance_update(data)
.map_err(|e| WebSocketError::Parse(e.to_string()))?;
Ok(Some(StreamEvent::BalanceUpdate(event)))
} else if channel.contains(".positions") {
let event = GateioParser::parse_ws_position_update(data)
.map_err(|e| WebSocketError::Parse(e.to_string()))?;
Ok(Some(StreamEvent::PositionUpdate(event)))
} else {
Ok(None)
}
}
fn start_ping_task(
ws_writer: Arc<Mutex<Option<WsSink>>>,
ping_interval: Duration,
last_ping: Arc<Mutex<Instant>>,
account_type: AccountType,
) {
tokio::spawn(async move {
loop {
sleep(Duration::from_millis(1000)).await;
let last = *last_ping.lock().await;
if last.elapsed() >= ping_interval {
let ping_channel = match account_type {
AccountType::Spot | AccountType::Margin => "spot.ping",
AccountType::FuturesCross | AccountType::FuturesIsolated => "futures.ping",
AccountType::Earn | AccountType::Lending | AccountType::Options | AccountType::Convert => "spot.ping",
};
let ping = json!({
"time": timestamp_seconds() as i64,
"channel": ping_channel
});
let msg_json = serde_json::to_string(&ping)
.expect("JSON serialization should never fail for valid struct");
let mut writer_guard = ws_writer.lock().await;
if let Some(ref mut writer) = *writer_guard {
if writer.send(Message::Text(msg_json)).await.is_ok() {
*last_ping.lock().await = Instant::now();
} else {
break;
}
} else {
break;
}
}
}
});
}
fn build_channel(request: &SubscriptionRequest, account_type: AccountType) -> String {
let prefix = match account_type {
AccountType::Spot | AccountType::Margin => "spot",
AccountType::FuturesCross | AccountType::FuturesIsolated => "futures",
AccountType::Earn | AccountType::Lending | AccountType::Options | AccountType::Convert => "spot",
};
match &request.stream_type {
StreamType::Ticker => format!("{}.tickers", prefix),
StreamType::Trade => format!("{}.trades", prefix),
StreamType::Orderbook | StreamType::OrderbookDelta => format!("{}.order_book", prefix),
StreamType::Kline { .. } => format!("{}.candlesticks", prefix),
StreamType::MarkPrice => format!("{}.tickers", prefix), StreamType::FundingRate => format!("{}.tickers", prefix), StreamType::OrderUpdate => format!("{}.orders", prefix),
StreamType::BalanceUpdate => format!("{}.balances", prefix),
StreamType::PositionUpdate => format!("{}.positions", prefix),
}
}
fn build_payload(request: &SubscriptionRequest, account_type: AccountType) -> Vec<String> {
let symbol = format_symbol(&request.symbol.base, &request.symbol.quote, account_type);
match &request.stream_type {
StreamType::Ticker => vec![symbol],
StreamType::Trade => vec![symbol],
StreamType::Orderbook | StreamType::OrderbookDelta => {
let depth = request.depth.unwrap_or(20).to_string();
let speed = request.update_speed_ms.map(|ms| format!("{}ms", ms)).unwrap_or_else(|| "1000ms".to_string());
vec![symbol, depth, speed]
}
StreamType::Kline { interval } => vec![interval.to_string(), symbol],
StreamType::MarkPrice => vec![symbol],
StreamType::FundingRate => vec![symbol],
StreamType::OrderUpdate => vec![symbol],
StreamType::BalanceUpdate => vec![],
StreamType::PositionUpdate => vec![symbol],
}
}
fn is_private(stream_type: &StreamType) -> bool {
matches!(
stream_type,
StreamType::OrderUpdate | StreamType::BalanceUpdate | StreamType::PositionUpdate
)
}
async fn ws_rate_limit_wait(weight: u32) {
loop {
let wait_time = {
let limiter = get_ws_rate_limiter();
let mut guard = limiter.lock().expect("Mutex poisoned");
if guard.try_acquire(weight) {
return; }
guard.time_until_ready(weight)
};
if wait_time > Duration::ZERO {
sleep(wait_time).await;
}
}
}
fn generate_auth_signature(
auth: &GateioAuth,
channel: &str,
event: &str,
timestamp: i64,
) -> ExchangeResult<(String, String)> {
let sign_str = format!("channel={}&event={}&time={}", channel, event, timestamp);
let signature = auth.sign_ws(&sign_str)?;
let api_key = auth.api_key().to_string();
Ok((api_key, signature))
}
fn parse_orderbook_ws(data: &Value) -> ExchangeResult<crate::core::OrderBook> {
let parse_levels = |key: &str| -> Vec<OrderBookLevel> {
data.get(key)
.and_then(|v| v.as_array())
.map(|arr| {
arr.iter()
.filter_map(|level| {
let pair = level.as_array()?;
if pair.len() < 2 { return None; }
let price = pair[0].as_str()?.parse::<f64>().ok()?;
let size = pair[1].as_str()?.parse::<f64>().ok()?;
Some(OrderBookLevel::new(price, size))
})
.collect()
})
.unwrap_or_default()
};
Ok(crate::core::OrderBook {
timestamp: data.get("t")
.and_then(|t| t.as_i64())
.unwrap_or(0),
bids: parse_levels("bids"),
asks: parse_levels("asks"),
sequence: data.get("lastUpdateId")
.and_then(|s| s.as_i64())
.map(|n| n.to_string()),
last_update_id: None,
first_update_id: None,
prev_update_id: None,
event_time: None,
transaction_time: None,
checksum: None,
})
}
fn parse_kline_ws(data: &Value) -> ExchangeResult<crate::core::Kline> {
let open_time = data.get("t")
.and_then(|t| t.as_str())
.and_then(|s| s.parse::<i64>().ok())
.unwrap_or(0) * 1000;
let parse_f64 = |key: &str| -> f64 {
data.get(key)
.and_then(|v| v.as_str())
.and_then(|s| s.parse::<f64>().ok())
.unwrap_or(0.0)
};
Ok(crate::core::Kline {
open_time,
open: parse_f64("o"),
high: parse_f64("h"),
low: parse_f64("l"),
close: parse_f64("c"),
volume: parse_f64("v"),
quote_volume: Some(parse_f64("a")),
close_time: None,
trades: None,
})
}
}
#[async_trait]
impl WebSocketConnector for GateioWebSocket {
async fn connect(&mut self, account_type: AccountType) -> WebSocketResult<()> {
*self.status.lock().await = ConnectionStatus::Connecting;
self.account_type = account_type;
let ws_url = self.urls.ws_url(account_type);
let (ws_stream, _) = connect_async(ws_url).await
.map_err(|e| WebSocketError::ConnectionError(format!("WebSocket connection failed: {}", e)))?;
let (write, read) = ws_stream.split();
*self.ws_writer.lock().await = Some(write);
let (tx, _) = broadcast::channel(1000);
*self.event_tx.lock().unwrap() = Some(tx);
Self::start_message_loop(
read,
self.event_tx.clone(),
self.status.clone(),
account_type,
self.last_ping.clone(),
self.ws_ping_rtt_ms.clone(),
);
Self::start_ping_task(
self.ws_writer.clone(),
self.ping_interval,
self.last_ping.clone(),
account_type,
);
*self.status.lock().await = ConnectionStatus::Connected;
Ok(())
}
async fn disconnect(&mut self) -> WebSocketResult<()> {
if let Some(mut writer) = self.ws_writer.lock().await.take() {
let _ = writer.close().await;
}
*self.status.lock().await = ConnectionStatus::Disconnected;
self.subscriptions.lock().await.clear();
Ok(())
}
fn connection_status(&self) -> ConnectionStatus {
match self.status.try_lock() {
Ok(status) => *status,
Err(_) => ConnectionStatus::Disconnected,
}
}
async fn subscribe(&mut self, request: SubscriptionRequest) -> WebSocketResult<()> {
Self::ws_rate_limit_wait(1).await;
let channel = Self::build_channel(&request, self.account_type);
let payload = Self::build_payload(&request, self.account_type);
let timestamp = timestamp_seconds() as i64;
let mut msg = OutgoingMessage {
time: timestamp,
channel: channel.clone(),
event: "subscribe".to_string(),
payload: if payload.is_empty() { None } else { Some(payload) },
auth: None,
};
if Self::is_private(&request.stream_type) {
let auth = self.auth.as_ref()
.ok_or_else(|| WebSocketError::ConnectionError("Authentication required for private channels".to_string()))?;
let (api_key, signature) = Self::generate_auth_signature(auth, &channel, "subscribe", timestamp)
.map_err(|e| WebSocketError::ConnectionError(e.to_string()))?;
msg.auth = Some(AuthData {
method: "api_key".to_string(),
key: api_key,
sign: signature,
});
}
let msg_json = serde_json::to_string(&msg)
.map_err(|e| WebSocketError::ProtocolError(e.to_string()))?;
self.send_text(msg_json).await?;
self.subscriptions.lock().await.insert(request);
Ok(())
}
async fn unsubscribe(&mut self, request: SubscriptionRequest) -> WebSocketResult<()> {
Self::ws_rate_limit_wait(1).await;
let channel = Self::build_channel(&request, self.account_type);
let payload = Self::build_payload(&request, self.account_type);
let msg = OutgoingMessage {
time: timestamp_seconds() as i64,
channel,
event: "unsubscribe".to_string(),
payload: if payload.is_empty() { None } else { Some(payload) },
auth: None,
};
let msg_json = serde_json::to_string(&msg)
.map_err(|e| WebSocketError::ProtocolError(e.to_string()))?;
self.send_text(msg_json).await?;
self.subscriptions.lock().await.remove(&request);
Ok(())
}
fn event_stream(&self) -> Pin<Box<dyn Stream<Item = WebSocketResult<StreamEvent>> + Send>> {
let tx_guard = self.event_tx.lock().unwrap();
if let Some(ref tx) = *tx_guard {
let rx = tx.subscribe();
Box::pin(tokio_stream::wrappers::BroadcastStream::new(rx).map(|r| {
r.map_err(|e| WebSocketError::ConnectionError(format!("Broadcast error: {}", e)))
.and_then(|x| x)
}))
} else {
Box::pin(futures_util::stream::empty())
}
}
fn active_subscriptions(&self) -> Vec<SubscriptionRequest> {
match self.subscriptions.try_lock() {
Ok(subs) => subs.iter().cloned().collect(),
Err(_) => Vec::new(),
}
}
fn ping_rtt_handle(&self) -> Option<Arc<Mutex<u64>>> {
Some(self.ws_ping_rtt_ms.clone())
}
fn orderbook_capabilities(&self, _account_type: AccountType) -> OrderbookCapabilities {
static GATEIO_CHANNELS: &[WsBookChannel] = &[
WsBookChannel { name: "spot.book_ticker", depth: Some(1), is_snapshot: true, update_speed_ms: Some(10), requires_auth_tier: false },
WsBookChannel { name: "spot.order_book_update", depth: Some(100), is_snapshot: false, update_speed_ms: Some(100), requires_auth_tier: false },
WsBookChannel { name: "spot.order_book_update", depth: Some(20), is_snapshot: false, update_speed_ms: Some(20), requires_auth_tier: false },
WsBookChannel { name: "spot.order_book", depth: Some(100), is_snapshot: true, update_speed_ms: Some(100), requires_auth_tier: false },
WsBookChannel { name: "spot.obu", depth: Some(400), is_snapshot: false, update_speed_ms: Some(100), requires_auth_tier: false },
WsBookChannel { name: "spot.obu", depth: Some(50), is_snapshot: false, update_speed_ms: Some(20), requires_auth_tier: false },
];
OrderbookCapabilities {
ws_depths: &[5, 10, 20, 50, 100, 400],
ws_default_depth: Some(100),
rest_max_depth: Some(1000),
rest_depth_values: &[],
supports_snapshot: true,
supports_delta: true,
update_speeds_ms: &[10, 20, 100, 1000],
default_speed_ms: Some(100),
ws_channels: GATEIO_CHANNELS,
checksum: None,
has_sequence: true,
has_prev_sequence: false,
supports_aggregation: false,
aggregation_levels: &[],
}
}
}