Skip to main content

bybit_client/ws/
client.rs

1//! WebSocket client implementation.
2
3use std::collections::HashSet;
4use std::sync::atomic::{AtomicBool, Ordering};
5use std::sync::Arc;
6use std::time::Duration;
7
8use futures_util::{SinkExt, StreamExt};
9use tokio::net::TcpStream;
10use tokio::sync::{mpsc, RwLock};
11use tokio::time::interval;
12use tokio_tungstenite::{
13    connect_async,
14    tungstenite::Message,
15    MaybeTlsStream, WebSocketStream,
16};
17use tracing::{debug, error, info, warn};
18
19use crate::auth;
20use crate::config::{ClientConfig, Environment};
21use crate::error::BybitError;
22use crate::ws::types::*;
23
24/// Default ping interval in seconds.
25const DEFAULT_PING_INTERVAL_SECS: u64 = 20;
26
27/// Default reconnect delay in seconds.
28const DEFAULT_RECONNECT_DELAY_SECS: u64 = 5;
29
30/// Maximum reconnect attempts before giving up.
31const MAX_RECONNECT_ATTEMPTS: u32 = 10;
32
33/// WebSocket client for public and private streams.
34pub struct WsClient {
35    /// Configuration.
36    #[allow(dead_code)]
37    config: ClientConfig,
38    /// Channel type.
39    channel: WsChannel,
40    /// Subscribed topics.
41    subscribed_topics: Arc<RwLock<HashSet<String>>>,
42    /// Message sender.
43    #[allow(dead_code)]
44    message_tx: mpsc::UnboundedSender<WsMessage>,
45    /// Command sender for internal commands.
46    command_tx: mpsc::UnboundedSender<WsCommand>,
47    /// Whether the client is connected.
48    connected: Arc<AtomicBool>,
49    /// Whether the client is running.
50    running: Arc<AtomicBool>,
51}
52
53/// Internal commands for the WebSocket task.
54enum WsCommand {
55    Subscribe(Vec<String>),
56    Unsubscribe(Vec<String>),
57    #[allow(dead_code)]
58    SendRaw(String),
59    Disconnect,
60}
61
62impl WsClient {
63    /// Create a new WebSocket client for public streams.
64    ///
65    /// # Example
66    ///
67    /// ```no_run
68    /// use bybit_client::ws::{WsClient, WsChannel};
69    ///
70    /// #[tokio::main]
71    /// async fn main() -> Result<(), Box<dyn std::error::Error>> {
72    ///     let (client, mut receiver) = WsClient::connect_public(WsChannel::PublicLinear).await?;
73    ///
74    ///     client.subscribe(&["orderbook.50.BTCUSDT"]).await?;
75    ///
76    ///     while let Some(msg) = receiver.recv().await {
77    ///         println!("Received: {:?}", msg);
78    ///     }
79    ///
80    ///     Ok(())
81    /// }
82    /// ```
83    pub async fn connect_public(
84        channel: WsChannel,
85    ) -> Result<(Self, mpsc::UnboundedReceiver<WsMessage>), BybitError> {
86        Self::connect_with_config(ClientConfig::public_only(), channel).await
87    }
88
89    /// Create a new WebSocket client for private streams (requires authentication).
90    ///
91    /// # Example
92    ///
93    /// ```no_run
94    /// use bybit_client::ws::{WsClient, WsChannel};
95    ///
96    /// #[tokio::main]
97    /// async fn main() -> Result<(), Box<dyn std::error::Error>> {
98    ///     let (client, mut receiver) = WsClient::connect_private(
99    ///         "api_key",
100    ///         "api_secret",
101    ///     ).await?;
102    ///
103    ///     client.subscribe(&["position", "order", "execution"]).await?;
104    ///
105    ///     while let Some(msg) = receiver.recv().await {
106    ///         println!("Received: {:?}", msg);
107    ///     }
108    ///
109    ///     Ok(())
110    /// }
111    /// ```
112    pub async fn connect_private(
113        api_key: impl Into<String>,
114        api_secret: impl Into<String>,
115    ) -> Result<(Self, mpsc::UnboundedReceiver<WsMessage>), BybitError> {
116        let config = ClientConfig::new(api_key, api_secret);
117        Self::connect_with_config(config, WsChannel::Private).await
118    }
119
120    /// Create a new WebSocket client with custom configuration.
121    pub async fn connect_with_config(
122        config: ClientConfig,
123        channel: WsChannel,
124    ) -> Result<(Self, mpsc::UnboundedReceiver<WsMessage>), BybitError> {
125        if channel.requires_auth() && !config.has_credentials() {
126            return Err(BybitError::Config(
127                "Authentication required for private WebSocket channels".to_string(),
128            ));
129        }
130
131        let (message_tx, message_rx) = mpsc::unbounded_channel();
132        let (command_tx, command_rx) = mpsc::unbounded_channel();
133        let subscribed_topics = Arc::new(RwLock::new(HashSet::new()));
134        let connected = Arc::new(AtomicBool::new(false));
135        let running = Arc::new(AtomicBool::new(true));
136
137        let client = WsClient {
138            config: config.clone(),
139            channel,
140            subscribed_topics: subscribed_topics.clone(),
141            message_tx: message_tx.clone(),
142            command_tx,
143            connected: connected.clone(),
144            running: running.clone(),
145        };
146
147        tokio::spawn(Self::run_ws_loop(
148            config,
149            channel,
150            subscribed_topics,
151            message_tx,
152            command_rx,
153            connected,
154            running,
155        ));
156
157        Ok((client, message_rx))
158    }
159
160    /// Subscribe to topics.
161    pub async fn subscribe(&self, topics: &[&str]) -> Result<(), BybitError> {
162        if topics.is_empty() {
163            return Ok(());
164        }
165
166        let topics: Vec<String> = topics.iter().map(|t| t.to_string()).collect();
167
168        {
169            let mut subscribed = self.subscribed_topics.write().await;
170            for topic in &topics {
171                subscribed.insert(topic.clone());
172            }
173        }
174
175        self.command_tx
176            .send(WsCommand::Subscribe(topics))
177            .map_err(|_| BybitError::WebSocket("Failed to send subscribe command".to_string()))?;
178
179        Ok(())
180    }
181
182    /// Unsubscribe from topics.
183    pub async fn unsubscribe(&self, topics: &[&str]) -> Result<(), BybitError> {
184        if topics.is_empty() {
185            return Ok(());
186        }
187
188        let topics: Vec<String> = topics.iter().map(|t| t.to_string()).collect();
189
190        {
191            let mut subscribed = self.subscribed_topics.write().await;
192            for topic in &topics {
193                subscribed.remove(topic);
194            }
195        }
196
197        self.command_tx
198            .send(WsCommand::Unsubscribe(topics))
199            .map_err(|_| {
200                BybitError::WebSocket("Failed to send unsubscribe command".to_string())
201            })?;
202
203        Ok(())
204    }
205
206    /// Check if the client is connected.
207    pub fn is_connected(&self) -> bool {
208        self.connected.load(Ordering::SeqCst)
209    }
210
211    /// Disconnect and stop the client.
212    pub fn disconnect(&self) {
213        self.running.store(false, Ordering::SeqCst);
214        let _ = self.command_tx.send(WsCommand::Disconnect);
215    }
216
217    /// Get the list of subscribed topics.
218    pub async fn subscribed_topics(&self) -> Vec<String> {
219        self.subscribed_topics
220            .read()
221            .await
222            .iter()
223            .cloned()
224            .collect()
225    }
226
227    /// Get the WebSocket channel.
228    pub fn channel(&self) -> WsChannel {
229        self.channel
230    }
231
232    /// Build the WebSocket URL.
233    fn build_ws_url(config: &ClientConfig, channel: WsChannel) -> String {
234        let base = match config.environment {
235            Environment::Production => "wss://stream.bybit.com",
236            Environment::Testnet => "wss://stream-testnet.bybit.com",
237            Environment::Demo => "wss://stream-demo.bybit.com",
238        };
239        format!("{}{}", base, channel.path())
240    }
241
242    /// Main WebSocket loop.
243    async fn run_ws_loop(
244        config: ClientConfig,
245        channel: WsChannel,
246        subscribed_topics: Arc<RwLock<HashSet<String>>>,
247        message_tx: mpsc::UnboundedSender<WsMessage>,
248        mut command_rx: mpsc::UnboundedReceiver<WsCommand>,
249        connected: Arc<AtomicBool>,
250        running: Arc<AtomicBool>,
251    ) {
252        let mut reconnect_attempts = 0;
253
254        while running.load(Ordering::SeqCst) {
255            let url = Self::build_ws_url(&config, channel);
256            info!("Connecting to WebSocket: {}", url);
257
258            match Self::connect_and_run(
259                &url,
260                &config,
261                channel,
262                &subscribed_topics,
263                &message_tx,
264                &mut command_rx,
265                &connected,
266                &running,
267            )
268            .await
269            {
270                Ok(()) => {
271                    info!("WebSocket connection closed normally");
272                    break;
273                }
274                Err(e) => {
275                    error!("WebSocket error: {}", e);
276                    connected.store(false, Ordering::SeqCst);
277
278                    if !running.load(Ordering::SeqCst) {
279                        break;
280                    }
281
282                    reconnect_attempts += 1;
283                    if reconnect_attempts >= MAX_RECONNECT_ATTEMPTS {
284                        error!(
285                            "Max reconnect attempts ({}) reached, giving up",
286                            MAX_RECONNECT_ATTEMPTS
287                        );
288                        break;
289                    }
290
291                    let delay = Duration::from_secs(DEFAULT_RECONNECT_DELAY_SECS);
292                    warn!(
293                        "Reconnecting in {} seconds (attempt {}/{})",
294                        delay.as_secs(),
295                        reconnect_attempts,
296                        MAX_RECONNECT_ATTEMPTS
297                    );
298                    tokio::time::sleep(delay).await;
299                }
300            }
301        }
302
303        connected.store(false, Ordering::SeqCst);
304        info!("WebSocket task ended");
305    }
306
307    /// Connect and run the WebSocket.
308    #[allow(clippy::too_many_arguments)]
309    async fn connect_and_run(
310        url: &str,
311        config: &ClientConfig,
312        channel: WsChannel,
313        subscribed_topics: &Arc<RwLock<HashSet<String>>>,
314        message_tx: &mpsc::UnboundedSender<WsMessage>,
315        command_rx: &mut mpsc::UnboundedReceiver<WsCommand>,
316        connected: &Arc<AtomicBool>,
317        running: &Arc<AtomicBool>,
318    ) -> Result<(), BybitError> {
319        let (ws_stream, _) = tokio::time::timeout(Duration::from_secs(30), connect_async(url))
320            .await
321            .map_err(|_| BybitError::WebSocket("Connection timeout".to_string()))?
322            .map_err(|e| BybitError::WebSocket(format!("Connection failed: {}", e)))?;
323
324        info!("WebSocket connected");
325        connected.store(true, Ordering::SeqCst);
326
327        let (mut write, mut read) = ws_stream.split();
328
329        if channel.requires_auth() {
330            Self::authenticate(&mut write, config).await?;
331        }
332
333        {
334            let topics: Vec<String> = subscribed_topics.read().await.iter().cloned().collect();
335            if !topics.is_empty() {
336                info!("Re-subscribing to {} topics", topics.len());
337                let op = WsOperation::subscribe(topics);
338                let msg = serde_json::to_string(&op)
339                    .map_err(|e| BybitError::WebSocket(format!("Serialize error: {}", e)))?;
340                write
341                    .send(Message::Text(msg.into()))
342                    .await
343                    .map_err(|e| BybitError::WebSocket(format!("Send error: {}", e)))?;
344            }
345        }
346
347        let mut ping_interval = interval(Duration::from_secs(DEFAULT_PING_INTERVAL_SECS));
348
349        loop {
350            tokio::select! {
351                msg = read.next() => {
352                    match msg {
353                        Some(Ok(Message::Text(text))) => {
354                            if let Some(ws_msg) = Self::parse_message(text.as_str()) {
355                                if message_tx.send(ws_msg).is_err() {
356                                    debug!("Message receiver dropped");
357                                    break;
358                                }
359                            }
360                        }
361                        Some(Ok(Message::Ping(data))) => {
362                            debug!("Received ping");
363                            write.send(Message::Pong(data)).await
364                                .map_err(|e| BybitError::WebSocket(format!("Pong error: {}", e)))?;
365                        }
366                        Some(Ok(Message::Pong(_))) => {
367                            debug!("Received pong");
368                        }
369                        Some(Ok(Message::Close(frame))) => {
370                            info!("Received close frame: {:?}", frame);
371                            break;
372                        }
373                        Some(Err(e)) => {
374                            return Err(BybitError::WebSocket(format!("Read error: {}", e)));
375                        }
376                        None => {
377                            info!("WebSocket stream ended");
378                            break;
379                        }
380                        _ => {}
381                    }
382                }
383
384                cmd = command_rx.recv() => {
385                    match cmd {
386                        Some(WsCommand::Subscribe(topics)) => {
387                            let op = WsOperation::subscribe(topics);
388                            let msg = serde_json::to_string(&op)
389                                .map_err(|e| BybitError::WebSocket(format!("Serialize error: {}", e)))?;
390                            write.send(Message::Text(msg.into())).await
391                                .map_err(|e| BybitError::WebSocket(format!("Send error: {}", e)))?;
392                        }
393                        Some(WsCommand::Unsubscribe(topics)) => {
394                            let op = WsOperation::unsubscribe(topics);
395                            let msg = serde_json::to_string(&op)
396                                .map_err(|e| BybitError::WebSocket(format!("Serialize error: {}", e)))?;
397                            write.send(Message::Text(msg.into())).await
398                                .map_err(|e| BybitError::WebSocket(format!("Send error: {}", e)))?;
399                        }
400                        Some(WsCommand::SendRaw(text)) => {
401                            write.send(Message::Text(text.into())).await
402                                .map_err(|e| BybitError::WebSocket(format!("Send error: {}", e)))?;
403                        }
404                        Some(WsCommand::Disconnect) | None => {
405                            info!("Disconnect requested");
406                            let _ = write.send(Message::Close(None)).await;
407                            break;
408                        }
409                    }
410                }
411
412                _ = ping_interval.tick() => {
413                    let op = WsOperation::ping();
414                    let msg = serde_json::to_string(&op)
415                        .map_err(|e| BybitError::WebSocket(format!("Serialize error: {}", e)))?;
416                    write.send(Message::Text(msg.into())).await
417                        .map_err(|e| BybitError::WebSocket(format!("Ping error: {}", e)))?;
418                    debug!("Sent ping");
419                }
420
421                _ = tokio::time::sleep(Duration::from_millis(100)) => {
422                    if !running.load(Ordering::SeqCst) {
423                        info!("Stop requested");
424                        break;
425                    }
426                }
427            }
428        }
429
430        Ok(())
431    }
432
433    /// Authenticate the private WebSocket connection.
434    async fn authenticate(
435        write: &mut futures_util::stream::SplitSink<
436            WebSocketStream<MaybeTlsStream<TcpStream>>,
437            Message,
438        >,
439        config: &ClientConfig,
440    ) -> Result<(), BybitError> {
441        let api_key = config.api_key.as_ref().ok_or_else(|| {
442            BybitError::Config("API key required for authentication".to_string())
443        })?;
444
445        let api_secret = config.get_secret().ok_or_else(|| {
446            BybitError::Config("API secret required for authentication".to_string())
447        })?;
448
449        let expires = auth::current_timestamp_ms() + 10_000;
450        let signature = auth::sign_ws_auth(expires, api_secret);
451
452        let op = WsOperation::auth(api_key, expires, &signature);
453
454        let msg = serde_json::to_string(&op)
455            .map_err(|e| BybitError::WebSocket(format!("Serialize error: {}", e)))?;
456
457        write
458            .send(Message::Text(msg.into()))
459            .await
460            .map_err(|e| BybitError::WebSocket(format!("Auth send error: {}", e)))?;
461
462        info!("Sent authentication request");
463        Ok(())
464    }
465
466    /// Parse an incoming WebSocket message.
467    fn parse_message(text: &str) -> Option<WsMessage> {
468        let value: serde_json::Value = match serde_json::from_str(text) {
469            Ok(v) => v,
470            Err(e) => {
471                warn!("Failed to parse WebSocket message: {}", e);
472                return Some(WsMessage::Raw(text.to_string()));
473            }
474        };
475
476        if value.get("op").and_then(|v| v.as_str()) == Some("pong") {
477            if let Ok(pong) = serde_json::from_value(value.clone()) {
478                return Some(WsMessage::Pong(pong));
479            }
480        }
481
482        if value.get("success").is_some() && value.get("topic").is_none() {
483            if let Ok(response) = serde_json::from_value(value.clone()) {
484                return Some(WsMessage::OperationResponse(response));
485            }
486        }
487
488        if let Some(topic) = value.get("topic").and_then(|v| v.as_str()) {
489            if topic.starts_with("orderbook.") {
490                if let Ok(msg) = serde_json::from_value(value) {
491                    return Some(WsMessage::Orderbook(Box::new(msg)));
492                }
493            } else if topic.starts_with("publicTrade.") {
494                if let Ok(msg) = serde_json::from_value(value) {
495                    return Some(WsMessage::Trade(Box::new(msg)));
496                }
497            } else if topic.starts_with("tickers.") {
498                if let Ok(msg) = serde_json::from_value(value) {
499                    return Some(WsMessage::Ticker(Box::new(msg)));
500                }
501            } else if topic.starts_with("kline.") {
502                if let Ok(msg) = serde_json::from_value(value) {
503                    return Some(WsMessage::Kline(Box::new(msg)));
504                }
505            } else if topic.starts_with("liquidation.") {
506                if let Ok(msg) = serde_json::from_value(value) {
507                    return Some(WsMessage::Liquidation(Box::new(msg)));
508                }
509            }
510            else if topic == "position" || topic.starts_with("position.") {
511                if let Ok(msg) = serde_json::from_value(value) {
512                    return Some(WsMessage::Position(Box::new(msg)));
513                }
514            } else if topic == "order" || topic.starts_with("order.") {
515                if let Ok(msg) = serde_json::from_value(value) {
516                    return Some(WsMessage::Order(Box::new(msg)));
517                }
518            } else if topic == "execution.fast" {
519                if let Ok(msg) = serde_json::from_value(value) {
520                    return Some(WsMessage::ExecutionFast(Box::new(msg)));
521                }
522            } else if topic == "execution" || topic.starts_with("execution.") {
523                if let Ok(msg) = serde_json::from_value(value) {
524                    return Some(WsMessage::Execution(Box::new(msg)));
525                }
526            } else if topic == "wallet" {
527                if let Ok(msg) = serde_json::from_value(value) {
528                    return Some(WsMessage::Wallet(Box::new(msg)));
529                }
530            } else if topic == "greeks" {
531                if let Ok(msg) = serde_json::from_value(value) {
532                    return Some(WsMessage::Greeks(Box::new(msg)));
533                }
534            }
535        }
536
537        Some(WsMessage::Raw(text.to_string()))
538    }
539}
540
541impl Drop for WsClient {
542    fn drop(&mut self) {
543        self.disconnect();
544    }
545}
546
547#[cfg(test)]
548mod tests {
549    use super::*;
550
551    #[test]
552    fn test_build_ws_url() {
553        let config = ClientConfig::public_only();
554        let url = WsClient::build_ws_url(&config, WsChannel::PublicLinear);
555        assert_eq!(url, "wss://stream.bybit.com/v5/public/linear");
556
557        let testnet = config.testnet();
558        let url = WsClient::build_ws_url(&testnet, WsChannel::PublicLinear);
559        assert_eq!(url, "wss://stream-testnet.bybit.com/v5/public/linear");
560    }
561
562    #[test]
563    fn test_parse_pong() {
564        let text = r#"{"success":true,"ret_msg":"pong","conn_id":"abc123","op":"pong"}"#;
565        let msg = WsClient::parse_message(text);
566        assert!(matches!(msg, Some(WsMessage::Pong(_))));
567    }
568
569    #[test]
570    fn test_parse_operation_response() {
571        let text = r#"{"success":true,"ret_msg":"subscribe","conn_id":"abc123"}"#;
572        let msg = WsClient::parse_message(text);
573        assert!(matches!(msg, Some(WsMessage::OperationResponse(_))));
574    }
575}