bybit_async/websocket/
mod.rs

1pub mod topics;
2
3use crate::{
4    error::BybitError::{self, *},
5    models::Product,
6    Config,
7};
8use fehler::{throw, throws};
9use futures::{stream::Stream, SinkExt, StreamExt};
10use hmac::{Hmac, Mac};
11use reqwest::Url;
12use serde::{Deserialize, Serialize};
13use serde_json::{from_str, value::RawValue};
14use sha2::Sha256;
15use std::time::SystemTime;
16use std::{
17    marker::PhantomData,
18    pin::Pin,
19    task::{Context, Poll},
20};
21use tokio::net::TcpStream;
22use tokio_tungstenite::{connect_async, MaybeTlsStream, WebSocketStream};
23use tungstenite::Message;
24
25type WSStream = WebSocketStream<MaybeTlsStream<TcpStream>>;
26
27pub trait ParseMessage: Sized {
28    fn parse(topic: &str, data: &str) -> Result<Self, BybitError>;
29    fn parse_succ(succ: &str) -> Result<Self, BybitError>;
30    fn ping() -> Self;
31}
32
33pub struct BybitWebsocket<M> {
34    stream: WSStream,
35    _phantom: PhantomData<M>,
36    private: bool,
37}
38
39impl<M> BybitWebsocket<M>
40where
41    M: ParseMessage,
42{
43    #[throws(BybitError)]
44    pub async fn new(config: Config) -> BybitWebsocket<M> {
45        let base = if config.api_key.is_none() {
46            match config.product {
47                Product::Spot => &config.spot_ws_endpoint,
48                Product::UsdMFutures => &config.usdm_futures_ws_endpoint,
49                Product::CoinMFutures => &config.coinm_futures_ws_endpoint,
50                Product::EuropeanOptions => &config.european_options_ws_endpoint,
51            }
52        } else {
53            &config.private_ws_endpoint
54        };
55        let endpoint = Url::parse(base).unwrap();
56        let (mut stream, _) = match connect_async(endpoint).await {
57            Ok(v) => v,
58            Err(tungstenite::Error::Http(ref http)) => throw!(StartWebsocketError(
59                http.status(),
60                String::from_utf8_lossy(http.body().as_deref().unwrap_or_default()).to_string()
61            )),
62            Err(e) => throw!(e),
63        };
64
65        let private = config.api_secret.is_some() && config.api_key.is_some();
66        if private {
67            let since_epoch = SystemTime::now()
68                .duration_since(SystemTime::UNIX_EPOCH)
69                .unwrap();
70            let in_ms = since_epoch.as_millis() as u64;
71            let expires = in_ms + 1_000;
72
73            let mut mac =
74                Hmac::<Sha256>::new_from_slice(config.api_secret.unwrap().as_bytes()).unwrap();
75            let sign_message = format!("GET/realtime{}", expires);
76            mac.update(sign_message.as_bytes());
77            let signature = hex::encode(mac.finalize().into_bytes());
78
79            let msg = serde_json::to_string(&serde_json::json!({
80                "op": "auth",
81                "args": [config.api_key.unwrap(), expires, signature],
82            }))?;
83            stream.send(Message::Text(msg)).await?;
84        }
85
86        Self {
87            stream,
88            _phantom: PhantomData,
89            private,
90        }
91    }
92
93    pub async fn subscribe(&mut self, topics: Vec<&str>) -> Result<(), BybitError> {
94        let topics: Vec<&str> = topics.into_iter().collect();
95        let msg = serde_json::to_string(&serde_json::json!({
96            "op": "subscribe",
97            "args": topics,
98        }))?;
99        self.stream.send(Message::Text(msg)).await?;
100        Ok(())
101    }
102}
103
104impl<M> BybitWebsocket<M> {
105    #[throws(BybitError)]
106    pub async fn pong(&mut self) {
107        self.stream.send(Message::Pong(vec![])).await?
108    }
109}
110
111#[derive(Deserialize)]
112struct PublicMessage<'a> {
113    pub topic: String,
114    #[serde(rename = "type")]
115    pub type_: String,
116    pub ts: u64,
117    #[serde(borrow)]
118    pub data: &'a RawValue,
119}
120
121#[derive(Deserialize)]
122struct PrivateMessage<'a> {
123    pub id: String,
124    pub topic: String,
125    #[serde(rename = "creationTime")]
126    pub creation_time: u64,
127    #[serde(borrow)]
128    pub data: &'a RawValue,
129}
130
131
132impl<M> Stream for BybitWebsocket<M>
133where
134    M: ParseMessage + Unpin + std::fmt::Debug,
135{
136    type Item = Result<M, BybitError>;
137
138    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
139        let c = match self.stream.poll_next_unpin(cx) {
140            Poll::Ready(Some(Ok(c))) => c,
141            Poll::Ready(Some(Err(e))) => return Poll::Ready(Some(Err(e.into()))),
142            Poll::Pending => return Poll::Pending,
143            Poll::Ready(None) => return Poll::Ready(None),
144        };
145        let msg = match c {
146            Message::Text(msg) => msg,
147            Message::Ping(..) => return Poll::Ready(Some(Ok(M::ping()))),
148            Message::Binary(_) | Message::Frame(_) | Message::Pong(..) => return Poll::Pending,
149            Message::Close(_) => return Poll::Ready(None),
150        };
151
152        if self.private {
153            let try_message = from_str::<PrivateMessage>(&msg);
154            if try_message.is_ok() {
155                let message = try_message.unwrap();
156                Poll::Ready(Some(M::parse(&message.topic, message.data.get())))
157            } else {
158                // auth/sub success
159                Poll::Ready(Some(Ok(M::parse_succ(&msg).unwrap())))
160            }
161        } else {
162            let try_message = from_str::<PublicMessage>(&msg);
163            if try_message.is_ok() {
164                let message = try_message.unwrap();
165                Poll::Ready(Some(M::parse(&message.topic, message.data.get())))
166            } else {
167                Poll::Ready(Some(Ok(M::parse_succ(&msg).unwrap())))
168            }
169        }
170    }
171}