1use crate::okx::parser::{parse_orderbook, parse_ticker, parse_trade};
8use ccxt_core::error::{Error, Result};
9use ccxt_core::types::{Market, OrderBook, Ticker, Trade};
10use ccxt_core::ws_client::{WsClient, WsConfig, WsConnectionState};
11use ccxt_core::ws_exchange::MessageStream;
12use futures::Stream;
13use serde_json::Value;
14use std::pin::Pin;
15use std::sync::Arc;
16use std::task::{Context, Poll};
17use tokio::sync::{RwLock, mpsc};
18
19const DEFAULT_PING_INTERVAL_MS: u64 = 25000;
22
23const DEFAULT_RECONNECT_INTERVAL_MS: u64 = 5000;
25
26const MAX_RECONNECT_ATTEMPTS: u32 = 10;
28
29pub struct OkxWs {
33 client: Arc<WsClient>,
35 subscriptions: Arc<RwLock<Vec<String>>>,
37}
38
39impl OkxWs {
40 pub fn new(url: String) -> Self {
46 let config = WsConfig {
47 url: url.clone(),
48 connect_timeout: 10000,
49 ping_interval: DEFAULT_PING_INTERVAL_MS,
50 reconnect_interval: DEFAULT_RECONNECT_INTERVAL_MS,
51 max_reconnect_attempts: MAX_RECONNECT_ATTEMPTS,
52 auto_reconnect: true,
53 enable_compression: false,
54 pong_timeout: 90000,
55 ..Default::default()
56 };
57
58 Self {
59 client: Arc::new(WsClient::new(config)),
60 subscriptions: Arc::new(RwLock::new(Vec::new())),
61 }
62 }
63
64 pub async fn connect(&self) -> Result<()> {
66 self.client.connect().await
67 }
68
69 pub async fn disconnect(&self) -> Result<()> {
71 self.client.disconnect().await
72 }
73
74 pub fn state(&self) -> WsConnectionState {
76 self.client.state()
77 }
78
79 pub fn is_connected(&self) -> bool {
81 self.client.is_connected()
82 }
83
84 pub async fn receive(&self) -> Option<Value> {
86 self.client.receive().await
87 }
88
89 pub async fn subscribe_ticker(&self, symbol: &str) -> Result<()> {
95 #[allow(clippy::disallowed_methods)]
98 let msg = serde_json::json!({
99 "op": "subscribe",
100 "args": [{
101 "channel": "tickers",
102 "instId": symbol
103 }]
104 });
105
106 self.client.send_json(&msg).await?;
107
108 let sub_key = format!("ticker:{}", symbol);
109 self.subscriptions.write().await.push(sub_key);
110
111 Ok(())
112 }
113
114 pub async fn subscribe_tickers(&self, symbols: &[String]) -> Result<()> {
120 let mut args = Vec::new();
121 for symbol in symbols {
122 let mut arg_map = serde_json::Map::new();
123 arg_map.insert(
124 "channel".to_string(),
125 serde_json::Value::String("tickers".to_string()),
126 );
127 arg_map.insert(
128 "instId".to_string(),
129 serde_json::Value::String(symbol.clone()),
130 );
131 args.push(serde_json::Value::Object(arg_map));
132 }
133
134 #[allow(clippy::disallowed_methods)]
136 let msg = serde_json::json!({
137 "op": "subscribe",
138 "args": args
139 });
140
141 self.client.send_json(&msg).await?;
142
143 let mut subs = self.subscriptions.write().await;
144 for symbol in symbols {
145 subs.push(format!("ticker:{}", symbol));
146 }
147
148 Ok(())
149 }
150
151 pub async fn watch_tickers(&self, symbols: &[String]) -> Result<MessageStream<Vec<Ticker>>> {
163 if !self.is_connected() {
165 self.connect().await?;
166 }
167
168 self.subscribe_tickers(symbols).await?;
170
171 let (tx, rx) = mpsc::unbounded_channel::<Result<Vec<Ticker>>>();
173 let symbols_owned: Vec<String> = symbols.to_vec();
174 let client = Arc::clone(&self.client);
175
176 tokio::spawn(async move {
178 while let Some(msg) = client.receive().await {
179 if let Some(arg) = msg.get("arg") {
181 let channel = arg.get("channel").and_then(|c| c.as_str());
182 let inst_id = arg.get("instId").and_then(|i| i.as_str());
183
184 if channel == Some("tickers") {
185 if let Some(id) = inst_id {
186 if symbols_owned.iter().any(|s| s == id) {
187 match parse_ws_ticker(&msg, None) {
188 Ok(ticker) => {
189 if tx.send(Ok(vec![ticker])).is_err() {
190 break; }
192 }
193 Err(e) => {
194 if tx.send(Err(e)).is_err() {
195 break;
196 }
197 }
198 }
199 }
200 }
201 }
202 }
203 }
204 });
205
206 Ok(Box::pin(ReceiverStream::new(rx)))
207 }
208 pub async fn subscribe_orderbook(&self, symbol: &str, depth: u32) -> Result<()> {
210 let channel = match depth {
212 d if d <= 5 => "books5",
213 d if d <= 50 => "books50-l2",
214 _ => "books",
215 };
216
217 #[allow(clippy::disallowed_methods)]
219 let msg = serde_json::json!({
220 "op": "subscribe",
221 "args": [{
222 "channel": channel,
223 "instId": symbol
224 }]
225 });
226
227 self.client.send_json(&msg).await?;
228
229 let sub_key = format!("orderbook:{}", symbol);
230 self.subscriptions.write().await.push(sub_key);
231
232 Ok(())
233 }
234
235 pub async fn subscribe_trades(&self, symbol: &str) -> Result<()> {
241 #[allow(clippy::disallowed_methods)]
243 let msg = serde_json::json!({
244 "op": "subscribe",
245 "args": [{
246 "channel": "trades",
247 "instId": symbol
248 }]
249 });
250
251 self.client.send_json(&msg).await?;
252
253 let sub_key = format!("trades:{}", symbol);
254 self.subscriptions.write().await.push(sub_key);
255
256 Ok(())
257 }
258
259 pub async fn subscribe_kline(&self, symbol: &str, interval: &str) -> Result<()> {
266 let channel = format!("candle{}", interval);
267
268 #[allow(clippy::disallowed_methods)]
270 let msg = serde_json::json!({
271 "op": "subscribe",
272 "args": [{
273 "channel": channel,
274 "instId": symbol
275 }]
276 });
277
278 self.client.send_json(&msg).await?;
279
280 let sub_key = format!("kline:{}:{}", symbol, interval);
281 self.subscriptions.write().await.push(sub_key);
282
283 Ok(())
284 }
285
286 pub async fn unsubscribe(&self, stream_name: String) -> Result<()> {
292 let parts: Vec<&str> = stream_name.split(':').collect();
294 if parts.len() < 2 {
295 return Err(Error::invalid_request(format!(
296 "Invalid stream name: {}",
297 stream_name
298 )));
299 }
300
301 let channel_type = parts[0];
302 let symbol = parts[1];
303
304 let channel = match channel_type {
305 "ticker" => "tickers".to_string(),
306 "orderbook" => "books5".to_string(),
307 "trades" => "trades".to_string(),
308 "kline" => {
309 if parts.len() >= 3 {
310 format!("candle{}", parts[2])
311 } else {
312 return Err(Error::invalid_request(
313 "Kline unsubscribe requires interval",
314 ));
315 }
316 }
317 _ => {
318 return Err(Error::invalid_request(format!(
319 "Unknown channel: {}",
320 channel_type
321 )));
322 }
323 };
324
325 #[allow(clippy::disallowed_methods)]
327 let msg = serde_json::json!({
328 "op": "unsubscribe",
329 "args": [{
330 "channel": channel,
331 "instId": symbol
332 }]
333 });
334
335 self.client.send_json(&msg).await?;
336
337 let mut subs = self.subscriptions.write().await;
339 subs.retain(|s| s != &stream_name);
340
341 Ok(())
342 }
343
344 pub async fn subscriptions(&self) -> Vec<String> {
346 self.subscriptions.read().await.clone()
347 }
348
349 pub async fn watch_ticker(
362 &self,
363 symbol: &str,
364 market: Option<Market>,
365 ) -> Result<MessageStream<Ticker>> {
366 if !self.is_connected() {
368 self.connect().await?;
369 }
370
371 self.subscribe_ticker(symbol).await?;
373
374 let (tx, rx) = mpsc::unbounded_channel::<Result<Ticker>>();
376 let symbol_owned = symbol.to_string();
377 let client = Arc::clone(&self.client);
378
379 tokio::spawn(async move {
381 while let Some(msg) = client.receive().await {
382 if is_ticker_message(&msg, &symbol_owned) {
384 match parse_ws_ticker(&msg, market.as_ref()) {
385 Ok(ticker) => {
386 if tx.send(Ok(ticker)).is_err() {
387 break; }
389 }
390 Err(e) => {
391 if tx.send(Err(e)).is_err() {
392 break;
393 }
394 }
395 }
396 }
397 }
398 });
399
400 Ok(Box::pin(ReceiverStream::new(rx)))
401 }
402
403 pub async fn watch_order_book(
416 &self,
417 symbol: &str,
418 limit: Option<u32>,
419 ) -> Result<MessageStream<OrderBook>> {
420 if !self.is_connected() {
422 self.connect().await?;
423 }
424
425 let depth = limit.unwrap_or(5);
427 self.subscribe_orderbook(symbol, depth).await?;
428
429 let (tx, rx) = mpsc::unbounded_channel::<Result<OrderBook>>();
431 let symbol_owned = symbol.to_string();
432 let unified_symbol = format_unified_symbol(&symbol_owned);
433 let client = Arc::clone(&self.client);
434
435 tokio::spawn(async move {
437 while let Some(msg) = client.receive().await {
438 if is_orderbook_message(&msg, &symbol_owned) {
440 match parse_ws_orderbook(&msg, unified_symbol.clone()) {
441 Ok(orderbook) => {
442 if tx.send(Ok(orderbook)).is_err() {
443 break; }
445 }
446 Err(e) => {
447 if tx.send(Err(e)).is_err() {
448 break;
449 }
450 }
451 }
452 }
453 }
454 });
455
456 Ok(Box::pin(ReceiverStream::new(rx)))
457 }
458
459 pub async fn watch_trades(
472 &self,
473 symbol: &str,
474 market: Option<Market>,
475 ) -> Result<MessageStream<Vec<Trade>>> {
476 if !self.is_connected() {
478 self.connect().await?;
479 }
480
481 self.subscribe_trades(symbol).await?;
483
484 let (tx, rx) = mpsc::unbounded_channel::<Result<Vec<Trade>>>();
486 let symbol_owned = symbol.to_string();
487 let client = Arc::clone(&self.client);
488
489 tokio::spawn(async move {
491 while let Some(msg) = client.receive().await {
492 if is_trade_message(&msg, &symbol_owned) {
494 match parse_ws_trades(&msg, market.as_ref()) {
495 Ok(trades) => {
496 if tx.send(Ok(trades)).is_err() {
497 break; }
499 }
500 Err(e) => {
501 if tx.send(Err(e)).is_err() {
502 break;
503 }
504 }
505 }
506 }
507 }
508 });
509
510 Ok(Box::pin(ReceiverStream::new(rx)))
511 }
512}
513
514struct ReceiverStream<T> {
520 receiver: mpsc::UnboundedReceiver<T>,
521}
522
523impl<T> ReceiverStream<T> {
524 fn new(receiver: mpsc::UnboundedReceiver<T>) -> Self {
525 Self { receiver }
526 }
527}
528
529impl<T> Stream for ReceiverStream<T> {
530 type Item = T;
531
532 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
533 self.receiver.poll_recv(cx)
534 }
535}
536
537fn is_ticker_message(msg: &Value, symbol: &str) -> bool {
543 if let Some(arg) = msg.get("arg") {
546 let channel = arg.get("channel").and_then(|c| c.as_str());
547 let inst_id = arg.get("instId").and_then(|i| i.as_str());
548 channel == Some("tickers") && inst_id == Some(symbol)
549 } else {
550 false
551 }
552}
553
554fn is_orderbook_message(msg: &Value, symbol: &str) -> bool {
556 if let Some(arg) = msg.get("arg") {
559 let channel = arg.get("channel").and_then(|c| c.as_str());
560 let inst_id = arg.get("instId").and_then(|i| i.as_str());
561 if let (Some(ch), Some(id)) = (channel, inst_id) {
563 ch.starts_with("books") && id == symbol
564 } else {
565 false
566 }
567 } else {
568 false
569 }
570}
571
572fn is_trade_message(msg: &Value, symbol: &str) -> bool {
574 if let Some(arg) = msg.get("arg") {
577 let channel = arg.get("channel").and_then(|c| c.as_str());
578 let inst_id = arg.get("instId").and_then(|i| i.as_str());
579 channel == Some("trades") && inst_id == Some(symbol)
580 } else {
581 false
582 }
583}
584
585fn format_unified_symbol(symbol: &str) -> String {
587 symbol.replace('-', "/")
588}
589
590pub fn parse_ws_ticker(msg: &Value, market: Option<&Market>) -> Result<Ticker> {
596 let data = msg
599 .get("data")
600 .and_then(|d| d.as_array())
601 .and_then(|arr| arr.first())
602 .ok_or_else(|| Error::invalid_request("Missing data in ticker message"))?;
603
604 parse_ticker(data, market)
605}
606
607pub fn parse_ws_orderbook(msg: &Value, symbol: String) -> Result<OrderBook> {
609 let data = msg
612 .get("data")
613 .and_then(|d| d.as_array())
614 .and_then(|arr| arr.first())
615 .ok_or_else(|| Error::invalid_request("Missing data in orderbook message"))?;
616
617 parse_orderbook(data, symbol)
618}
619
620pub fn parse_ws_trade(msg: &Value, market: Option<&Market>) -> Result<Trade> {
622 let data = msg
625 .get("data")
626 .and_then(|d| d.as_array())
627 .and_then(|arr| arr.first())
628 .ok_or_else(|| Error::invalid_request("Missing data in trade message"))?;
629
630 parse_trade(data, market)
631}
632
633pub fn parse_ws_trades(msg: &Value, market: Option<&Market>) -> Result<Vec<Trade>> {
635 let data_array = msg
638 .get("data")
639 .and_then(|d| d.as_array())
640 .ok_or_else(|| Error::invalid_request("Missing data in trade message"))?;
641
642 let mut trades = Vec::with_capacity(data_array.len());
643 for data in data_array {
644 trades.push(parse_trade(data, market)?);
645 }
646
647 Ok(trades)
648}
649
650#[cfg(test)]
651mod tests {
652 use super::*;
653 use ccxt_core::types::financial::Price;
654 use rust_decimal_macros::dec;
655
656 #[test]
657 fn test_okx_ws_creation() {
658 let ws = OkxWs::new("wss://ws.okx.com:8443/ws/v5/public".to_string());
659 assert!(ws.subscriptions.try_read().is_ok());
660 }
661
662 #[tokio::test]
663 async fn test_subscriptions_empty_by_default() {
664 let ws = OkxWs::new("wss://ws.okx.com:8443/ws/v5/public".to_string());
665 let subs = ws.subscriptions().await;
666 assert!(subs.is_empty());
667 }
668
669 #[test]
672 fn test_is_ticker_message_true() {
673 let msg = serde_json::from_str(
674 r#"{
675 "arg": {"channel": "tickers", "instId": "BTC-USDT"},
676 "data": [{}]
677 }"#,
678 )
679 .unwrap();
680
681 assert!(is_ticker_message(&msg, "BTC-USDT"));
682 }
683
684 #[test]
685 fn test_is_ticker_message_wrong_symbol() {
686 let msg = serde_json::from_str(
687 r#"{
688 "arg": {"channel": "tickers", "instId": "ETH-USDT"},
689 "data": [{}]
690 }"#,
691 )
692 .unwrap();
693
694 assert!(!is_ticker_message(&msg, "BTC-USDT"));
695 }
696
697 #[test]
698 fn test_is_ticker_message_wrong_channel() {
699 let msg = serde_json::from_str(
700 r#"{
701 "arg": {"channel": "trades", "instId": "BTC-USDT"},
702 "data": [{}]
703 }"#,
704 )
705 .unwrap();
706
707 assert!(!is_ticker_message(&msg, "BTC-USDT"));
708 }
709
710 #[test]
711 fn test_is_orderbook_message_books5() {
712 let msg = serde_json::from_str(
713 r#"{
714 "arg": {"channel": "books5", "instId": "BTC-USDT"},
715 "data": [{}]
716 }"#,
717 )
718 .unwrap();
719
720 assert!(is_orderbook_message(&msg, "BTC-USDT"));
721 }
722
723 #[test]
724 fn test_is_orderbook_message_books50() {
725 let msg = serde_json::from_str(
726 r#"{
727 "arg": {"channel": "books50-l2", "instId": "BTC-USDT"},
728 "data": [{}]
729 }"#,
730 )
731 .unwrap();
732
733 assert!(is_orderbook_message(&msg, "BTC-USDT"));
734 }
735
736 #[test]
737 fn test_is_trade_message_true() {
738 let msg = serde_json::from_str(
739 r#"{
740 "arg": {"channel": "trades", "instId": "BTC-USDT"},
741 "data": [{}]
742 }"#,
743 )
744 .unwrap();
745
746 assert!(is_trade_message(&msg, "BTC-USDT"));
747 }
748
749 #[test]
750 fn test_format_unified_symbol() {
751 assert_eq!(format_unified_symbol("BTC-USDT"), "BTC/USDT");
752 assert_eq!(format_unified_symbol("ETH-BTC"), "ETH/BTC");
753 }
754
755 #[test]
758 fn test_parse_ws_ticker() {
759 let msg = serde_json::from_str(
760 r#"{
761 "arg": {"channel": "tickers", "instId": "BTC-USDT"},
762 "data": [{
763 "instId": "BTC-USDT",
764 "last": "50000.00",
765 "high24h": "51000.00",
766 "low24h": "49000.00",
767 "bidPx": "49999.00",
768 "askPx": "50001.00",
769 "vol24h": "1000.5",
770 "ts": "1700000000000"
771 }]
772 }"#,
773 )
774 .unwrap();
775
776 let ticker = parse_ws_ticker(&msg, None).unwrap();
777 assert_eq!(ticker.symbol, "BTC/USDT");
779 assert_eq!(ticker.last, Some(Price::new(dec!(50000.00))));
780 assert_eq!(ticker.high, Some(Price::new(dec!(51000.00))));
781 assert_eq!(ticker.low, Some(Price::new(dec!(49000.00))));
782 }
783
784 #[test]
785 fn test_parse_ws_ticker_with_market() {
786 let msg = serde_json::from_str(
787 r#"{
788 "arg": {"channel": "tickers", "instId": "BTC-USDT"},
789 "data": [{
790 "instId": "BTC-USDT",
791 "last": "50000.00",
792 "ts": "1700000000000"
793 }]
794 }"#,
795 )
796 .unwrap();
797
798 let market = Market {
799 id: "BTC-USDT".to_string(),
800 symbol: "BTC/USDT".to_string(),
801 base: "BTC".to_string(),
802 quote: "USDT".to_string(),
803 ..Default::default()
804 };
805
806 let ticker = parse_ws_ticker(&msg, Some(&market)).unwrap();
807 assert_eq!(ticker.symbol, "BTC/USDT");
808 }
809
810 #[test]
811 fn test_parse_ws_ticker_missing_data() {
812 let msg = serde_json::from_str(
813 r#"{
814 "arg": {"channel": "tickers", "instId": "BTC-USDT"}
815 }"#,
816 )
817 .unwrap();
818
819 let result = parse_ws_ticker(&msg, None);
820 assert!(result.is_err());
821 }
822
823 #[test]
826 fn test_parse_ws_orderbook() {
827 let msg = serde_json::from_str(
828 r#"{
829 "arg": {"channel": "books5", "instId": "BTC-USDT"},
830 "data": [{
831 "asks": [
832 ["50001.00", "1.0", "0", "1"],
833 ["50002.00", "3.0", "0", "2"],
834 ["50003.00", "2.5", "0", "1"]
835 ],
836 "bids": [
837 ["50000.00", "1.5", "0", "2"],
838 ["49999.00", "2.0", "0", "1"],
839 ["49998.00", "0.5", "0", "1"]
840 ],
841 "ts": "1700000000000"
842 }]
843 }"#,
844 )
845 .unwrap();
846
847 let orderbook = parse_ws_orderbook(&msg, "BTC/USDT".to_string()).unwrap();
848 assert_eq!(orderbook.symbol, "BTC/USDT");
849 assert_eq!(orderbook.bids.len(), 3);
850 assert_eq!(orderbook.asks.len(), 3);
851
852 assert_eq!(orderbook.bids[0].price, Price::new(dec!(50000.00)));
854 assert_eq!(orderbook.bids[1].price, Price::new(dec!(49999.00)));
855
856 assert_eq!(orderbook.asks[0].price, Price::new(dec!(50001.00)));
858 assert_eq!(orderbook.asks[1].price, Price::new(dec!(50002.00)));
859 }
860
861 #[test]
862 fn test_parse_ws_orderbook_missing_data() {
863 let msg = serde_json::from_str(
864 r#"{
865 "arg": {"channel": "books5", "instId": "BTC-USDT"}
866 }"#,
867 )
868 .unwrap();
869
870 let result = parse_ws_orderbook(&msg, "BTC/USDT".to_string());
871 assert!(result.is_err());
872 }
873
874 #[test]
877 fn test_parse_ws_trade() {
878 let msg = serde_json::from_str(
879 r#"{
880 "arg": {"channel": "trades", "instId": "BTC-USDT"},
881 "data": [{
882 "instId": "BTC-USDT",
883 "tradeId": "123456789",
884 "px": "50000.00",
885 "sz": "0.5",
886 "side": "buy",
887 "ts": "1700000000000"
888 }]
889 }"#,
890 )
891 .unwrap();
892
893 let trade = parse_ws_trade(&msg, None).unwrap();
894 assert_eq!(trade.timestamp, 1700000000000);
895 }
896
897 #[test]
898 fn test_parse_ws_trades_multiple() {
899 let msg = serde_json::from_str(
900 r#"{
901 "arg": {"channel": "trades", "instId": "BTC-USDT"},
902 "data": [
903 {
904 "instId": "BTC-USDT",
905 "tradeId": "123456789",
906 "px": "50000.00",
907 "sz": "0.5",
908 "side": "buy",
909 "ts": "1700000000000"
910 },
911 {
912 "instId": "BTC-USDT",
913 "tradeId": "123456790",
914 "px": "50001.00",
915 "sz": "1.0",
916 "side": "sell",
917 "ts": "1700000000001"
918 }
919 ]
920 }"#,
921 )
922 .unwrap();
923
924 let trades = parse_ws_trades(&msg, None).unwrap();
925 assert_eq!(trades.len(), 2);
926 }
927
928 #[test]
929 fn test_parse_ws_trade_missing_data() {
930 let msg = serde_json::from_str(
931 r#"{
932 "arg": {"channel": "trades", "instId": "BTC-USDT"}
933 }"#,
934 )
935 .unwrap();
936
937 let result = parse_ws_trade(&msg, None);
938 assert!(result.is_err());
939 }
940
941 #[test]
942 fn test_parse_ws_trades_empty_array() {
943 let msg = serde_json::from_str(
944 r#"{
945 "arg": {"channel": "trades", "instId": "BTC-USDT"},
946 "data": []
947 }"#,
948 )
949 .unwrap();
950
951 let trades = parse_ws_trades(&msg, None).unwrap();
952 assert!(trades.is_empty());
953 }
954}