bybit/
ws.rs

1use crate::prelude::*;
2
3use futures::{SinkExt, StreamExt};
4use log::trace;
5use std::collections::HashMap;
6use std::sync::Arc;
7use std::time::Instant;
8use tokio::net::TcpStream;
9use tokio::sync::mpsc;
10use tokio::time::Duration;
11use tokio_tungstenite::WebSocketStream;
12use tokio_tungstenite::{tungstenite::Message as WsMessage, MaybeTlsStream};
13
14#[derive(Clone)]
15pub struct Stream {
16    pub client: Client,
17}
18
19impl Stream {
20    /// Tests for connectivity by sending a ping request to the Bybit server.
21    ///
22    /// # Returns
23    ///
24    /// Returns a `Result` containing a `String` with the response message if successful,
25
26    /// * `private` is set to `true` if the request is for a private endpoint
27    /// or a `BybitError` if an error occurs.
28    pub async fn ws_ping(&self, private: bool) -> Result<(), BybitError> {
29        let mut parameters: BTreeMap<String, Value> = BTreeMap::new();
30        parameters.insert("req_id".into(), generate_random_uid(8).into());
31        parameters.insert("op".into(), "ping".into());
32        let request = build_json_request(&parameters);
33        let endpoint = if private {
34            WebsocketAPI::Private
35        } else {
36            WebsocketAPI::PublicLinear
37        };
38        let mut response = self
39            .client
40            .wss_connect(endpoint, Some(request), private, None)
41            .await?;
42        let Some(data) = response.next().await else {
43            return Err(BybitError::Base(
44                "Failed to receive ping response".to_string(),
45            ));
46        };
47
48        let data = data
49            .map_err(|e| BybitError::Base(format!("Failed to get ping response, error {}", e)))?;
50        if let WsMessage::Text(data) = data {
51            let response: PongResponse = serde_json::from_str(&data)?;
52            match response {
53                PongResponse::PublicPong(pong) => {
54                    trace!("Pong received successfully: {:#?}", pong);
55                }
56                PongResponse::PrivatePong(pong) => {
57                    trace!("Pong received successfully: {:#?}", pong);
58                }
59            }
60        }
61        Ok(())
62    }
63
64    pub async fn ws_priv_subscribe<'b, F>(
65        &self,
66        req: Subscription<'_>,
67        handler: F,
68    ) -> Result<(), BybitError>
69    where
70        F: FnMut(WebsocketEvents) -> Result<(), BybitError> + 'static + Send,
71    {
72        let request = Self::build_subscription(req);
73        let response = self
74            .client
75            .wss_connect(WebsocketAPI::Private, Some(request), true, Some(10))
76            .await?;
77        if let Ok(_) = Self::event_loop(response, handler, None).await {}
78        Ok(())
79    }
80
81    pub async fn ws_subscribe<'b, F>(
82        &self,
83        req: Subscription<'_>,
84        category: Category,
85        handler: F,
86    ) -> Result<(), BybitError>
87    where
88        F: FnMut(WebsocketEvents) -> Result<(), BybitError> + 'static + Send,
89    {
90        let endpoint = {
91            match category {
92                Category::Linear => WebsocketAPI::PublicLinear,
93                Category::Inverse => WebsocketAPI::PublicInverse,
94                Category::Spot => WebsocketAPI::PublicSpot,
95                _ => unimplemented!("Option has not been implemented"),
96            }
97        };
98        let request = Self::build_subscription(req);
99        let response = self
100            .client
101            .wss_connect(endpoint, Some(request), false, None)
102            .await?;
103        Self::event_loop(response, handler, None).await?;
104        Ok(())
105    }
106
107    pub fn build_subscription(action: Subscription) -> String {
108        let mut parameters: BTreeMap<String, Value> = BTreeMap::new();
109        parameters.insert("req_id".into(), generate_random_uid(8).into());
110        parameters.insert("op".into(), action.op.into());
111        let args_value: Value = action
112            .args
113            .iter()
114            .map(ToString::to_string)
115            .collect::<Vec<_>>()
116            .into();
117        parameters.insert("args".into(), args_value);
118
119        build_json_request(&parameters)
120    }
121
122    pub fn build_trade_subscription(orders: RequestType, recv_window: Option<u64>) -> String {
123        let mut parameters: BTreeMap<String, Value> = BTreeMap::new();
124        parameters.insert("reqId".into(), generate_random_uid(16).into());
125        let mut header_map: BTreeMap<String, String> = BTreeMap::new();
126        header_map.insert("X-BAPI-TIMESTAMP".into(), get_timestamp().to_string());
127        header_map.insert(
128            "X-BAPI-RECV-WINDOW".into(),
129            recv_window.unwrap_or(5000).to_string(),
130        );
131        parameters.insert("header".into(), json!(header_map));
132        match orders {
133            RequestType::Create(order) => {
134                parameters.insert("op".into(), "order.create".into());
135                parameters.insert("args".into(), build_ws_orders(RequestType::Create(order)));
136            }
137            RequestType::CreateBatch(order) => {
138                parameters.insert("op".into(), "order.create-batch".into());
139                parameters.insert(
140                    "args".into(),
141                    build_ws_orders(RequestType::CreateBatch(order)),
142                );
143            }
144            RequestType::Amend(order) => {
145                parameters.insert("op".into(), "order.amend".into());
146                parameters.insert("args".into(), build_ws_orders(RequestType::Amend(order)));
147            }
148            RequestType::AmendBatch(order) => {
149                parameters.insert("op".into(), "order.amend-batch".into());
150                parameters.insert(
151                    "args".into(),
152                    build_ws_orders(RequestType::AmendBatch(order)),
153                );
154            }
155            RequestType::Cancel(order) => {
156                parameters.insert("op".into(), "order.cancel".into());
157                parameters.insert("args".into(), build_ws_orders(RequestType::Cancel(order)));
158            }
159            RequestType::CancelBatch(order) => {
160                parameters.insert("op".into(), "order.cancel-batch".into());
161                parameters.insert(
162                    "args".into(),
163                    build_ws_orders(RequestType::CancelBatch(order)),
164                );
165            }
166        }
167        build_json_request(&parameters)
168    }
169
170    /// Subscribes to the specified order book updates and handles the order book events
171    ///
172    /// # Arguments
173    ///
174    /// * `subs` - A vector of tuples containing the order book ID and symbol
175    /// * `category` - The category of the order book
176    ///
177    /// # Example
178    ///
179    /// ```
180    /// use your_crate_name::Category;
181    /// let subs = vec![(1, "BTC"), (2, "ETH")];
182    /// ```
183    pub async fn ws_orderbook(
184        &self,
185        subs: Vec<(i32, &str)>,
186        category: Category,
187        sender: mpsc::UnboundedSender<OrderBookUpdate>,
188    ) -> Result<(), BybitError> {
189        let arr: Vec<String> = subs
190            .into_iter()
191            .map(|(num, sym)| format!("orderbook.{}.{}", num, sym.to_uppercase()))
192            .collect();
193        let request = Subscription::new("subscribe", arr.iter().map(AsRef::as_ref).collect());
194        self.ws_subscribe(request, category, move |event| {
195            if let WebsocketEvents::OrderBookEvent(order_book) = event {
196                sender
197                    .send(order_book)
198                    .map_err(|e| BybitError::ChannelSendError {
199                        underlying: e.to_string(),
200                    })?;
201            }
202            Ok(())
203        })
204        .await
205    }
206
207    /// This function subscribes to the specified trades and handles the trade events.
208    /// # Arguments
209    ///
210    /// * `subs` - A vector of trade subscriptions
211    /// * `category` - The category of the trades
212    ///
213    /// # Example
214    ///
215    /// ```
216    /// use your_crate_name::Category;
217    /// let subs = vec!["BTCUSD", "ETHUSD"];
218    /// let category = Category::Linear;
219    /// ws_trades(subs, category);
220    /// ```
221    pub async fn ws_trades(
222        &self,
223        subs: Vec<&str>,
224        category: Category,
225        sender: mpsc::UnboundedSender<WsTrade>,
226    ) -> Result<(), BybitError> {
227        let arr: Vec<String> = subs
228            .iter()
229            .map(|&sub| format!("publicTrade.{}", sub.to_uppercase()))
230            .collect();
231        let request = Subscription::new("subscribe", arr.iter().map(AsRef::as_ref).collect());
232        let handler = move |event| {
233            if let WebsocketEvents::TradeEvent(trades) = event {
234                for trade in trades.data {
235                    sender
236                        .send(trade)
237                        .map_err(|e| BybitError::ChannelSendError {
238                            underlying: e.to_string(),
239                        })?;
240                }
241            }
242            Ok(())
243        };
244
245        self.ws_subscribe(request, category, handler).await
246    }
247
248    /// Subscribes to ticker events for the specified symbols and category.
249    ///
250    /// # Arguments
251    ///
252    /// * `subs` - A vector of symbols for which ticker events are subscribed.
253    /// * `category` - The category for which ticker events are subscribed.
254    ///
255    /// # Examples
256    ///
257    /// ```
258    /// use your_crate_name::Category;
259    /// let subs = vec!["BTCUSD", "ETHUSD"];
260    /// let category = Category::Linear;
261    /// let sender = UnboundedSender<Ticker>;
262    /// ws_tickers(subs, category, sender);
263    /// ```
264    pub async fn ws_tickers(
265        &self,
266        subs: Vec<&str>,
267        category: Category,
268        sender: mpsc::UnboundedSender<Ticker>,
269    ) -> Result<(), BybitError> {
270        self._ws_tickers(subs, category, sender, |ws_ticker| Some(ws_ticker.data))
271            .await
272    }
273
274    /// Subscribes to ticker events with timestamp for the specified symbols and category.
275    ///
276    /// # Arguments
277    ///
278    /// * `subs` - A vector of symbols for which ticker events are subscribed.
279    /// * `category` - The category for which ticker events are subscribed.
280    ///
281    /// # Examples
282    ///
283    /// ```
284    /// use your_crate_name::Category;
285    /// let subs = vec!["BTCUSD", "ETHUSD"];
286    /// let category = Category::Linear;
287    /// let sender = UnboundedSender<Ticker>;
288    /// ws_timed_tickers(subs, category, sender);
289    /// ```
290    pub async fn ws_timed_tickers(
291        &self,
292        subs: Vec<&str>,
293        category: Category,
294        sender: mpsc::UnboundedSender<Timed<Ticker>>,
295    ) -> Result<(), BybitError> {
296        self._ws_tickers(subs, category, sender, |ticker| {
297            Some(Timed {
298                time: ticker.ts,
299                data: ticker.data,
300            })
301        })
302        .await
303    }
304
305    /// A high abstraction level stream of timed linear snapshots, which you can
306    /// subscribe to using the receiver of the sender. Internally this method
307    /// consumes the linear ticker API but instead of returning a stream of deltas
308    /// we update the initial snapshot with all subsequent streams, and thanks
309    /// to internally using `.scan` we you get `Timed<LinearTickerDataSnapshot>`,
310    /// instead of `Timed<LinearTickerDataDelta>`.
311    ///
312    /// If you provide multiple symbols, the `LinearTickerDataSnapshot` values
313    /// will be interleaved.
314    ///
315    /// # Usage
316    /// ```no_run
317    /// use bybit::prelude::*;
318    /// use tokio::sync::mpsc;
319    /// use std::sync::Arc;
320    ///
321    /// #[tokio::main]
322    /// async fn main() {
323    ///
324    /// let ws: Arc<Stream> = Arc::new(Bybit::new(None, None));
325    /// let (tx, mut rx) = mpsc::unbounded_channel::<Timed<LinearTickerDataSnapshot>>();
326    /// tokio::spawn(async move {
327    ///     ws.ws_timed_linear_tickers(vec!["BTCUSDT".to_owned(), "ETHUSDT".to_owned()], tx)
328    ///         .await
329    ///         .unwrap();
330    /// });
331    /// while let Some(ticker_snapshot) = rx.recv().await {
332    ///     println!("{:#?}", ticker_snapshot);
333    /// }
334    /// }
335    /// ```
336    pub async fn ws_timed_linear_tickers(
337        self: Arc<Self>,
338        subs: Vec<String>,
339        sender: mpsc::UnboundedSender<Timed<LinearTickerDataSnapshot>>,
340    ) -> Result<(), BybitError> {
341        let (tx, mut rx) = mpsc::unbounded_channel::<Timed<LinearTickerData>>();
342        // Spawn the WebSocket task
343        tokio::spawn({
344            let self_arc = Arc::clone(&self);
345            let subs = subs.clone();
346            async move {
347                self_arc
348                    ._ws_tickers(
349                        subs.iter().map(|s| s.as_str()).collect(),
350                        Category::Linear,
351                        tx,
352                        |ticker| match &ticker.data {
353                            Ticker::Linear(linear) => Some(Timed {
354                                time: ticker.ts,
355                                data: linear.clone(),
356                            }),
357                            Ticker::Spot(_) => None,
358                        },
359                    )
360                    .await
361            }
362        });
363
364        // State to store snapshots for each symbol
365        let mut snapshots: HashMap<String, Timed<LinearTickerDataSnapshot>> = HashMap::new();
366
367        // Process incoming messages
368        while let Some(ticker) = rx.recv().await {
369            match ticker.data {
370                LinearTickerData::Snapshot(snapshot) => {
371                    let symbol = snapshot.symbol.clone();
372                    let timed_snapshot = Timed {
373                        time: ticker.time,
374                        data: snapshot,
375                    };
376                    // Store the snapshot and send it
377                    snapshots.insert(symbol.clone(), timed_snapshot.clone());
378                    sender
379                        .send(timed_snapshot)
380                        .map_err(|e| BybitError::ChannelSendError {
381                            underlying: e.to_string(),
382                        })?
383                }
384                LinearTickerData::Delta(delta) => {
385                    let symbol = delta.symbol.clone();
386                    if let Some(snapshot_timed) = snapshots.get_mut(&symbol) {
387                        let mut snapshot = snapshot_timed.data.clone();
388                        snapshot.update(delta);
389                        let new = Timed {
390                            data: snapshot,
391                            time: ticker.time,
392                        };
393                        *snapshot_timed = new.clone();
394                        sender.send(new).map_err(|e| BybitError::ChannelSendError {
395                            underlying: e.to_string(),
396                        })?
397                    }
398                    // If no snapshot exists for the symbol, skip the delta
399                }
400            }
401        }
402
403        Ok(())
404    }
405
406    async fn _ws_tickers<T, F>(
407        &self,
408        subs: Vec<&str>,
409        category: Category,
410        sender: mpsc::UnboundedSender<T>,
411        filter_map: F,
412    ) -> Result<(), BybitError>
413    where
414        T: 'static + Sync + Send,
415        F: 'static + Sync + Send + Fn(WsTicker) -> Option<T>,
416    {
417        let arr: Vec<String> = subs
418            .into_iter()
419            .map(|sub| format!("tickers.{}", sub.to_uppercase()))
420            .collect();
421        let request = Subscription::new("subscribe", arr.iter().map(String::as_str).collect());
422
423        let handler = move |event| {
424            if let WebsocketEvents::TickerEvent(ticker) = event {
425                if let Some(mapped) = filter_map(ticker) {
426                    sender
427                        .send(mapped)
428                        .map_err(|e| BybitError::ChannelSendError {
429                            underlying: e.to_string(),
430                        })?;
431                }
432            }
433            Ok(())
434        };
435
436        self.ws_subscribe(request, category, handler).await
437    }
438    pub async fn ws_liquidations(
439        &self,
440        subs: Vec<&str>,
441        category: Category,
442        sender: mpsc::UnboundedSender<LiquidationData>,
443    ) -> Result<(), BybitError> {
444        let arr: Vec<String> = subs
445            .into_iter()
446            .map(|sub| format!("liquidation.{}", sub.to_uppercase()))
447            .collect();
448        let request = Subscription::new("subscribe", arr.iter().map(String::as_str).collect());
449
450        let handler = move |event| {
451            if let WebsocketEvents::LiquidationEvent(liquidation) = event {
452                sender
453                    .send(liquidation.data)
454                    .map_err(|e| BybitError::ChannelSendError {
455                        underlying: e.to_string(),
456                    })?;
457            }
458            Ok(())
459        };
460
461        self.ws_subscribe(request, category, handler).await
462    }
463    pub async fn ws_klines(
464        &self,
465        subs: Vec<(&str, &str)>,
466        category: Category,
467        sender: mpsc::UnboundedSender<WsKline>,
468    ) -> Result<(), BybitError> {
469        let arr: Vec<String> = subs
470            .into_iter()
471            .map(|(interval, sym)| format!("kline.{}.{}", interval, sym.to_uppercase()))
472            .collect();
473        let request = Subscription::new("subscribe", arr.iter().map(AsRef::as_ref).collect());
474        self.ws_subscribe(request, category, move |event| {
475            if let WebsocketEvents::KlineEvent(kline) = event {
476                sender
477                    .send(kline)
478                    .map_err(|e| BybitError::ChannelSendError {
479                        underlying: e.to_string(),
480                    })?;
481            }
482            Ok(())
483        })
484        .await
485    }
486
487    pub async fn ws_position(
488        &self,
489        cat: Option<Category>,
490        sender: mpsc::UnboundedSender<PositionData>,
491    ) -> Result<(), BybitError> {
492        let sub_str = if let Some(v) = cat {
493            match v {
494                Category::Linear => "position.linear",
495                Category::Inverse => "position.inverse",
496                _ => "",
497            }
498        } else {
499            "position"
500        };
501
502        let request = Subscription::new("subscribe", vec![sub_str]);
503        self.ws_priv_subscribe(request, move |event| {
504            if let WebsocketEvents::PositionEvent(position) = event {
505                for v in position.data {
506                    sender.send(v).map_err(|e| BybitError::ChannelSendError {
507                        underlying: e.to_string(),
508                    })?;
509                }
510            }
511            Ok(())
512        })
513        .await
514    }
515
516    pub async fn ws_executions(
517        &self,
518        cat: Option<Category>,
519        sender: mpsc::UnboundedSender<ExecutionData>,
520    ) -> Result<(), BybitError> {
521        let sub_str = if let Some(v) = cat {
522            match v {
523                Category::Linear => "execution.linear",
524                Category::Inverse => "execution.inverse",
525                Category::Spot => "execution.spot",
526                Category::Option => "execution.option",
527            }
528        } else {
529            "execution"
530        };
531
532        let request = Subscription::new("subscribe", vec![sub_str]);
533        self.ws_priv_subscribe(request, move |event| {
534            if let WebsocketEvents::ExecutionEvent(execute) = event {
535                for v in execute.data {
536                    sender.send(v).map_err(|e| BybitError::ChannelSendError {
537                        underlying: e.to_string(),
538                    })?;
539                }
540            }
541            Ok(())
542        })
543        .await
544    }
545
546    pub async fn ws_fast_exec(
547        &self,
548        sender: mpsc::UnboundedSender<FastExecData>,
549    ) -> Result<(), BybitError> {
550        let sub_str = "execution.fast";
551        let request = Subscription::new("subscribe", vec![sub_str]);
552
553        self.ws_priv_subscribe(request, move |event| {
554            if let WebsocketEvents::FastExecEvent(execution) = event {
555                for v in execution.data {
556                    sender.send(v).map_err(|e| BybitError::ChannelSendError {
557                        underlying: e.to_string(),
558                    })?;
559                }
560            }
561            Ok(())
562        })
563        .await
564    }
565
566    pub async fn ws_orders(
567        &self,
568        cat: Option<Category>,
569        sender: mpsc::UnboundedSender<OrderData>,
570    ) -> Result<(), BybitError> {
571        let sub_str = if let Some(v) = cat {
572            match v {
573                Category::Linear => "order.linear",
574                Category::Inverse => "order.inverse",
575                Category::Spot => "order.spot",
576                Category::Option => "order.option",
577            }
578        } else {
579            "order"
580        };
581
582        let request = Subscription::new("subscribe", vec![sub_str]);
583        self.ws_priv_subscribe(request, move |event| {
584            if let WebsocketEvents::OrderEvent(order) = event {
585                for v in order.data {
586                    sender.send(v).map_err(|e| BybitError::ChannelSendError {
587                        underlying: e.to_string(),
588                    })?;
589                }
590            }
591            Ok(())
592        })
593        .await
594    }
595
596    pub async fn ws_wallet(
597        &self,
598        sender: mpsc::UnboundedSender<WalletData>,
599    ) -> Result<(), BybitError> {
600        let sub_str = "wallet";
601        let request = Subscription::new("subscribe", vec![sub_str]);
602        self.ws_priv_subscribe(request, move |event| {
603            if let WebsocketEvents::Wallet(wallet) = event {
604                for v in wallet.data {
605                    sender.send(v).map_err(|e| BybitError::ChannelSendError {
606                        underlying: e.to_string(),
607                    })?;
608                }
609            }
610            Ok(())
611        })
612        .await
613    }
614
615    pub async fn ws_trade_stream<'a, F>(
616        &self,
617        req: mpsc::UnboundedReceiver<RequestType<'a>>,
618        handler: F,
619    ) -> Result<(), BybitError>
620    where
621        F: FnMut(WebsocketEvents) -> Result<(), BybitError> + 'static + Send,
622        'a: 'static,
623    {
624        let response = self
625            .client
626            .wss_connect(WebsocketAPI::TradeStream, None, true, Some(10))
627            .await?;
628        Self::event_loop(response, handler, Some(req)).await?;
629
630        Ok(())
631    }
632
633    pub async fn event_loop<'a, H>(
634        mut stream: WebSocketStream<MaybeTlsStream<TcpStream>>,
635        mut handler: H,
636        mut order_sender: Option<mpsc::UnboundedReceiver<RequestType<'_>>>,
637    ) -> Result<(), BybitError>
638    where
639        H: WebSocketHandler,
640    {
641        let mut interval = Instant::now();
642        loop {
643            let msg = stream.next().await;
644            match msg {
645                Some(Ok(WsMessage::Text(msg))) => {
646                    if let Err(_) = handler.handle_msg(&msg) {
647                        return Err(BybitError::Base(
648                            "Error handling stream message".to_string(),
649                        ));
650                    }
651                }
652                Some(Err(e)) => {
653                    return Err(BybitError::from(e.to_string()));
654                }
655                None => {
656                    return Err(BybitError::Base("Stream was closed".to_string()));
657                }
658                _ => {}
659            }
660            if let Some(sender) = order_sender.as_mut() {
661                if let Some(v) = sender.recv().await {
662                    let order_req = Self::build_trade_subscription(v, Some(3000));
663                    stream.send(WsMessage::Text(order_req)).await?;
664                }
665            }
666
667            if interval.elapsed() > Duration::from_secs(300) {
668                let mut parameters: BTreeMap<String, Value> = BTreeMap::new();
669                if order_sender.is_none() {
670                    parameters.insert("req_id".into(), generate_random_uid(8).into());
671                }
672                parameters.insert("op".into(), "ping".into());
673                let request = build_json_request(&parameters);
674                let _ = stream
675                    .send(WsMessage::Text(request))
676                    .await
677                    .map_err(BybitError::from);
678                interval = Instant::now();
679            }
680        }
681    }
682}
683
684pub trait WebSocketHandler {
685    type Event;
686    fn handle_msg(&mut self, msg: &str) -> Result<(), BybitError>;
687}
688
689impl<F> WebSocketHandler for F
690where
691    F: FnMut(WebsocketEvents) -> Result<(), BybitError>,
692{
693    type Event = WebsocketEvents;
694    fn handle_msg(&mut self, msg: &str) -> Result<(), BybitError> {
695        let update: Value = serde_json::from_str(msg)?;
696        if let Ok(event) = serde_json::from_value::<WebsocketEvents>(update.clone()) {
697            self(event)?;
698        }
699
700        Ok(())
701    }
702}