bybit_api/websocket/
client.rs1use futures_util::{SinkExt, StreamExt};
4use std::collections::HashMap;
5use std::sync::Arc;
6use tokio::net::TcpStream;
7use tokio::sync::{mpsc, RwLock};
8use tokio::time::{interval, Duration};
9use tokio_tungstenite::{connect_async, tungstenite::Message, MaybeTlsStream, WebSocketStream};
10use tracing::{debug, error, info, warn};
11
12use crate::auth::{generate_ws_signature, get_timestamp};
13use crate::config::WsConfig;
14use crate::error::{BybitError, Result};
15use crate::websocket::models::*;
16
17type WsStream = WebSocketStream<MaybeTlsStream<TcpStream>>;
18type Callback = Arc<dyn Fn(WsMessage) + Send + Sync>;
19
20pub struct BybitWebSocket {
22 config: WsConfig,
23 subscriptions: Arc<RwLock<Vec<String>>>,
24 callbacks: Arc<RwLock<HashMap<String, Callback>>>,
25 tx: Option<mpsc::Sender<Message>>,
26 is_connected: Arc<RwLock<bool>>,
27}
28
29impl BybitWebSocket {
30 pub fn public(url: &str) -> Self {
32 Self {
33 config: WsConfig::public(url),
34 subscriptions: Arc::new(RwLock::new(Vec::new())),
35 callbacks: Arc::new(RwLock::new(HashMap::new())),
36 tx: None,
37 is_connected: Arc::new(RwLock::new(false)),
38 }
39 }
40
41 pub fn private(api_key: &str, api_secret: &str, url: &str) -> Self {
43 Self {
44 config: WsConfig::private(api_key, api_secret).with_url(url),
45 subscriptions: Arc::new(RwLock::new(Vec::new())),
46 callbacks: Arc::new(RwLock::new(HashMap::new())),
47 tx: None,
48 is_connected: Arc::new(RwLock::new(false)),
49 }
50 }
51
52 pub async fn connect(&mut self) -> Result<()> {
54 let url = &self.config.url;
55 info!(url = %url, "Connecting to WebSocket");
56
57 let (ws_stream, _) = connect_async(url)
58 .await
59 .map_err(|e| BybitError::WebSocket(Box::new(e)))?;
60
61 let (write, read) = ws_stream.split();
62
63 let (tx, mut rx) = mpsc::channel::<Message>(100);
65 self.tx = Some(tx.clone());
66
67 *self.is_connected.write().await = true;
69
70 let write = Arc::new(tokio::sync::Mutex::new(write));
72 let write_clone = write.clone();
73 tokio::spawn(async move {
74 while let Some(msg) = rx.recv().await {
75 let mut w = write_clone.lock().await;
76 if let Err(e) = w.send(msg).await {
77 error!("Failed to send message: {}", e);
78 break;
79 }
80 }
81 });
82
83 if self.config.api_key.is_some() {
85 self.authenticate().await?;
86 }
87
88 let tx_ping = tx.clone();
90 let ping_interval = self.config.ping_interval;
91 tokio::spawn(async move {
92 let mut interval = interval(Duration::from_secs(ping_interval));
93 loop {
94 interval.tick().await;
95 let ping = WsPing::new();
96 let msg = serde_json::to_string(&ping).unwrap_or_default();
97 if tx_ping.send(Message::Text(msg)).await.is_err() {
98 break;
99 }
100 debug!("Ping sent");
101 }
102 });
103
104 let callbacks = self.callbacks.clone();
106 let is_connected = self.is_connected.clone();
107 let subscriptions = self.subscriptions.clone();
108 let config = self.config.clone();
109 let tx_reconnect = tx.clone();
110
111 tokio::spawn(async move {
112 Self::handle_messages(
113 read,
114 callbacks,
115 is_connected,
116 subscriptions,
117 config,
118 tx_reconnect,
119 )
120 .await;
121 });
122
123 info!("WebSocket connected");
124 Ok(())
125 }
126
127 async fn handle_messages(
129 mut read: futures_util::stream::SplitStream<WsStream>,
130 callbacks: Arc<RwLock<HashMap<String, Callback>>>,
131 is_connected: Arc<RwLock<bool>>,
132 _subscriptions: Arc<RwLock<Vec<String>>>,
133 _config: WsConfig,
134 _tx: mpsc::Sender<Message>,
135 ) {
136 while let Some(msg_result) = read.next().await {
137 match msg_result {
138 Ok(Message::Text(text)) => {
139 let json: serde_json::Value = match serde_json::from_str(&text) {
141 Ok(v) => v,
142 Err(e) => {
143 warn!(
144 "Failed to parse message: {}, text: {}",
145 e,
146 &text[..text.len().min(200)]
147 );
148 continue; }
150 };
151
152 if is_pong(&json) {
154 debug!("Pong received");
155 continue;
156 }
157
158 if is_auth_response(&json) {
159 if json
160 .get("success")
161 .and_then(|v| v.as_bool())
162 .unwrap_or(false)
163 {
164 info!("Authentication successful");
165 } else {
166 error!("Authentication failed: {:?}", json);
167 }
168 continue;
169 }
170
171 if is_subscription_response(&json) {
172 if json
173 .get("success")
174 .and_then(|v| v.as_bool())
175 .unwrap_or(false)
176 {
177 debug!("Subscription successful");
178 } else {
179 warn!("Subscription failed: {:?}", json);
180 }
181 continue;
182 }
183
184 if is_data_message(&json) {
186 if let Ok(ws_msg) = serde_json::from_value::<WsMessage>(json) {
187 let cbs = callbacks.read().await;
188 if let Some(callback) = cbs.get(&ws_msg.topic) {
189 callback(ws_msg.clone());
190 } else {
191 for (topic, callback) in cbs.iter() {
193 if ws_msg
194 .topic
195 .starts_with(topic.split('.').next().unwrap_or(""))
196 {
197 callback(ws_msg.clone());
198 break;
199 }
200 }
201 }
202 }
203 }
204 }
205 Ok(Message::Ping(_)) => {
206 debug!("Received ping frame");
207 }
209 Ok(Message::Close(_)) => {
210 info!("WebSocket closed");
211 *is_connected.write().await = false;
212 break;
213 }
214 Err(e) => {
215 error!("WebSocket error: {}", e);
216 *is_connected.write().await = false;
217 break;
218 }
219 _ => {}
220 }
221 }
222 }
223
224 async fn authenticate(&self) -> Result<()> {
226 let api_key = self
227 .config
228 .api_key
229 .as_ref()
230 .ok_or_else(|| BybitError::Auth("API key not set".into()))?;
231 let api_secret = self
232 .config
233 .api_secret
234 .as_ref()
235 .ok_or_else(|| BybitError::Auth("API secret not set".into()))?;
236
237 let expires = get_timestamp() + 10000;
238 let signature = generate_ws_signature(api_secret, expires);
239
240 let auth_msg = WsAuthRequest {
241 req_id: uuid::Uuid::new_v4().to_string(),
242 op: "auth".to_string(),
243 args: vec![
244 serde_json::Value::String(api_key.clone()),
245 serde_json::Value::Number(expires.into()),
246 serde_json::Value::String(signature),
247 ],
248 };
249
250 let msg = serde_json::to_string(&auth_msg).map_err(|e| BybitError::Parse(e.to_string()))?;
251
252 self.send(msg).await?;
253 info!("Authentication request sent");
254 Ok(())
255 }
256
257 pub async fn subscribe<F>(&mut self, topics: Vec<String>, callback: F) -> Result<()>
263 where
264 F: Fn(WsMessage) + Send + Sync + 'static,
265 {
266 let callback = Arc::new(callback) as Callback;
267
268 {
270 let mut cbs = self.callbacks.write().await;
271 for topic in &topics {
272 cbs.insert(topic.clone(), callback.clone());
273 }
274 }
275
276 {
278 let mut subs = self.subscriptions.write().await;
279 subs.extend(topics.clone());
280 }
281
282 let sub_msg = WsRequest {
284 req_id: uuid::Uuid::new_v4().to_string(),
285 op: "subscribe".to_string(),
286 args: topics,
287 };
288
289 let msg = serde_json::to_string(&sub_msg).map_err(|e| BybitError::Parse(e.to_string()))?;
290
291 self.send(msg).await
292 }
293
294 pub async fn unsubscribe(&mut self, topics: Vec<String>) -> Result<()> {
296 {
298 let mut cbs = self.callbacks.write().await;
299 for topic in &topics {
300 cbs.remove(topic);
301 }
302 }
303
304 {
306 let mut subs = self.subscriptions.write().await;
307 subs.retain(|t| !topics.contains(t));
308 }
309
310 let unsub_msg = WsRequest {
312 req_id: uuid::Uuid::new_v4().to_string(),
313 op: "unsubscribe".to_string(),
314 args: topics,
315 };
316
317 let msg =
318 serde_json::to_string(&unsub_msg).map_err(|e| BybitError::Parse(e.to_string()))?;
319
320 self.send(msg).await
321 }
322
323 async fn send(&self, msg: String) -> Result<()> {
325 if let Some(tx) = &self.tx {
326 tx.send(Message::Text(msg)).await.map_err(|_| {
327 BybitError::WebSocket(Box::new(
328 tokio_tungstenite::tungstenite::Error::AlreadyClosed,
329 ))
330 })?;
331 }
332 Ok(())
333 }
334
335 pub async fn is_connected(&self) -> bool {
337 *self.is_connected.read().await
338 }
339
340 pub async fn disconnect(&mut self) -> Result<()> {
342 *self.is_connected.write().await = false;
343 self.tx = None;
344 info!("WebSocket disconnected");
345 Ok(())
346 }
347}