1use crate::rest::errors::{BybitError, BybitResult};
25use crate::ws::messages::{WsMessage, WsRequest};
26use futures_util::stream::SplitSink;
27use futures_util::{SinkExt, Stream, StreamExt};
28use std::pin::Pin;
29use std::sync::Arc;
30use std::task::{Context, Poll};
31use std::time::Duration;
32use tokio::net::TcpStream;
33use tokio::sync::{mpsc, Mutex};
34use tokio::time::{interval, sleep};
35use tokio_tungstenite::tungstenite::Message;
36use tokio_tungstenite::{connect_async, MaybeTlsStream, WebSocketStream};
37
38type WsStream = WebSocketStream<MaybeTlsStream<TcpStream>>;
39
40const MAX_RECONNECT_ATTEMPTS: u32 = 10;
42const RECONNECT_BASE_DELAY_MS: u64 = 500;
44const RECONNECT_MAX_DELAY_MS: u64 = 30_000;
46const PING_INTERVAL_SECS: u64 = 20;
48
49#[derive(Clone)]
51struct AuthParams {
52 api_key: String,
53 expires: u64,
54 signature: String,
55}
56
57pub struct WsClient {
59 command_tx: mpsc::UnboundedSender<Command>,
61 message_rx: mpsc::UnboundedReceiver<WsMessage>,
63 _handle: Option<tokio::task::JoinHandle<()>>,
65 url: String,
67 subscribed_topics: Arc<Mutex<Vec<String>>>,
69}
70
71enum Command {
72 Subscribe(Vec<String>),
73 Unsubscribe(Vec<String>),
74 Authenticate {
75 api_key: String,
76 expires: u64,
77 signature: String,
78 },
79}
80
81impl WsClient {
82 pub async fn connect(url: &str) -> BybitResult<Self> {
87 let (command_tx, command_rx) = mpsc::unbounded_channel();
88 let (message_tx, message_rx) = mpsc::unbounded_channel();
89
90 let subscribed_topics = Arc::new(Mutex::new(Vec::new()));
91 let topics = subscribed_topics.clone();
92 let url_owned = url.to_string();
93
94 let handle = tokio::spawn(async move {
95 run_connection_loop(&url_owned, command_rx, message_tx, topics).await;
96 });
97
98 Ok(WsClient {
99 command_tx,
100 message_rx,
101 _handle: Some(handle),
102 url: url.to_string(),
103 subscribed_topics,
104 })
105 }
106
107 pub async fn subscribe(&self, topics: Vec<String>) -> BybitResult<()> {
111 {
113 let mut stored = self.subscribed_topics.lock().await;
114 for t in &topics {
115 if !stored.contains(t) {
116 stored.push(t.clone());
117 }
118 }
119 }
120
121 self.command_tx
122 .send(Command::Subscribe(topics))
123 .map_err(|e| BybitError::Internal(format!("Subscribe channel closed: {}", e)))?;
124 Ok(())
125 }
126
127 pub async fn unsubscribe(&self, topics: Vec<String>) -> BybitResult<()> {
129 {
131 let mut stored = self.subscribed_topics.lock().await;
132 stored.retain(|t| !topics.contains(t));
133 }
134
135 self.command_tx
136 .send(Command::Unsubscribe(topics))
137 .map_err(|e| BybitError::Internal(format!("Unsubscribe channel closed: {}", e)))?;
138 Ok(())
139 }
140
141 pub async fn authenticate(
145 &self,
146 api_key: &str,
147 expires: u64,
148 signature: &str,
149 ) -> BybitResult<()> {
150 self.command_tx
151 .send(Command::Authenticate {
152 api_key: api_key.to_string(),
153 expires,
154 signature: signature.to_string(),
155 })
156 .map_err(|e| BybitError::Internal(format!("Auth channel closed: {}", e)))?;
157 Ok(())
158 }
159
160 pub fn url(&self) -> &str {
162 &self.url
163 }
164
165 pub fn close(&mut self) {
170 self.command_tx = mpsc::unbounded_channel().0;
172 self._handle = None;
173 }
174}
175
176impl Drop for WsClient {
177 fn drop(&mut self) {
178 }
181}
182
183impl Stream for WsClient {
184 type Item = WsMessage;
185
186 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
187 self.message_rx.poll_recv(cx)
188 }
189}
190
191async fn run_connection_loop(
193 url: &str,
194 mut command_rx: mpsc::UnboundedReceiver<Command>,
195 message_tx: mpsc::UnboundedSender<WsMessage>,
196 subscribed_topics: Arc<Mutex<Vec<String>>>,
197) {
198 let mut auth_params: Option<AuthParams> = None;
199 let mut attempt = 0;
200
201 loop {
202 if attempt > 0 {
203 let delay_ms =
204 (RECONNECT_BASE_DELAY_MS * 2_u64.pow(attempt.min(6))).min(RECONNECT_MAX_DELAY_MS);
205 log::warn!(
206 "Reconnecting in {}ms (attempt {}/{})...",
207 delay_ms,
208 attempt,
209 MAX_RECONNECT_ATTEMPTS
210 );
211 sleep(Duration::from_millis(delay_ms)).await;
212 }
213
214 if attempt >= MAX_RECONNECT_ATTEMPTS {
215 log::error!("Max reconnect attempts reached. Giving up.");
216 break;
217 }
218
219 match connect_async(url).await {
220 Ok((ws_stream, _)) => {
221 log::info!("WebSocket connected to {}", url);
222 attempt = 0; let (ws_write, ws_read) = ws_stream.split();
225 let ws_write = Arc::new(Mutex::new(ws_write));
226
227 if let Some(ref auth) = auth_params {
229 let req = WsRequest::auth(&auth.api_key, auth.expires, &auth.signature);
230 send_command(&ws_write, &req).await;
231 }
232
233 {
235 let topics = subscribed_topics.lock().await;
236 if !topics.is_empty() {
237 let req = WsRequest::subscribe(topics.clone());
238 send_command(&ws_write, &req).await;
239 }
240 }
241
242 run_connection(
244 ws_read,
245 ws_write,
246 &mut command_rx,
247 &message_tx,
248 &mut auth_params,
249 )
250 .await;
251 }
252 Err(e) => {
253 log::error!("Connection failed: {}", e);
254 }
255 }
256
257 attempt += 1;
258 }
259}
260
261async fn send_command(writer: &Arc<Mutex<SplitSink<WsStream, Message>>>, req: &WsRequest) {
263 if let Ok(json) = serde_json::to_string(req) {
264 if let Ok(mut w) = writer.try_lock() {
265 let _ = w.send(Message::Text(json.into())).await;
266 }
267 }
268}
269
270async fn run_connection(
272 mut ws_read: futures_util::stream::SplitStream<WsStream>,
273 ws_write: Arc<Mutex<SplitSink<WsStream, Message>>>,
274 command_rx: &mut mpsc::UnboundedReceiver<Command>,
275 message_tx: &mpsc::UnboundedSender<WsMessage>,
276 auth_params: &mut Option<AuthParams>,
277) {
278 let mut ping_interval = interval(Duration::from_secs(PING_INTERVAL_SECS));
279
280 loop {
281 tokio::select! {
282 cmd = command_rx.recv() => {
284 match cmd {
285 Some(Command::Subscribe(topics)) => {
286 let req = WsRequest::subscribe(topics);
287 send_command(&ws_write, &req).await;
288 }
289 Some(Command::Unsubscribe(topics)) => {
290 let req = WsRequest::unsubscribe(topics);
291 send_command(&ws_write, &req).await;
292 }
293 Some(Command::Authenticate { api_key, expires, signature }) => {
294 *auth_params = Some(AuthParams {
295 api_key: api_key.clone(),
296 expires,
297 signature: signature.clone(),
298 });
299 let req = WsRequest::auth(&api_key, expires, &signature);
300 send_command(&ws_write, &req).await;
301 }
302 None => {
303 break;
305 }
306 }
307 }
308
309 _ = ping_interval.tick() => {
311 let ping = WsRequest::ping();
312 send_command(&ws_write, &ping).await;
313 }
314
315 msg = ws_read.next() => {
317 match msg {
318 Some(Ok(Message::Text(text))) => {
319 match serde_json::from_str::<WsMessage>(&text) {
320 Ok(parsed) => {
321 if message_tx.send(parsed).is_err() {
322 break; }
324 }
325 Err(e) => {
326 log::warn!("Failed to parse WS message: {} -- raw: {}", e, text);
327 }
328 }
329 }
330 Some(Ok(Message::Ping(data))) => {
331 if let Ok(mut writer) = ws_write.try_lock() {
332 let _ = writer.send(Message::Pong(data)).await;
333 }
334 }
335 Some(Ok(Message::Close(frame))) => {
336 log::info!(
337 "WebSocket closed by server: {:?}",
338 frame.map(|f| f.reason.to_string())
339 );
340 break;
341 }
342 Some(Err(e)) => {
343 log::error!("WebSocket error: {}", e);
344 break;
345 }
346 None => {
347 log::info!("WebSocket stream ended");
348 break;
349 }
350 _ => {} }
352 }
353 }
354 }
355
356 log::info!("WebSocket connection handler exited");
357}