Skip to main content

bybit_api/websocket/
client.rs

1//! WebSocket client implementation.
2
3use futures_util::{SinkExt, StreamExt};
4use std::collections::HashMap;
5use std::sync::Arc;
6use tokio::net::TcpStream;
7use tokio::sync::{mpsc, RwLock};
8use tokio::time::{interval, Duration};
9use tokio_tungstenite::{connect_async, tungstenite::Message, MaybeTlsStream, WebSocketStream};
10use tracing::{debug, error, info, warn};
11
12use crate::auth::{generate_ws_signature, get_timestamp};
13use crate::config::WsConfig;
14use crate::error::{BybitError, Result};
15use crate::websocket::models::*;
16
17type WsStream = WebSocketStream<MaybeTlsStream<TcpStream>>;
18type Callback = Arc<dyn Fn(WsMessage) + Send + Sync>;
19
20/// WebSocket client for Bybit streaming API.
21pub struct BybitWebSocket {
22    config: WsConfig,
23    subscriptions: Arc<RwLock<Vec<String>>>,
24    callbacks: Arc<RwLock<HashMap<String, Callback>>>,
25    tx: Option<mpsc::Sender<Message>>,
26    is_connected: Arc<RwLock<bool>>,
27}
28
29impl BybitWebSocket {
30    /// Create a new public WebSocket client.
31    pub fn public(url: &str) -> Self {
32        Self {
33            config: WsConfig::public(url),
34            subscriptions: Arc::new(RwLock::new(Vec::new())),
35            callbacks: Arc::new(RwLock::new(HashMap::new())),
36            tx: None,
37            is_connected: Arc::new(RwLock::new(false)),
38        }
39    }
40
41    /// Create a new private WebSocket client.
42    pub fn private(api_key: &str, api_secret: &str, url: &str) -> Self {
43        Self {
44            config: WsConfig::private(api_key, api_secret).with_url(url),
45            subscriptions: Arc::new(RwLock::new(Vec::new())),
46            callbacks: Arc::new(RwLock::new(HashMap::new())),
47            tx: None,
48            is_connected: Arc::new(RwLock::new(false)),
49        }
50    }
51
52    /// Connect to the WebSocket server.
53    pub async fn connect(&mut self) -> Result<()> {
54        let url = &self.config.url;
55        info!(url = %url, "Connecting to WebSocket");
56
57        let (ws_stream, _) = connect_async(url)
58            .await
59            .map_err(|e| BybitError::WebSocket(Box::new(e)))?;
60
61        let (write, read) = ws_stream.split();
62
63        // Create channel for sending messages
64        let (tx, mut rx) = mpsc::channel::<Message>(100);
65        self.tx = Some(tx.clone());
66
67        // Set connected flag
68        *self.is_connected.write().await = true;
69
70        // Spawn write task
71        let write = Arc::new(tokio::sync::Mutex::new(write));
72        let write_clone = write.clone();
73        tokio::spawn(async move {
74            while let Some(msg) = rx.recv().await {
75                let mut w = write_clone.lock().await;
76                if let Err(e) = w.send(msg).await {
77                    error!("Failed to send message: {}", e);
78                    break;
79                }
80            }
81        });
82
83        // Authenticate if private channel
84        if self.config.api_key.is_some() {
85            self.authenticate().await?;
86        }
87
88        // Spawn ping task
89        let tx_ping = tx.clone();
90        let ping_interval = self.config.ping_interval;
91        tokio::spawn(async move {
92            let mut interval = interval(Duration::from_secs(ping_interval));
93            loop {
94                interval.tick().await;
95                let ping = WsPing::new();
96                let msg = serde_json::to_string(&ping).unwrap_or_default();
97                if tx_ping.send(Message::Text(msg)).await.is_err() {
98                    break;
99                }
100                debug!("Ping sent");
101            }
102        });
103
104        // Spawn read task
105        let callbacks = self.callbacks.clone();
106        let is_connected = self.is_connected.clone();
107        let subscriptions = self.subscriptions.clone();
108        let config = self.config.clone();
109        let tx_reconnect = tx.clone();
110
111        tokio::spawn(async move {
112            Self::handle_messages(
113                read,
114                callbacks,
115                is_connected,
116                subscriptions,
117                config,
118                tx_reconnect,
119            )
120            .await;
121        });
122
123        info!("WebSocket connected");
124        Ok(())
125    }
126
127    /// Handle incoming messages.
128    async fn handle_messages(
129        mut read: futures_util::stream::SplitStream<WsStream>,
130        callbacks: Arc<RwLock<HashMap<String, Callback>>>,
131        is_connected: Arc<RwLock<bool>>,
132        _subscriptions: Arc<RwLock<Vec<String>>>,
133        _config: WsConfig,
134        _tx: mpsc::Sender<Message>,
135    ) {
136        while let Some(msg_result) = read.next().await {
137            match msg_result {
138                Ok(Message::Text(text)) => {
139                    // Try to parse as JSON
140                    let json: serde_json::Value = match serde_json::from_str(&text) {
141                        Ok(v) => v,
142                        Err(e) => {
143                            warn!(
144                                "Failed to parse message: {}, text: {}",
145                                e,
146                                &text[..text.len().min(200)]
147                            );
148                            continue; // Don't panic, continue processing
149                        }
150                    };
151
152                    // Handle different message types
153                    if is_pong(&json) {
154                        debug!("Pong received");
155                        continue;
156                    }
157
158                    if is_auth_response(&json) {
159                        if json
160                            .get("success")
161                            .and_then(|v| v.as_bool())
162                            .unwrap_or(false)
163                        {
164                            info!("Authentication successful");
165                        } else {
166                            error!("Authentication failed: {:?}", json);
167                        }
168                        continue;
169                    }
170
171                    if is_subscription_response(&json) {
172                        if json
173                            .get("success")
174                            .and_then(|v| v.as_bool())
175                            .unwrap_or(false)
176                        {
177                            debug!("Subscription successful");
178                        } else {
179                            warn!("Subscription failed: {:?}", json);
180                        }
181                        continue;
182                    }
183
184                    // Handle data message
185                    if is_data_message(&json) {
186                        if let Ok(ws_msg) = serde_json::from_value::<WsMessage>(json) {
187                            let cbs = callbacks.read().await;
188                            if let Some(callback) = cbs.get(&ws_msg.topic) {
189                                callback(ws_msg.clone());
190                            } else {
191                                // Try to find matching callback by prefix
192                                for (topic, callback) in cbs.iter() {
193                                    if ws_msg
194                                        .topic
195                                        .starts_with(topic.split('.').next().unwrap_or(""))
196                                    {
197                                        callback(ws_msg.clone());
198                                        break;
199                                    }
200                                }
201                            }
202                        }
203                    }
204                }
205                Ok(Message::Ping(_)) => {
206                    debug!("Received ping frame");
207                    // Tungstenite handles pong automatically
208                }
209                Ok(Message::Close(_)) => {
210                    info!("WebSocket closed");
211                    *is_connected.write().await = false;
212                    break;
213                }
214                Err(e) => {
215                    error!("WebSocket error: {}", e);
216                    *is_connected.write().await = false;
217                    break;
218                }
219                _ => {}
220            }
221        }
222    }
223
224    /// Authenticate with the server (for private channels).
225    async fn authenticate(&self) -> Result<()> {
226        let api_key = self
227            .config
228            .api_key
229            .as_ref()
230            .ok_or_else(|| BybitError::Auth("API key not set".into()))?;
231        let api_secret = self
232            .config
233            .api_secret
234            .as_ref()
235            .ok_or_else(|| BybitError::Auth("API secret not set".into()))?;
236
237        let expires = get_timestamp() + 10000;
238        let signature = generate_ws_signature(api_secret, expires);
239
240        let auth_msg = WsAuthRequest {
241            req_id: uuid::Uuid::new_v4().to_string(),
242            op: "auth".to_string(),
243            args: vec![
244                serde_json::Value::String(api_key.clone()),
245                serde_json::Value::Number(expires.into()),
246                serde_json::Value::String(signature),
247            ],
248        };
249
250        let msg = serde_json::to_string(&auth_msg).map_err(|e| BybitError::Parse(e.to_string()))?;
251
252        self.send(msg).await?;
253        info!("Authentication request sent");
254        Ok(())
255    }
256
257    /// Subscribe to topics.
258    ///
259    /// # Arguments
260    /// * `topics` - List of topics to subscribe
261    /// * `callback` - Callback function for received messages
262    pub async fn subscribe<F>(&mut self, topics: Vec<String>, callback: F) -> Result<()>
263    where
264        F: Fn(WsMessage) + Send + Sync + 'static,
265    {
266        let callback = Arc::new(callback) as Callback;
267
268        // Register callbacks
269        {
270            let mut cbs = self.callbacks.write().await;
271            for topic in &topics {
272                cbs.insert(topic.clone(), callback.clone());
273            }
274        }
275
276        // Store subscriptions
277        {
278            let mut subs = self.subscriptions.write().await;
279            subs.extend(topics.clone());
280        }
281
282        // Send subscription request
283        let sub_msg = WsRequest {
284            req_id: uuid::Uuid::new_v4().to_string(),
285            op: "subscribe".to_string(),
286            args: topics,
287        };
288
289        let msg = serde_json::to_string(&sub_msg).map_err(|e| BybitError::Parse(e.to_string()))?;
290
291        self.send(msg).await
292    }
293
294    /// Unsubscribe from topics.
295    pub async fn unsubscribe(&mut self, topics: Vec<String>) -> Result<()> {
296        // Remove callbacks
297        {
298            let mut cbs = self.callbacks.write().await;
299            for topic in &topics {
300                cbs.remove(topic);
301            }
302        }
303
304        // Remove from subscriptions
305        {
306            let mut subs = self.subscriptions.write().await;
307            subs.retain(|t| !topics.contains(t));
308        }
309
310        // Send unsubscribe request
311        let unsub_msg = WsRequest {
312            req_id: uuid::Uuid::new_v4().to_string(),
313            op: "unsubscribe".to_string(),
314            args: topics,
315        };
316
317        let msg =
318            serde_json::to_string(&unsub_msg).map_err(|e| BybitError::Parse(e.to_string()))?;
319
320        self.send(msg).await
321    }
322
323    /// Send a message.
324    async fn send(&self, msg: String) -> Result<()> {
325        if let Some(tx) = &self.tx {
326            tx.send(Message::Text(msg)).await.map_err(|_| {
327                BybitError::WebSocket(Box::new(
328                    tokio_tungstenite::tungstenite::Error::AlreadyClosed,
329                ))
330            })?;
331        }
332        Ok(())
333    }
334
335    /// Check if connected.
336    pub async fn is_connected(&self) -> bool {
337        *self.is_connected.read().await
338    }
339
340    /// Disconnect from the server.
341    pub async fn disconnect(&mut self) -> Result<()> {
342        *self.is_connected.write().await = false;
343        self.tx = None;
344        info!("WebSocket disconnected");
345        Ok(())
346    }
347}