Skip to main content

bybit_rust_api/ws/
client.rs

1//! WebSocket client for Bybit V5 with automatic reconnection and resubscription.
2//!
3//! # Features
4//! - Connect to public or private WebSocket streams
5//! - Subscribe/unsubscribe topics
6//! - Automatic ping/pong heartbeat (20s)
7//! - Automatic reconnection with exponential backoff
8//! - Automatic re-subscription after reconnect
9//! - Implements `futures::Stream` for async iteration
10//!
11//! # Example
12//!
13//! ```ignore
14//! use bybit_rust_api::ws::{WsClient, topics};
15//! use futures_util::StreamExt;
16//!
17//! let mut client = WsClient::connect("wss://stream.bybit.com/v5/public/linear").await?;
18//! client.subscribe(vec![topics::orderbook(1, "BTCUSDT")]).await?;
19//! while let Some(msg) = client.next().await {
20//!     println!("{:?}", msg);
21//! }
22//! ```
23
24use crate::rest::errors::{BybitError, BybitResult};
25use crate::ws::messages::{WsMessage, WsRequest};
26use futures_util::stream::SplitSink;
27use futures_util::{SinkExt, Stream, StreamExt};
28use std::pin::Pin;
29use std::sync::Arc;
30use std::task::{Context, Poll};
31use std::time::Duration;
32use tokio::net::TcpStream;
33use tokio::sync::{mpsc, Mutex};
34use tokio::time::{interval, sleep};
35use tokio_tungstenite::tungstenite::Message;
36use tokio_tungstenite::{connect_async, MaybeTlsStream, WebSocketStream};
37
38type WsStream = WebSocketStream<MaybeTlsStream<TcpStream>>;
39
40/// Maximum reconnection attempts before giving up.
41const MAX_RECONNECT_ATTEMPTS: u32 = 10;
42/// Base delay for exponential backoff (ms).
43const RECONNECT_BASE_DELAY_MS: u64 = 500;
44/// Maximum delay between reconnection attempts (ms).
45const RECONNECT_MAX_DELAY_MS: u64 = 30_000;
46/// Ping interval in seconds.
47const PING_INTERVAL_SECS: u64 = 20;
48
49/// Authentication parameters stored for reconnection.
50#[derive(Clone)]
51struct AuthParams {
52    api_key: String,
53    expires: u64,
54    signature: String,
55}
56
57/// A WebSocket client for Bybit V5 streams with auto-reconnect.
58pub struct WsClient {
59    /// Channel to send commands to the connection task
60    command_tx: mpsc::UnboundedSender<Command>,
61    /// Receives parsed messages from the connection task
62    message_rx: mpsc::UnboundedReceiver<WsMessage>,
63    /// Handle to the connection task
64    _handle: Option<tokio::task::JoinHandle<()>>,
65    /// The WebSocket endpoint URL
66    url: String,
67    /// Currently subscribed topics (for resubscribe on reconnect)
68    subscribed_topics: Arc<Mutex<Vec<String>>>,
69}
70
71enum Command {
72    Subscribe(Vec<String>),
73    Unsubscribe(Vec<String>),
74    Authenticate {
75        api_key: String,
76        expires: u64,
77        signature: String,
78    },
79}
80
81impl WsClient {
82    /// Connect to a Bybit WebSocket endpoint.
83    ///
84    /// Spawns a background task that manages the connection lifecycle
85    /// including automatic reconnection with exponential backoff.
86    pub async fn connect(url: &str) -> BybitResult<Self> {
87        let (command_tx, command_rx) = mpsc::unbounded_channel();
88        let (message_tx, message_rx) = mpsc::unbounded_channel();
89
90        let subscribed_topics = Arc::new(Mutex::new(Vec::new()));
91        let topics = subscribed_topics.clone();
92        let url_owned = url.to_string();
93
94        let handle = tokio::spawn(async move {
95            run_connection_loop(&url_owned, command_rx, message_tx, topics).await;
96        });
97
98        Ok(WsClient {
99            command_tx,
100            message_rx,
101            _handle: Some(handle),
102            url: url.to_string(),
103            subscribed_topics,
104        })
105    }
106
107    /// Subscribe to one or more topics.
108    ///
109    /// Topics are remembered for automatic re-subscription on reconnect.
110    pub async fn subscribe(&self, topics: Vec<String>) -> BybitResult<()> {
111        // Store topics for reconnect
112        {
113            let mut stored = self.subscribed_topics.lock().await;
114            for t in &topics {
115                if !stored.contains(t) {
116                    stored.push(t.clone());
117                }
118            }
119        }
120
121        self.command_tx
122            .send(Command::Subscribe(topics))
123            .map_err(|e| BybitError::Internal(format!("Subscribe channel closed: {}", e)))?;
124        Ok(())
125    }
126
127    /// Unsubscribe from one or more topics.
128    pub async fn unsubscribe(&self, topics: Vec<String>) -> BybitResult<()> {
129        // Remove from stored topics
130        {
131            let mut stored = self.subscribed_topics.lock().await;
132            stored.retain(|t| !topics.contains(t));
133        }
134
135        self.command_tx
136            .send(Command::Unsubscribe(topics))
137            .map_err(|e| BybitError::Internal(format!("Unsubscribe channel closed: {}", e)))?;
138        Ok(())
139    }
140
141    /// Authenticate for private channels.
142    ///
143    /// Auth params are remembered for automatic re-authentication on reconnect.
144    pub async fn authenticate(
145        &self,
146        api_key: &str,
147        expires: u64,
148        signature: &str,
149    ) -> BybitResult<()> {
150        self.command_tx
151            .send(Command::Authenticate {
152                api_key: api_key.to_string(),
153                expires,
154                signature: signature.to_string(),
155            })
156            .map_err(|e| BybitError::Internal(format!("Auth channel closed: {}", e)))?;
157        Ok(())
158    }
159
160    /// Get the WebSocket endpoint URL.
161    pub fn url(&self) -> &str {
162        &self.url
163    }
164
165    /// Close the WebSocket connection gracefully.
166    ///
167    /// After calling this, the stream will end and no more messages
168    /// will be received. The background task is notified to shut down.
169    pub fn close(&mut self) {
170        // Dropping command_tx will cause the connection handler to exit
171        self.command_tx = mpsc::unbounded_channel().0;
172        self._handle = None;
173    }
174}
175
176impl Drop for WsClient {
177    fn drop(&mut self) {
178        // Shutdown the background task by dropping the sender
179        // The JoinHandle will be cancelled on drop
180    }
181}
182
183impl Stream for WsClient {
184    type Item = WsMessage;
185
186    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
187        self.message_rx.poll_recv(cx)
188    }
189}
190
191/// Connection loop: connects, handles messages, reconnects on failure.
192async fn run_connection_loop(
193    url: &str,
194    mut command_rx: mpsc::UnboundedReceiver<Command>,
195    message_tx: mpsc::UnboundedSender<WsMessage>,
196    subscribed_topics: Arc<Mutex<Vec<String>>>,
197) {
198    let mut auth_params: Option<AuthParams> = None;
199    let mut attempt = 0;
200
201    loop {
202        if attempt > 0 {
203            let delay_ms =
204                (RECONNECT_BASE_DELAY_MS * 2_u64.pow(attempt.min(6))).min(RECONNECT_MAX_DELAY_MS);
205            log::warn!(
206                "Reconnecting in {}ms (attempt {}/{})...",
207                delay_ms,
208                attempt,
209                MAX_RECONNECT_ATTEMPTS
210            );
211            sleep(Duration::from_millis(delay_ms)).await;
212        }
213
214        if attempt >= MAX_RECONNECT_ATTEMPTS {
215            log::error!("Max reconnect attempts reached. Giving up.");
216            break;
217        }
218
219        match connect_async(url).await {
220            Ok((ws_stream, _)) => {
221                log::info!("WebSocket connected to {}", url);
222                attempt = 0; // reset on successful connection
223
224                let (ws_write, ws_read) = ws_stream.split();
225                let ws_write = Arc::new(Mutex::new(ws_write));
226
227                // Re-authenticate if we had auth before
228                if let Some(ref auth) = auth_params {
229                    let req = WsRequest::auth(&auth.api_key, auth.expires, &auth.signature);
230                    send_command(&ws_write, &req).await;
231                }
232
233                // Re-subscribe all topics
234                {
235                    let topics = subscribed_topics.lock().await;
236                    if !topics.is_empty() {
237                        let req = WsRequest::subscribe(topics.clone());
238                        send_command(&ws_write, &req).await;
239                    }
240                }
241
242                // Run the connection until it fails
243                run_connection(
244                    ws_read,
245                    ws_write,
246                    &mut command_rx,
247                    &message_tx,
248                    &mut auth_params,
249                )
250                .await;
251            }
252            Err(e) => {
253                log::error!("Connection failed: {}", e);
254            }
255        }
256
257        attempt += 1;
258    }
259}
260
261/// Send a WS request through the writer.
262async fn send_command(writer: &Arc<Mutex<SplitSink<WsStream, Message>>>, req: &WsRequest) {
263    if let Ok(json) = serde_json::to_string(req) {
264        if let Ok(mut w) = writer.try_lock() {
265            let _ = w.send(Message::Text(json.into())).await;
266        }
267    }
268}
269
270/// Core connection handler: reads WS messages, processes commands, sends pings.
271async fn run_connection(
272    mut ws_read: futures_util::stream::SplitStream<WsStream>,
273    ws_write: Arc<Mutex<SplitSink<WsStream, Message>>>,
274    command_rx: &mut mpsc::UnboundedReceiver<Command>,
275    message_tx: &mpsc::UnboundedSender<WsMessage>,
276    auth_params: &mut Option<AuthParams>,
277) {
278    let mut ping_interval = interval(Duration::from_secs(PING_INTERVAL_SECS));
279
280    loop {
281        tokio::select! {
282            // Handle incoming commands
283            cmd = command_rx.recv() => {
284                match cmd {
285                    Some(Command::Subscribe(topics)) => {
286                        let req = WsRequest::subscribe(topics);
287                        send_command(&ws_write, &req).await;
288                    }
289                    Some(Command::Unsubscribe(topics)) => {
290                        let req = WsRequest::unsubscribe(topics);
291                        send_command(&ws_write, &req).await;
292                    }
293                    Some(Command::Authenticate { api_key, expires, signature }) => {
294                        *auth_params = Some(AuthParams {
295                            api_key: api_key.clone(),
296                            expires,
297                            signature: signature.clone(),
298                        });
299                        let req = WsRequest::auth(&api_key, expires, &signature);
300                        send_command(&ws_write, &req).await;
301                    }
302                    None => {
303                        // Command channel closed
304                        break;
305                    }
306                }
307            }
308
309            // Send ping every N seconds
310            _ = ping_interval.tick() => {
311                let ping = WsRequest::ping();
312                send_command(&ws_write, &ping).await;
313            }
314
315            // Read incoming messages
316            msg = ws_read.next() => {
317                match msg {
318                    Some(Ok(Message::Text(text))) => {
319                        match serde_json::from_str::<WsMessage>(&text) {
320                            Ok(parsed) => {
321                                if message_tx.send(parsed).is_err() {
322                                    break; // receiver dropped
323                                }
324                            }
325                            Err(e) => {
326                                log::warn!("Failed to parse WS message: {} -- raw: {}", e, text);
327                            }
328                        }
329                    }
330                    Some(Ok(Message::Ping(data))) => {
331                        if let Ok(mut writer) = ws_write.try_lock() {
332                            let _ = writer.send(Message::Pong(data)).await;
333                        }
334                    }
335                    Some(Ok(Message::Close(frame))) => {
336                        log::info!(
337                            "WebSocket closed by server: {:?}",
338                            frame.map(|f| f.reason.to_string())
339                        );
340                        break;
341                    }
342                    Some(Err(e)) => {
343                        log::error!("WebSocket error: {}", e);
344                        break;
345                    }
346                    None => {
347                        log::info!("WebSocket stream ended");
348                        break;
349                    }
350                    _ => {} // ignore binary, pong, etc.
351                }
352            }
353        }
354    }
355
356    log::info!("WebSocket connection handler exited");
357}