use std::collections::HashSet;
use std::pin::Pin;
use std::sync::{Arc, Mutex as StdMutex};
use std::time::Duration;
use async_trait::async_trait;
use futures_util::{SinkExt, Stream, StreamExt};
use serde_json::{json, Value};
use tokio::sync::{broadcast, Mutex};
use tokio::time::{interval, Instant};
use tokio_tungstenite::{
connect_async,
tungstenite::Message,
MaybeTlsStream,
WebSocketStream,
};
use crate::core::{
AccountType, ConnectionStatus, Credentials, ExchangeResult, OrderBook,
StreamEvent, SubscriptionRequest, timestamp_iso8601,
};
use crate::core::types::OrderbookDelta;
use crate::core::traits::WebSocketConnector;
use crate::core::types::{WebSocketError, WebSocketResult, OrderbookCapabilities, WsBookChannel, ChecksumInfo, ChecksumAlgorithm};
use super::auth::OkxAuth;
use super::endpoints::{format_symbol, OkxUrls};
use super::parser::OkxParser;
type WsStream = WebSocketStream<MaybeTlsStream<tokio::net::TcpStream>>;
type WsSink = futures_util::stream::SplitSink<WsStream, Message>;
type WsReader = futures_util::stream::SplitStream<WsStream>;
pub struct OkxWebSocket {
auth: Option<OkxAuth>,
urls: OkxUrls,
status: Arc<Mutex<ConnectionStatus>>,
subscriptions: Arc<Mutex<HashSet<SubscriptionRequest>>>,
broadcast_tx: Arc<StdMutex<Option<broadcast::Sender<WebSocketResult<StreamEvent>>>>>,
ws_sink: Arc<Mutex<Option<WsSink>>>,
ws_reader: Arc<Mutex<Option<WsReader>>>,
last_ping: Arc<Mutex<Instant>>,
ws_ping_rtt_ms: Arc<Mutex<u64>>,
is_private: bool,
}
impl OkxWebSocket {
pub async fn new(
credentials: Option<Credentials>,
testnet: bool,
) -> ExchangeResult<Self> {
let urls = if testnet {
OkxUrls::TESTNET
} else {
OkxUrls::MAINNET
};
let auth = credentials
.as_ref()
.map(OkxAuth::new)
.transpose()?;
Ok(Self {
auth,
urls,
status: Arc::new(Mutex::new(ConnectionStatus::Disconnected)),
subscriptions: Arc::new(Mutex::new(HashSet::new())),
broadcast_tx: Arc::new(StdMutex::new(None)),
ws_sink: Arc::new(Mutex::new(None)),
ws_reader: Arc::new(Mutex::new(None)),
last_ping: Arc::new(Mutex::new(Instant::now())),
ws_ping_rtt_ms: Arc::new(Mutex::new(0)),
is_private: false,
})
}
async fn send_login(&self, sink: &mut WsSink) -> WebSocketResult<()> {
let auth = self.auth.as_ref().ok_or_else(|| {
WebSocketError::Auth("Private channels require authentication".to_string())
})?;
let timestamp = timestamp_iso8601();
let signature = auth.sign_websocket_login(×tamp);
let login_msg = json!({
"op": "login",
"args": [{
"apiKey": auth.api_key(),
"passphrase": auth.passphrase,
"timestamp": timestamp,
"sign": signature,
}]
});
sink.send(Message::Text(login_msg.to_string()))
.await
.map_err(|e| WebSocketError::ConnectionError(e.to_string()))?;
Ok(())
}
fn start_ping_task(
ws_sink: Arc<Mutex<Option<WsSink>>>,
last_ping: Arc<Mutex<Instant>>,
) {
tokio::spawn(async move {
let mut ticker = interval(Duration::from_secs(5));
loop {
ticker.tick().await;
let mut sink_guard = ws_sink.lock().await;
if let Some(sink) = sink_guard.as_mut() {
if sink.send(Message::Text("ping".to_string())).await.is_ok() {
*last_ping.lock().await = Instant::now();
} else {
break;
}
} else {
break;
}
}
});
}
fn start_message_handler(
ws_reader: Arc<Mutex<Option<WsReader>>>,
broadcast_tx: Arc<StdMutex<Option<broadcast::Sender<WebSocketResult<StreamEvent>>>>>,
status: Arc<Mutex<ConnectionStatus>>,
last_ping: Arc<Mutex<Instant>>,
ws_ping_rtt_ms: Arc<Mutex<u64>>,
) {
tokio::spawn(async move {
loop {
let msg = {
let mut reader_guard = ws_reader.lock().await;
if let Some(reader) = reader_guard.as_mut() {
reader.next().await
} else {
break;
}
};
match msg {
Some(Ok(Message::Text(text))) => {
if text.trim() == "pong" {
let rtt = last_ping.lock().await.elapsed().as_millis() as u64;
*ws_ping_rtt_ms.lock().await = rtt;
continue;
}
if let Ok(value) = serde_json::from_str::<Value>(&text) {
if let Some(event) = value.get("event").and_then(|e| e.as_str()) {
match event {
"subscribe" | "unsubscribe" | "login" => {
continue;
}
"error" => {
let code = value
.get("code")
.and_then(|c| c.as_str())
.unwrap_or("unknown");
let msg_text = value
.get("msg")
.and_then(|m| m.as_str())
.unwrap_or("Unknown error");
let tx_guard = broadcast_tx.lock().unwrap();
if let Some(ref tx) = *tx_guard {
let _ = tx.send(Err(
WebSocketError::ProtocolError(format!(
"{}: {}",
code, msg_text
)),
));
}
continue;
}
_ => {}
}
}
if let Some(arg) = value.get("arg") {
if let Some(channel) =
arg.get("channel").and_then(|c| c.as_str())
{
let action = value.get("action").and_then(|a| a.as_str());
if let Some(data_arr) =
value.get("data").and_then(|d| d.as_array())
{
for data in data_arr {
let event =
Self::parse_channel_data(channel, data, action);
if let Some(ev) = event {
let tx_guard = broadcast_tx.lock().unwrap();
if let Some(ref tx) = *tx_guard {
let _ = tx.send(Ok(ev));
}
}
}
}
}
}
}
}
Some(Ok(Message::Close(_))) => {
*status.lock().await = ConnectionStatus::Disconnected;
break;
}
Some(Err(_)) | None => {
*status.lock().await = ConnectionStatus::Disconnected;
break;
}
_ => {}
}
}
let _ = broadcast_tx.lock().unwrap().take();
*status.lock().await = ConnectionStatus::Disconnected;
});
}
fn parse_channel_data(channel: &str, data: &Value, action: Option<&str>) -> Option<StreamEvent> {
match channel {
"tickers" => OkxParser::parse_ws_ticker(data)
.ok()
.map(StreamEvent::Ticker),
"books" | "books5" | "books-l2-tbt" | "books50-l2-tbt" => {
let (asks, bids) = OkxParser::parse_ws_orderbook(data).ok()?;
let timestamp = OkxParser::get_i64(data, "ts").unwrap_or(0);
let seq_id = data.get("seqId").and_then(|v| v.as_u64());
let prev_seq_id = data.get("prevSeqId").and_then(|v| v.as_u64());
let checksum = data.get("checksum").and_then(|v| v.as_i64());
if action == Some("snapshot") {
let orderbook = OrderBook {
asks,
bids,
timestamp,
sequence: None,
last_update_id: seq_id,
first_update_id: seq_id,
prev_update_id: prev_seq_id,
event_time: Some(timestamp),
transaction_time: None,
checksum,
};
Some(StreamEvent::OrderbookSnapshot(orderbook))
} else {
let delta = OrderbookDelta {
asks,
bids,
timestamp,
first_update_id: seq_id,
last_update_id: seq_id,
prev_update_id: prev_seq_id,
event_time: Some(timestamp),
checksum,
};
Some(StreamEvent::OrderbookDelta(delta))
}
}
"trades" => OkxParser::parse_ws_trade(data)
.ok()
.map(StreamEvent::Trade),
"candle1m" | "candle5m" | "candle15m" | "candle30m" | "candle1H"
| "candle4H" | "candle1D" => OkxParser::parse_ws_kline(data)
.ok()
.map(StreamEvent::Kline),
"orders" => OkxParser::parse_ws_order_update(data)
.ok()
.map(StreamEvent::OrderUpdate),
"account" => {
if let Some(details) = data.get("details").and_then(|d| d.as_array()) {
for detail in details {
if let Ok(event) = OkxParser::parse_ws_balance_update(detail) {
return Some(StreamEvent::BalanceUpdate(event));
}
}
}
None
}
"positions" => OkxParser::parse_ws_position_update(data)
.ok()
.map(StreamEvent::PositionUpdate),
_ => None,
}
}
pub fn ping_rtt_ms(&self) -> u64 {
match self.ws_ping_rtt_ms.try_lock() {
Ok(guard) => *guard,
Err(_) => 0,
}
}
pub fn ping_rtt_handle(&self) -> Arc<Mutex<u64>> {
self.ws_ping_rtt_ms.clone()
}
}
#[async_trait]
impl WebSocketConnector for OkxWebSocket {
async fn connect(&mut self, _account_type: AccountType) -> WebSocketResult<()> {
let url = if self.auth.is_some() {
self.is_private = true;
self.urls.ws_url(true)
} else {
self.is_private = false;
self.urls.ws_url(false)
};
let (ws_stream, _) = connect_async(url)
.await
.map_err(|e| WebSocketError::ConnectionError(e.to_string()))?;
let (mut sink, reader) = ws_stream.split();
if self.is_private {
self.send_login(&mut sink).await?;
}
*self.ws_sink.lock().await = Some(sink);
*self.ws_reader.lock().await = Some(reader);
*self.status.lock().await = ConnectionStatus::Connected;
let (tx, _) = broadcast::channel(1000);
*self.broadcast_tx.lock().unwrap() = Some(tx);
Self::start_ping_task(self.ws_sink.clone(), self.last_ping.clone());
Self::start_message_handler(
self.ws_reader.clone(),
self.broadcast_tx.clone(),
self.status.clone(),
self.last_ping.clone(),
self.ws_ping_rtt_ms.clone(),
);
Ok(())
}
async fn disconnect(&mut self) -> WebSocketResult<()> {
{
let mut sink_guard = self.ws_sink.lock().await;
if let Some(sink) = sink_guard.as_mut() {
let _ = sink.send(Message::Close(None)).await;
}
*sink_guard = None;
}
*self.ws_reader.lock().await = None;
*self.status.lock().await = ConnectionStatus::Disconnected;
let _ = self.broadcast_tx.lock().unwrap().take();
Ok(())
}
async fn subscribe(&mut self, request: SubscriptionRequest) -> WebSocketResult<()> {
let channel = match &request.stream_type {
crate::core::StreamType::Ticker => "tickers",
crate::core::StreamType::Orderbook => {
match request.depth {
Some(5) => "books5",
Some(50) => "books50-l2-tbt",
_ => "books",
}
}
crate::core::StreamType::OrderbookDelta => {
match request.depth {
Some(50) => "books50-l2-tbt",
_ => "books-l2-tbt",
}
}
crate::core::StreamType::Trade => "trades",
crate::core::StreamType::Kline { interval } => match interval.as_str() {
"1m" => "candle1m",
"5m" => "candle5m",
"15m" => "candle15m",
"30m" => "candle30m",
"1h" => "candle1H",
"4h" => "candle4H",
"1d" => "candle1D",
_ => "candle1H",
},
crate::core::StreamType::MarkPrice => "mark-price",
crate::core::StreamType::FundingRate => "funding-rate",
crate::core::StreamType::OrderUpdate => "orders",
crate::core::StreamType::BalanceUpdate => "account",
crate::core::StreamType::PositionUpdate => "positions",
};
let account_type = request.account_type;
let inst_id = format_symbol(&request.symbol.base, &request.symbol.quote, account_type);
let sub_msg = json!({
"op": "subscribe",
"args": [{
"channel": channel,
"instId": inst_id,
}]
});
let mut sink_guard = self.ws_sink.lock().await;
if let Some(sink) = sink_guard.as_mut() {
sink.send(Message::Text(sub_msg.to_string()))
.await
.map_err(|e| WebSocketError::ConnectionError(e.to_string()))?;
self.subscriptions.lock().await.insert(request);
Ok(())
} else {
Err(WebSocketError::ConnectionError("Not connected".to_string()))
}
}
async fn unsubscribe(&mut self, request: SubscriptionRequest) -> WebSocketResult<()> {
let channel = match &request.stream_type {
crate::core::StreamType::Ticker => "tickers",
crate::core::StreamType::Orderbook => "books",
crate::core::StreamType::OrderbookDelta => "books",
crate::core::StreamType::Trade => "trades",
crate::core::StreamType::Kline { interval: _ } => "candle1H",
crate::core::StreamType::MarkPrice => "mark-price",
crate::core::StreamType::FundingRate => "funding-rate",
crate::core::StreamType::OrderUpdate => "orders",
crate::core::StreamType::BalanceUpdate => "account",
crate::core::StreamType::PositionUpdate => "positions",
};
let account_type = request.account_type;
let inst_id = format_symbol(&request.symbol.base, &request.symbol.quote, account_type);
let unsub_msg = json!({
"op": "unsubscribe",
"args": [{
"channel": channel,
"instId": inst_id,
}]
});
let mut sink_guard = self.ws_sink.lock().await;
if let Some(sink) = sink_guard.as_mut() {
sink.send(Message::Text(unsub_msg.to_string()))
.await
.map_err(|e| WebSocketError::ConnectionError(e.to_string()))?;
self.subscriptions.lock().await.remove(&request);
Ok(())
} else {
Err(WebSocketError::ConnectionError("Not connected".to_string()))
}
}
fn event_stream(
&self,
) -> Pin<Box<dyn Stream<Item = WebSocketResult<StreamEvent>> + Send + 'static>> {
let tx_guard = self.broadcast_tx.lock().unwrap();
if let Some(ref tx) = *tx_guard {
let rx = tx.subscribe();
Box::pin(tokio_stream::wrappers::BroadcastStream::new(rx).map(|r| {
r.map_err(|e| WebSocketError::ConnectionError(format!("Broadcast error: {}", e)))
.and_then(|x| x)
}))
} else {
Box::pin(futures_util::stream::empty())
}
}
fn connection_status(&self) -> ConnectionStatus {
match self.status.try_lock() {
Ok(guard) => *guard,
Err(_) => ConnectionStatus::Connected,
}
}
fn active_subscriptions(&self) -> Vec<SubscriptionRequest> {
match self.subscriptions.try_lock() {
Ok(guard) => guard.iter().cloned().collect(),
Err(_) => Vec::new(),
}
}
fn ping_rtt_handle(&self) -> Option<Arc<Mutex<u64>>> {
Some(self.ws_ping_rtt_ms.clone())
}
fn orderbook_capabilities(&self, _account_type: AccountType) -> OrderbookCapabilities {
static OKX_CHANNELS: &[WsBookChannel] = &[
WsBookChannel::snapshot("bbo-tbt", 1, 10),
WsBookChannel::snapshot("books5", 5, 100),
WsBookChannel::delta("books", Some(400), Some(100)),
WsBookChannel::delta("books50-l2-tbt", Some(50), Some(10)).with_auth_tier(),
WsBookChannel::delta("books-l2-tbt", Some(400), Some(10)).with_auth_tier(),
];
OrderbookCapabilities {
ws_depths: &[1, 5, 50, 400],
ws_default_depth: Some(400),
rest_max_depth: Some(400),
rest_depth_values: &[],
supports_snapshot: true,
supports_delta: true,
update_speeds_ms: &[10, 100],
default_speed_ms: Some(100),
ws_channels: OKX_CHANNELS,
checksum: Some(ChecksumInfo {
algorithm: ChecksumAlgorithm::Crc32Interleaved,
levels_per_side: 25,
opt_in: false,
}),
has_sequence: true,
has_prev_sequence: true,
supports_aggregation: false,
aggregation_levels: &[],
}
}
}