1use std::collections::HashSet;
4use std::sync::atomic::{AtomicBool, Ordering};
5use std::sync::Arc;
6use std::time::Duration;
7
8use futures_util::{SinkExt, StreamExt};
9use tokio::net::TcpStream;
10use tokio::sync::{mpsc, RwLock};
11use tokio::time::interval;
12use tokio_tungstenite::{
13 connect_async,
14 tungstenite::Message,
15 MaybeTlsStream, WebSocketStream,
16};
17use tracing::{debug, error, info, warn};
18
19use crate::auth;
20use crate::config::{ClientConfig, Environment};
21use crate::error::BybitError;
22use crate::ws::types::*;
23
24const DEFAULT_PING_INTERVAL_SECS: u64 = 20;
26
27const DEFAULT_RECONNECT_DELAY_SECS: u64 = 5;
29
30const MAX_RECONNECT_ATTEMPTS: u32 = 10;
32
33pub struct WsClient {
35 #[allow(dead_code)]
37 config: ClientConfig,
38 channel: WsChannel,
40 subscribed_topics: Arc<RwLock<HashSet<String>>>,
42 #[allow(dead_code)]
44 message_tx: mpsc::UnboundedSender<WsMessage>,
45 command_tx: mpsc::UnboundedSender<WsCommand>,
47 connected: Arc<AtomicBool>,
49 running: Arc<AtomicBool>,
51}
52
53enum WsCommand {
55 Subscribe(Vec<String>),
56 Unsubscribe(Vec<String>),
57 #[allow(dead_code)]
58 SendRaw(String),
59 Disconnect,
60}
61
62impl WsClient {
63 pub async fn connect_public(
84 channel: WsChannel,
85 ) -> Result<(Self, mpsc::UnboundedReceiver<WsMessage>), BybitError> {
86 Self::connect_with_config(ClientConfig::public_only(), channel).await
87 }
88
89 pub async fn connect_private(
113 api_key: impl Into<String>,
114 api_secret: impl Into<String>,
115 ) -> Result<(Self, mpsc::UnboundedReceiver<WsMessage>), BybitError> {
116 let config = ClientConfig::new(api_key, api_secret);
117 Self::connect_with_config(config, WsChannel::Private).await
118 }
119
120 pub async fn connect_with_config(
122 config: ClientConfig,
123 channel: WsChannel,
124 ) -> Result<(Self, mpsc::UnboundedReceiver<WsMessage>), BybitError> {
125 if channel.requires_auth() && !config.has_credentials() {
126 return Err(BybitError::Config(
127 "Authentication required for private WebSocket channels".to_string(),
128 ));
129 }
130
131 let (message_tx, message_rx) = mpsc::unbounded_channel();
132 let (command_tx, command_rx) = mpsc::unbounded_channel();
133 let subscribed_topics = Arc::new(RwLock::new(HashSet::new()));
134 let connected = Arc::new(AtomicBool::new(false));
135 let running = Arc::new(AtomicBool::new(true));
136
137 let client = WsClient {
138 config: config.clone(),
139 channel,
140 subscribed_topics: subscribed_topics.clone(),
141 message_tx: message_tx.clone(),
142 command_tx,
143 connected: connected.clone(),
144 running: running.clone(),
145 };
146
147 tokio::spawn(Self::run_ws_loop(
148 config,
149 channel,
150 subscribed_topics,
151 message_tx,
152 command_rx,
153 connected,
154 running,
155 ));
156
157 Ok((client, message_rx))
158 }
159
160 pub async fn subscribe(&self, topics: &[&str]) -> Result<(), BybitError> {
162 if topics.is_empty() {
163 return Ok(());
164 }
165
166 let topics: Vec<String> = topics.iter().map(|t| t.to_string()).collect();
167
168 {
169 let mut subscribed = self.subscribed_topics.write().await;
170 for topic in &topics {
171 subscribed.insert(topic.clone());
172 }
173 }
174
175 self.command_tx
176 .send(WsCommand::Subscribe(topics))
177 .map_err(|_| BybitError::WebSocket("Failed to send subscribe command".to_string()))?;
178
179 Ok(())
180 }
181
182 pub async fn unsubscribe(&self, topics: &[&str]) -> Result<(), BybitError> {
184 if topics.is_empty() {
185 return Ok(());
186 }
187
188 let topics: Vec<String> = topics.iter().map(|t| t.to_string()).collect();
189
190 {
191 let mut subscribed = self.subscribed_topics.write().await;
192 for topic in &topics {
193 subscribed.remove(topic);
194 }
195 }
196
197 self.command_tx
198 .send(WsCommand::Unsubscribe(topics))
199 .map_err(|_| {
200 BybitError::WebSocket("Failed to send unsubscribe command".to_string())
201 })?;
202
203 Ok(())
204 }
205
206 pub fn is_connected(&self) -> bool {
208 self.connected.load(Ordering::SeqCst)
209 }
210
211 pub fn disconnect(&self) {
213 self.running.store(false, Ordering::SeqCst);
214 let _ = self.command_tx.send(WsCommand::Disconnect);
215 }
216
217 pub async fn subscribed_topics(&self) -> Vec<String> {
219 self.subscribed_topics
220 .read()
221 .await
222 .iter()
223 .cloned()
224 .collect()
225 }
226
227 pub fn channel(&self) -> WsChannel {
229 self.channel
230 }
231
232 fn build_ws_url(config: &ClientConfig, channel: WsChannel) -> String {
234 let base = match config.environment {
235 Environment::Production => "wss://stream.bybit.com",
236 Environment::Testnet => "wss://stream-testnet.bybit.com",
237 Environment::Demo => "wss://stream-demo.bybit.com",
238 };
239 format!("{}{}", base, channel.path())
240 }
241
242 async fn run_ws_loop(
244 config: ClientConfig,
245 channel: WsChannel,
246 subscribed_topics: Arc<RwLock<HashSet<String>>>,
247 message_tx: mpsc::UnboundedSender<WsMessage>,
248 mut command_rx: mpsc::UnboundedReceiver<WsCommand>,
249 connected: Arc<AtomicBool>,
250 running: Arc<AtomicBool>,
251 ) {
252 let mut reconnect_attempts = 0;
253
254 while running.load(Ordering::SeqCst) {
255 let url = Self::build_ws_url(&config, channel);
256 info!("Connecting to WebSocket: {}", url);
257
258 match Self::connect_and_run(
259 &url,
260 &config,
261 channel,
262 &subscribed_topics,
263 &message_tx,
264 &mut command_rx,
265 &connected,
266 &running,
267 )
268 .await
269 {
270 Ok(()) => {
271 info!("WebSocket connection closed normally");
272 break;
273 }
274 Err(e) => {
275 error!("WebSocket error: {}", e);
276 connected.store(false, Ordering::SeqCst);
277
278 if !running.load(Ordering::SeqCst) {
279 break;
280 }
281
282 reconnect_attempts += 1;
283 if reconnect_attempts >= MAX_RECONNECT_ATTEMPTS {
284 error!(
285 "Max reconnect attempts ({}) reached, giving up",
286 MAX_RECONNECT_ATTEMPTS
287 );
288 break;
289 }
290
291 let delay = Duration::from_secs(DEFAULT_RECONNECT_DELAY_SECS);
292 warn!(
293 "Reconnecting in {} seconds (attempt {}/{})",
294 delay.as_secs(),
295 reconnect_attempts,
296 MAX_RECONNECT_ATTEMPTS
297 );
298 tokio::time::sleep(delay).await;
299 }
300 }
301 }
302
303 connected.store(false, Ordering::SeqCst);
304 info!("WebSocket task ended");
305 }
306
307 #[allow(clippy::too_many_arguments)]
309 async fn connect_and_run(
310 url: &str,
311 config: &ClientConfig,
312 channel: WsChannel,
313 subscribed_topics: &Arc<RwLock<HashSet<String>>>,
314 message_tx: &mpsc::UnboundedSender<WsMessage>,
315 command_rx: &mut mpsc::UnboundedReceiver<WsCommand>,
316 connected: &Arc<AtomicBool>,
317 running: &Arc<AtomicBool>,
318 ) -> Result<(), BybitError> {
319 let (ws_stream, _) = tokio::time::timeout(Duration::from_secs(30), connect_async(url))
320 .await
321 .map_err(|_| BybitError::WebSocket("Connection timeout".to_string()))?
322 .map_err(|e| BybitError::WebSocket(format!("Connection failed: {}", e)))?;
323
324 info!("WebSocket connected");
325 connected.store(true, Ordering::SeqCst);
326
327 let (mut write, mut read) = ws_stream.split();
328
329 if channel.requires_auth() {
330 Self::authenticate(&mut write, config).await?;
331 }
332
333 {
334 let topics: Vec<String> = subscribed_topics.read().await.iter().cloned().collect();
335 if !topics.is_empty() {
336 info!("Re-subscribing to {} topics", topics.len());
337 let op = WsOperation::subscribe(topics);
338 let msg = serde_json::to_string(&op)
339 .map_err(|e| BybitError::WebSocket(format!("Serialize error: {}", e)))?;
340 write
341 .send(Message::Text(msg.into()))
342 .await
343 .map_err(|e| BybitError::WebSocket(format!("Send error: {}", e)))?;
344 }
345 }
346
347 let mut ping_interval = interval(Duration::from_secs(DEFAULT_PING_INTERVAL_SECS));
348
349 loop {
350 tokio::select! {
351 msg = read.next() => {
352 match msg {
353 Some(Ok(Message::Text(text))) => {
354 if let Some(ws_msg) = Self::parse_message(text.as_str()) {
355 if message_tx.send(ws_msg).is_err() {
356 debug!("Message receiver dropped");
357 break;
358 }
359 }
360 }
361 Some(Ok(Message::Ping(data))) => {
362 debug!("Received ping");
363 write.send(Message::Pong(data)).await
364 .map_err(|e| BybitError::WebSocket(format!("Pong error: {}", e)))?;
365 }
366 Some(Ok(Message::Pong(_))) => {
367 debug!("Received pong");
368 }
369 Some(Ok(Message::Close(frame))) => {
370 info!("Received close frame: {:?}", frame);
371 break;
372 }
373 Some(Err(e)) => {
374 return Err(BybitError::WebSocket(format!("Read error: {}", e)));
375 }
376 None => {
377 info!("WebSocket stream ended");
378 break;
379 }
380 _ => {}
381 }
382 }
383
384 cmd = command_rx.recv() => {
385 match cmd {
386 Some(WsCommand::Subscribe(topics)) => {
387 let op = WsOperation::subscribe(topics);
388 let msg = serde_json::to_string(&op)
389 .map_err(|e| BybitError::WebSocket(format!("Serialize error: {}", e)))?;
390 write.send(Message::Text(msg.into())).await
391 .map_err(|e| BybitError::WebSocket(format!("Send error: {}", e)))?;
392 }
393 Some(WsCommand::Unsubscribe(topics)) => {
394 let op = WsOperation::unsubscribe(topics);
395 let msg = serde_json::to_string(&op)
396 .map_err(|e| BybitError::WebSocket(format!("Serialize error: {}", e)))?;
397 write.send(Message::Text(msg.into())).await
398 .map_err(|e| BybitError::WebSocket(format!("Send error: {}", e)))?;
399 }
400 Some(WsCommand::SendRaw(text)) => {
401 write.send(Message::Text(text.into())).await
402 .map_err(|e| BybitError::WebSocket(format!("Send error: {}", e)))?;
403 }
404 Some(WsCommand::Disconnect) | None => {
405 info!("Disconnect requested");
406 let _ = write.send(Message::Close(None)).await;
407 break;
408 }
409 }
410 }
411
412 _ = ping_interval.tick() => {
413 let op = WsOperation::ping();
414 let msg = serde_json::to_string(&op)
415 .map_err(|e| BybitError::WebSocket(format!("Serialize error: {}", e)))?;
416 write.send(Message::Text(msg.into())).await
417 .map_err(|e| BybitError::WebSocket(format!("Ping error: {}", e)))?;
418 debug!("Sent ping");
419 }
420
421 _ = tokio::time::sleep(Duration::from_millis(100)) => {
422 if !running.load(Ordering::SeqCst) {
423 info!("Stop requested");
424 break;
425 }
426 }
427 }
428 }
429
430 Ok(())
431 }
432
433 async fn authenticate(
435 write: &mut futures_util::stream::SplitSink<
436 WebSocketStream<MaybeTlsStream<TcpStream>>,
437 Message,
438 >,
439 config: &ClientConfig,
440 ) -> Result<(), BybitError> {
441 let api_key = config.api_key.as_ref().ok_or_else(|| {
442 BybitError::Config("API key required for authentication".to_string())
443 })?;
444
445 let api_secret = config.get_secret().ok_or_else(|| {
446 BybitError::Config("API secret required for authentication".to_string())
447 })?;
448
449 let expires = auth::current_timestamp_ms() + 10_000;
450 let signature = auth::sign_ws_auth(expires, api_secret);
451
452 let op = WsOperation::auth(api_key, expires, &signature);
453
454 let msg = serde_json::to_string(&op)
455 .map_err(|e| BybitError::WebSocket(format!("Serialize error: {}", e)))?;
456
457 write
458 .send(Message::Text(msg.into()))
459 .await
460 .map_err(|e| BybitError::WebSocket(format!("Auth send error: {}", e)))?;
461
462 info!("Sent authentication request");
463 Ok(())
464 }
465
466 fn parse_message(text: &str) -> Option<WsMessage> {
468 let value: serde_json::Value = match serde_json::from_str(text) {
469 Ok(v) => v,
470 Err(e) => {
471 warn!("Failed to parse WebSocket message: {}", e);
472 return Some(WsMessage::Raw(text.to_string()));
473 }
474 };
475
476 if value.get("op").and_then(|v| v.as_str()) == Some("pong") {
477 if let Ok(pong) = serde_json::from_value(value.clone()) {
478 return Some(WsMessage::Pong(pong));
479 }
480 }
481
482 if value.get("success").is_some() && value.get("topic").is_none() {
483 if let Ok(response) = serde_json::from_value(value.clone()) {
484 return Some(WsMessage::OperationResponse(response));
485 }
486 }
487
488 if let Some(topic) = value.get("topic").and_then(|v| v.as_str()) {
489 if topic.starts_with("orderbook.") {
490 if let Ok(msg) = serde_json::from_value(value) {
491 return Some(WsMessage::Orderbook(Box::new(msg)));
492 }
493 } else if topic.starts_with("publicTrade.") {
494 if let Ok(msg) = serde_json::from_value(value) {
495 return Some(WsMessage::Trade(Box::new(msg)));
496 }
497 } else if topic.starts_with("tickers.") {
498 if let Ok(msg) = serde_json::from_value(value) {
499 return Some(WsMessage::Ticker(Box::new(msg)));
500 }
501 } else if topic.starts_with("kline.") {
502 if let Ok(msg) = serde_json::from_value(value) {
503 return Some(WsMessage::Kline(Box::new(msg)));
504 }
505 } else if topic.starts_with("liquidation.") {
506 if let Ok(msg) = serde_json::from_value(value) {
507 return Some(WsMessage::Liquidation(Box::new(msg)));
508 }
509 }
510 else if topic == "position" || topic.starts_with("position.") {
511 if let Ok(msg) = serde_json::from_value(value) {
512 return Some(WsMessage::Position(Box::new(msg)));
513 }
514 } else if topic == "order" || topic.starts_with("order.") {
515 if let Ok(msg) = serde_json::from_value(value) {
516 return Some(WsMessage::Order(Box::new(msg)));
517 }
518 } else if topic == "execution.fast" {
519 if let Ok(msg) = serde_json::from_value(value) {
520 return Some(WsMessage::ExecutionFast(Box::new(msg)));
521 }
522 } else if topic == "execution" || topic.starts_with("execution.") {
523 if let Ok(msg) = serde_json::from_value(value) {
524 return Some(WsMessage::Execution(Box::new(msg)));
525 }
526 } else if topic == "wallet" {
527 if let Ok(msg) = serde_json::from_value(value) {
528 return Some(WsMessage::Wallet(Box::new(msg)));
529 }
530 } else if topic == "greeks" {
531 if let Ok(msg) = serde_json::from_value(value) {
532 return Some(WsMessage::Greeks(Box::new(msg)));
533 }
534 }
535 }
536
537 Some(WsMessage::Raw(text.to_string()))
538 }
539}
540
541impl Drop for WsClient {
542 fn drop(&mut self) {
543 self.disconnect();
544 }
545}
546
547#[cfg(test)]
548mod tests {
549 use super::*;
550
551 #[test]
552 fn test_build_ws_url() {
553 let config = ClientConfig::public_only();
554 let url = WsClient::build_ws_url(&config, WsChannel::PublicLinear);
555 assert_eq!(url, "wss://stream.bybit.com/v5/public/linear");
556
557 let testnet = config.testnet();
558 let url = WsClient::build_ws_url(&testnet, WsChannel::PublicLinear);
559 assert_eq!(url, "wss://stream-testnet.bybit.com/v5/public/linear");
560 }
561
562 #[test]
563 fn test_parse_pong() {
564 let text = r#"{"success":true,"ret_msg":"pong","conn_id":"abc123","op":"pong"}"#;
565 let msg = WsClient::parse_message(text);
566 assert!(matches!(msg, Some(WsMessage::Pong(_))));
567 }
568
569 #[test]
570 fn test_parse_operation_response() {
571 let text = r#"{"success":true,"ret_msg":"subscribe","conn_id":"abc123"}"#;
572 let msg = WsClient::parse_message(text);
573 assert!(matches!(msg, Some(WsMessage::OperationResponse(_))));
574 }
575}