ccxt_core/ws_client/
mod.rs

1//! WebSocket client module.
2//!
3//! Provides asynchronous WebSocket connection management, subscription handling,
4//! and heartbeat maintenance for cryptocurrency exchange streaming APIs.
5
6mod config;
7mod error;
8mod event;
9mod message;
10mod reconnect;
11mod state;
12mod subscription;
13
14pub use config::{
15    BackoffConfig, BackoffStrategy, BackpressureStrategy, DEFAULT_MAX_SUBSCRIPTIONS,
16    DEFAULT_MESSAGE_CHANNEL_CAPACITY, DEFAULT_SHUTDOWN_TIMEOUT, DEFAULT_WRITE_CHANNEL_CAPACITY,
17    WsConfig,
18};
19pub use error::{WsError, WsErrorKind};
20pub use event::{WsEvent, WsEventCallback};
21pub use message::WsMessage;
22pub use reconnect::AutoReconnectCoordinator;
23pub use state::{WsConnectionState, WsStats, WsStatsSnapshot};
24pub use subscription::{Subscription, SubscriptionManager};
25
26use crate::error::{Error, Result};
27use derive_more::Debug;
28use futures_util::{SinkExt, StreamExt, stream::SplitSink};
29use serde_json::Value;
30use std::collections::HashMap;
31use std::sync::Arc;
32use std::sync::atomic::{AtomicU8, AtomicU32, Ordering};
33use tokio::net::TcpStream;
34use tokio::sync::{Mutex, RwLock, mpsc};
35use tokio::time::{Duration, interval};
36use tokio_tungstenite::{
37    MaybeTlsStream, WebSocketStream, connect_async, tungstenite::protocol::Message,
38};
39use tokio_util::sync::CancellationToken;
40use tracing::{debug, error, info, instrument, warn};
41
42/// Type alias for WebSocket write half.
43#[allow(dead_code)]
44type WsWriter = SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>;
45
46/// Async WebSocket client for exchange streaming APIs.
47///
48/// # Backpressure Handling
49///
50/// This client uses bounded channels to prevent memory exhaustion in high-frequency
51/// trading scenarios. When the message channel is full, the configured
52/// `BackpressureStrategy` determines how to handle new messages:
53///
54/// - `DropOldest`: Removes the oldest message to make room (default)
55/// - `DropNewest`: Discards the incoming message
56/// - `Block`: Waits until space is available (may stall the read loop)
57#[derive(Debug)]
58pub struct WsClient {
59    config: WsConfig,
60    state: Arc<AtomicU8>,
61    subscription_manager: SubscriptionManager,
62    message_tx: mpsc::Sender<Value>,
63    message_rx: Arc<RwLock<mpsc::Receiver<Value>>>,
64    write_tx: Arc<RwLock<Option<mpsc::Sender<Message>>>>,
65    pub(crate) reconnect_count: AtomicU32,
66    shutdown_tx: Arc<Mutex<Option<mpsc::UnboundedSender<()>>>>,
67    stats: Arc<WsStats>,
68    cancel_token: Arc<Mutex<Option<CancellationToken>>>,
69    #[debug(skip)]
70    event_callback: Arc<Mutex<Option<WsEventCallback>>>,
71    /// Counter for dropped messages due to backpressure
72    dropped_messages: Arc<AtomicU32>,
73}
74
75impl WsClient {
76    /// Creates a new WebSocket client instance.
77    ///
78    /// The client uses bounded channels for message passing to prevent memory
79    /// exhaustion. Channel capacities are configured via `WsConfig`.
80    pub fn new(config: WsConfig) -> Self {
81        let (message_tx, message_rx) = mpsc::channel(config.message_channel_capacity);
82        let max_subscriptions = config.max_subscriptions;
83
84        Self {
85            config,
86            state: Arc::new(AtomicU8::new(WsConnectionState::Disconnected.as_u8())),
87            subscription_manager: SubscriptionManager::new(max_subscriptions),
88            message_tx,
89            message_rx: Arc::new(RwLock::new(message_rx)),
90            write_tx: Arc::new(RwLock::new(None)),
91            reconnect_count: AtomicU32::new(0),
92            shutdown_tx: Arc::new(Mutex::new(None)),
93            stats: Arc::new(WsStats::new()),
94            cancel_token: Arc::new(Mutex::new(None)),
95            event_callback: Arc::new(Mutex::new(None)),
96            dropped_messages: Arc::new(AtomicU32::new(0)),
97        }
98    }
99
100    /// Sets the event callback for connection lifecycle events.
101    pub async fn set_event_callback(&self, callback: WsEventCallback) {
102        *self.event_callback.lock().await = Some(callback);
103        debug!("Event callback set");
104    }
105
106    /// Clears the event callback.
107    pub async fn clear_event_callback(&self) {
108        *self.event_callback.lock().await = None;
109        debug!("Event callback cleared");
110    }
111
112    async fn emit_event(&self, event: WsEvent) {
113        let callback = self.event_callback.lock().await;
114        if let Some(ref cb) = *callback {
115            let cb = Arc::clone(cb);
116            drop(callback);
117            tokio::spawn(async move {
118                cb(event);
119            });
120        }
121    }
122
123    /// Sets the cancellation token for this client.
124    pub async fn set_cancel_token(&self, token: CancellationToken) {
125        *self.cancel_token.lock().await = Some(token);
126        debug!("Cancellation token set");
127    }
128
129    /// Clears the cancellation token.
130    pub async fn clear_cancel_token(&self) {
131        *self.cancel_token.lock().await = None;
132        debug!("Cancellation token cleared");
133    }
134
135    /// Returns a clone of the current cancellation token, if set.
136    pub async fn get_cancel_token(&self) -> Option<CancellationToken> {
137        self.cancel_token.lock().await.clone()
138    }
139
140    /// Establishes connection to the WebSocket server.
141    #[instrument(
142        name = "ws_connect",
143        skip(self),
144        fields(url = %self.config.url, timeout_ms = self.config.connect_timeout)
145    )]
146    pub async fn connect(&self) -> Result<()> {
147        if self.state() == WsConnectionState::Connected {
148            info!("WebSocket already connected");
149            return Ok(());
150        }
151
152        self.set_state(WsConnectionState::Connecting);
153
154        let url = self.config.url.clone();
155        info!("Initiating WebSocket connection");
156
157        match tokio::time::timeout(
158            Duration::from_millis(self.config.connect_timeout),
159            connect_async(&url),
160        )
161        .await
162        {
163            Ok(Ok((ws_stream, response))) => {
164                info!(
165                    status = response.status().as_u16(),
166                    "WebSocket connection established successfully"
167                );
168
169                self.set_state(WsConnectionState::Connected);
170                self.reconnect_count.store(0, Ordering::Release);
171                self.stats.record_connected();
172                self.start_message_loop(ws_stream).await;
173                self.resubscribe_all().await?;
174
175                Ok(())
176            }
177            Ok(Err(e)) => {
178                error!(error = %e, "WebSocket connection failed");
179                self.set_state(WsConnectionState::Error);
180                Err(Error::network(format!("WebSocket connection failed: {e}")))
181            }
182            Err(_) => {
183                error!(
184                    timeout_ms = self.config.connect_timeout,
185                    "WebSocket connection timeout"
186                );
187                self.set_state(WsConnectionState::Error);
188                Err(Error::timeout("WebSocket connection timeout"))
189            }
190        }
191    }
192
193    /// Establishes connection with cancellation support.
194    #[instrument(
195        name = "ws_connect_with_cancel",
196        skip(self, cancel_token),
197        fields(url = %self.config.url)
198    )]
199    pub async fn connect_with_cancel(&self, cancel_token: Option<CancellationToken>) -> Result<()> {
200        let token = if let Some(t) = cancel_token {
201            t
202        } else {
203            let internal_token = self.cancel_token.lock().await;
204            internal_token
205                .clone()
206                .unwrap_or_else(CancellationToken::new)
207        };
208
209        if self.state() == WsConnectionState::Connected {
210            info!("WebSocket already connected");
211            return Ok(());
212        }
213
214        self.set_state(WsConnectionState::Connecting);
215        let url = self.config.url.clone();
216
217        tokio::select! {
218            biased;
219            () = token.cancelled() => {
220                warn!("WebSocket connection cancelled");
221                self.set_state(WsConnectionState::Disconnected);
222                Err(Error::cancelled("WebSocket connection cancelled"))
223            }
224            result = tokio::time::timeout(
225                Duration::from_millis(self.config.connect_timeout),
226                connect_async(&url),
227            ) => {
228                match result {
229                    Ok(Ok((ws_stream, response))) => {
230                        info!(status = response.status().as_u16(), "WebSocket connected");
231                        self.set_state(WsConnectionState::Connected);
232                        self.reconnect_count.store(0, Ordering::Release);
233                        self.stats.record_connected();
234                        self.start_message_loop(ws_stream).await;
235                        self.resubscribe_all().await?;
236                        Ok(())
237                    }
238                    Ok(Err(e)) => {
239                        error!(error = %e, "WebSocket connection failed");
240                        self.set_state(WsConnectionState::Error);
241                        Err(Error::network(format!("WebSocket connection failed: {e}")))
242                    }
243                    Err(_) => {
244                        error!("WebSocket connection timeout");
245                        self.set_state(WsConnectionState::Error);
246                        Err(Error::timeout("WebSocket connection timeout"))
247                    }
248                }
249            }
250        }
251    }
252
253    /// Closes the WebSocket connection gracefully.
254    #[instrument(name = "ws_disconnect", skip(self))]
255    pub async fn disconnect(&self) -> Result<()> {
256        info!("Initiating WebSocket disconnect");
257
258        if let Some(tx) = self.shutdown_tx.lock().await.as_ref() {
259            let _ = tx.send(());
260        }
261
262        *self.write_tx.write().await = None;
263        self.set_state(WsConnectionState::Disconnected);
264
265        info!("WebSocket disconnected");
266        Ok(())
267    }
268
269    /// Gracefully shuts down the WebSocket client.
270    #[instrument(name = "ws_shutdown", skip(self))]
271    pub async fn shutdown(&self) {
272        info!("Initiating graceful shutdown");
273
274        {
275            let token_guard = self.cancel_token.lock().await;
276            if let Some(ref token) = *token_guard {
277                token.cancel();
278            }
279        }
280
281        self.set_state(WsConnectionState::Disconnected);
282
283        {
284            let write_tx_guard = self.write_tx.read().await;
285            if let Some(ref tx) = *write_tx_guard {
286                // Ignore send result - we're shutting down anyway
287                drop(tx.send(Message::Close(None)).await);
288            }
289        }
290
291        let shutdown_timeout = Duration::from_millis(self.config.shutdown_timeout);
292        let _ = tokio::time::timeout(shutdown_timeout, async {
293            if let Some(tx) = self.shutdown_tx.lock().await.as_ref() {
294                let _ = tx.send(());
295            }
296            tokio::time::sleep(Duration::from_millis(100)).await;
297        })
298        .await;
299
300        {
301            *self.write_tx.write().await = None;
302            *self.shutdown_tx.lock().await = None;
303            self.subscription_manager.clear();
304            self.reconnect_count.store(0, Ordering::Release);
305            self.dropped_messages.store(0, Ordering::Relaxed);
306            self.stats.reset();
307        }
308
309        self.emit_event(WsEvent::Shutdown).await;
310        info!("Graceful shutdown completed");
311    }
312
313    /// Attempts to reconnect to the WebSocket server.
314    #[instrument(name = "ws_reconnect", skip(self))]
315    pub async fn reconnect(&self) -> Result<()> {
316        let count = self.reconnect_count.fetch_add(1, Ordering::AcqRel) + 1;
317
318        if count > self.config.max_reconnect_attempts {
319            error!(attempts = count, "Max reconnect attempts reached");
320            return Err(Error::network("Max reconnect attempts reached"));
321        }
322
323        warn!(attempt = count, "Attempting WebSocket reconnection");
324        self.set_state(WsConnectionState::Reconnecting);
325
326        tokio::time::sleep(Duration::from_millis(self.config.reconnect_interval)).await;
327        self.connect().await
328    }
329
330    /// Attempts to reconnect with cancellation support.
331    #[instrument(name = "ws_reconnect_with_cancel", skip(self, cancel_token))]
332    pub async fn reconnect_with_cancel(
333        &self,
334        cancel_token: Option<CancellationToken>,
335    ) -> Result<()> {
336        let token = if let Some(t) = cancel_token {
337            t
338        } else {
339            let internal_token = self.cancel_token.lock().await;
340            internal_token
341                .clone()
342                .unwrap_or_else(CancellationToken::new)
343        };
344
345        let backoff = BackoffStrategy::new(self.config.backoff_config.clone());
346        self.set_state(WsConnectionState::Reconnecting);
347
348        loop {
349            if token.is_cancelled() {
350                self.set_state(WsConnectionState::Disconnected);
351                return Err(Error::cancelled("Reconnection cancelled"));
352            }
353
354            let attempt = self.reconnect_count.fetch_add(1, Ordering::AcqRel);
355
356            if attempt >= self.config.max_reconnect_attempts {
357                self.set_state(WsConnectionState::Error);
358                return Err(Error::network(format!(
359                    "Max reconnect attempts ({}) reached",
360                    self.config.max_reconnect_attempts
361                )));
362            }
363
364            let delay = backoff.calculate_delay(attempt);
365
366            tokio::select! {
367                biased;
368                () = token.cancelled() => {
369                    self.set_state(WsConnectionState::Disconnected);
370                    return Err(Error::cancelled("Reconnection cancelled during backoff"));
371                }
372                () = tokio::time::sleep(delay) => {}
373            }
374
375            match self.connect_with_cancel(Some(token.clone())).await {
376                Ok(()) => {
377                    self.reconnect_count.store(0, Ordering::Release);
378                    return Ok(());
379                }
380                Err(e) => {
381                    if e.as_cancelled().is_some() {
382                        self.set_state(WsConnectionState::Disconnected);
383                        return Err(e);
384                    }
385
386                    let ws_error = WsError::from_error(&e);
387                    if ws_error.is_permanent() {
388                        self.set_state(WsConnectionState::Error);
389                        return Err(e);
390                    }
391                }
392            }
393        }
394    }
395
396    /// Returns the current reconnection attempt count.
397    #[inline]
398    pub fn reconnect_count(&self) -> u32 {
399        self.reconnect_count.load(Ordering::Acquire)
400    }
401
402    /// Resets the reconnection attempt counter.
403    pub fn reset_reconnect_count(&self) {
404        self.reconnect_count.store(0, Ordering::Release);
405    }
406
407    /// Increments the reconnection attempt counter.
408    pub(crate) fn increment_reconnect_count(&self) {
409        self.reconnect_count.fetch_add(1, Ordering::AcqRel);
410    }
411
412    /// Returns a snapshot of connection statistics.
413    pub fn stats(&self) -> WsStatsSnapshot {
414        self.stats.snapshot()
415    }
416
417    /// Resets all connection statistics.
418    pub fn reset_stats(&self) {
419        self.stats.reset();
420    }
421
422    /// Calculates current connection latency in milliseconds.
423    pub fn latency(&self) -> Option<i64> {
424        let last_pong = self.stats.last_pong_time();
425        let last_ping = self.stats.last_ping_time();
426        if last_pong > 0 && last_ping > 0 {
427            Some(last_pong - last_ping)
428        } else {
429            None
430        }
431    }
432
433    /// Returns the number of messages dropped due to backpressure.
434    ///
435    /// This counter is incremented when the message channel is full and
436    /// messages are dropped according to the configured backpressure strategy.
437    pub fn dropped_messages(&self) -> u32 {
438        self.dropped_messages.load(Ordering::Relaxed)
439    }
440
441    /// Resets the dropped messages counter.
442    pub fn reset_dropped_messages(&self) {
443        self.dropped_messages.store(0, Ordering::Relaxed);
444    }
445
446    /// Creates an automatic reconnection coordinator.
447    pub fn create_auto_reconnect_coordinator(self: Arc<Self>) -> AutoReconnectCoordinator {
448        AutoReconnectCoordinator::new(self)
449    }
450
451    /// Subscribes to a WebSocket channel.
452    #[instrument(name = "ws_subscribe", skip(self, params), fields(channel = %channel))]
453    pub async fn subscribe(
454        &self,
455        channel: String,
456        symbol: Option<String>,
457        params: Option<HashMap<String, Value>>,
458    ) -> Result<()> {
459        let sub_key = Self::subscription_key(&channel, symbol.as_ref());
460        let subscription = Subscription {
461            channel: channel.clone(),
462            symbol: symbol.clone(),
463            params: params.clone(),
464        };
465
466        self.subscription_manager
467            .try_add(sub_key.clone(), subscription)?;
468
469        if self.state() == WsConnectionState::Connected {
470            self.send_subscribe_message(channel, symbol, params).await?;
471        }
472
473        Ok(())
474    }
475
476    /// Unsubscribes from a WebSocket channel.
477    #[instrument(name = "ws_unsubscribe", skip(self), fields(channel = %channel))]
478    pub async fn unsubscribe(&self, channel: String, symbol: Option<String>) -> Result<()> {
479        let sub_key = Self::subscription_key(&channel, symbol.as_ref());
480        self.subscription_manager.remove(&sub_key);
481
482        if self.state() == WsConnectionState::Connected {
483            self.send_unsubscribe_message(channel, symbol).await?;
484        }
485
486        Ok(())
487    }
488
489    /// Receives the next available message.
490    pub async fn receive(&self) -> Option<Value> {
491        let mut rx = self.message_rx.write().await;
492        rx.recv().await
493    }
494
495    /// Returns the current connection state.
496    #[inline]
497    pub fn state(&self) -> WsConnectionState {
498        WsConnectionState::from_u8(self.state.load(Ordering::Acquire))
499    }
500
501    /// Returns a reference to the WebSocket configuration.
502    #[inline]
503    pub fn config(&self) -> &WsConfig {
504        &self.config
505    }
506
507    /// Sets the connection state.
508    #[inline]
509    pub fn set_state(&self, state: WsConnectionState) {
510        self.state.store(state.as_u8(), Ordering::Release);
511    }
512
513    /// Checks whether the WebSocket is currently connected.
514    #[inline]
515    pub fn is_connected(&self) -> bool {
516        self.state() == WsConnectionState::Connected
517    }
518
519    /// Checks if subscribed to a specific channel.
520    pub fn is_subscribed(&self, channel: &str, symbol: Option<&String>) -> bool {
521        let sub_key = Self::subscription_key(channel, symbol);
522        self.subscription_manager.contains(&sub_key)
523    }
524
525    /// Returns the number of active subscriptions.
526    pub fn subscription_count(&self) -> usize {
527        self.subscription_manager.count()
528    }
529
530    /// Returns the remaining capacity for new subscriptions.
531    pub fn remaining_capacity(&self) -> usize {
532        self.subscription_manager.remaining_capacity()
533    }
534
535    /// Returns a list of all active subscription channel names.
536    ///
537    /// Each subscription is identified by its channel name, optionally combined
538    /// with a symbol in the format "channel:symbol" or just "channel".
539    pub fn subscriptions(&self) -> Vec<String> {
540        self.subscription_manager
541            .iter()
542            .map(|entry| {
543                let sub = entry.value();
544                match &sub.symbol {
545                    Some(sym) => format!("{}:{}", sub.channel, sym),
546                    None => sub.channel.clone(),
547                }
548            })
549            .collect()
550    }
551
552    /// Sends a raw WebSocket message.
553    ///
554    /// This method uses a bounded channel for sending. If the write channel is full,
555    /// it will wait until space is available.
556    #[instrument(name = "ws_send", skip(self, message))]
557    pub async fn send(&self, message: Message) -> Result<()> {
558        let tx = self.write_tx.read().await;
559
560        if let Some(sender) = tx.as_ref() {
561            sender
562                .send(message)
563                .await
564                .map_err(|e| Error::network(format!("Failed to send message: {e}")))?;
565            Ok(())
566        } else {
567            Err(Error::network("WebSocket not connected"))
568        }
569    }
570
571    /// Tries to send a raw WebSocket message without blocking.
572    ///
573    /// Returns an error if the channel is full or closed.
574    #[instrument(name = "ws_try_send", skip(self, message))]
575    pub fn try_send(&self, message: Message) -> Result<()> {
576        // Note: This is a sync method, so we can't use async lock
577        // We use try_read to avoid blocking
578        if let Ok(tx) = self.write_tx.try_read() {
579            if let Some(sender) = tx.as_ref() {
580                sender.try_send(message).map_err(|e| match e {
581                    mpsc::error::TrySendError::Full(_) => {
582                        Error::network("Write channel full (backpressure)")
583                    }
584                    mpsc::error::TrySendError::Closed(_) => {
585                        Error::network("WebSocket channel closed")
586                    }
587                })?;
588                Ok(())
589            } else {
590                Err(Error::network("WebSocket not connected"))
591            }
592        } else {
593            Err(Error::network("Write channel busy"))
594        }
595    }
596
597    /// Sends a text message.
598    #[instrument(name = "ws_send_text", skip(self, text))]
599    pub async fn send_text(&self, text: String) -> Result<()> {
600        self.send(Message::Text(text.into())).await
601    }
602
603    /// Sends a JSON-encoded message.
604    #[instrument(name = "ws_send_json", skip(self, json))]
605    pub async fn send_json(&self, json: &Value) -> Result<()> {
606        let text = serde_json::to_string(json).map_err(Error::from)?;
607        self.send_text(text).await
608    }
609
610    fn subscription_key(channel: &str, symbol: Option<&String>) -> String {
611        match symbol {
612            Some(s) => format!("{channel}:{s}"),
613            None => channel.to_string(),
614        }
615    }
616
617    async fn start_message_loop(&self, ws_stream: WebSocketStream<MaybeTlsStream<TcpStream>>) {
618        let (write, mut read) = ws_stream.split();
619
620        // Use bounded channel for write operations to prevent memory exhaustion
621        let (write_tx, mut write_rx) = mpsc::channel::<Message>(self.config.write_channel_capacity);
622        *self.write_tx.write().await = Some(write_tx.clone());
623
624        // Shutdown channel remains unbounded as it's only used for signaling
625        let (shutdown_tx, mut shutdown_rx) = mpsc::unbounded_channel::<()>();
626        *self.shutdown_tx.lock().await = Some(shutdown_tx);
627
628        let state = Arc::clone(&self.state);
629        let message_tx = self.message_tx.clone();
630        let ping_interval_ms = self.config.ping_interval;
631        let backpressure_strategy = self.config.backpressure_strategy;
632        let dropped_messages = Arc::clone(&self.dropped_messages);
633
634        let write_handle = tokio::spawn(async move {
635            let mut write = write;
636            loop {
637                tokio::select! {
638                    Some(msg) = write_rx.recv() => {
639                        if let Err(e) = write.send(msg).await {
640                            error!(error = %e, "Failed to write message");
641                            break;
642                        }
643                    }
644                    _ = shutdown_rx.recv() => {
645                        let _ = write.send(Message::Close(None)).await;
646                        break;
647                    }
648                }
649            }
650        });
651
652        let state_clone = Arc::clone(&state);
653        let ws_stats = Arc::clone(&self.stats);
654        let read_handle = tokio::spawn(async move {
655            while let Some(msg_result) = read.next().await {
656                match msg_result {
657                    Ok(Message::Text(text)) => {
658                        ws_stats.record_received(text.len() as u64);
659                        if let Ok(json) = serde_json::from_str::<Value>(&text) {
660                            Self::send_with_backpressure(
661                                &message_tx,
662                                json,
663                                backpressure_strategy,
664                                &dropped_messages,
665                            )
666                            .await;
667                        }
668                    }
669                    Ok(Message::Binary(data)) => {
670                        ws_stats.record_received(data.len() as u64);
671                        if let Some(json) = String::from_utf8(data.to_vec())
672                            .ok()
673                            .and_then(|text| serde_json::from_str::<Value>(&text).ok())
674                        {
675                            Self::send_with_backpressure(
676                                &message_tx,
677                                json,
678                                backpressure_strategy,
679                                &dropped_messages,
680                            )
681                            .await;
682                        }
683                    }
684                    Ok(Message::Pong(_)) => {
685                        ws_stats.record_pong();
686                    }
687                    Ok(Message::Close(_)) => {
688                        state_clone
689                            .store(WsConnectionState::Disconnected.as_u8(), Ordering::Release);
690                        break;
691                    }
692                    Err(_) => {
693                        state_clone.store(WsConnectionState::Error.as_u8(), Ordering::Release);
694                        break;
695                    }
696                    _ => {}
697                }
698            }
699        });
700
701        if ping_interval_ms > 0 {
702            let write_tx_clone = write_tx.clone();
703            let ping_stats = Arc::clone(&self.stats);
704            let ping_state = Arc::clone(&state);
705            let pong_timeout_ms = self.config.pong_timeout;
706
707            tokio::spawn(async move {
708                let mut interval = interval(Duration::from_millis(ping_interval_ms));
709
710                loop {
711                    interval.tick().await;
712
713                    let now = chrono::Utc::now().timestamp_millis();
714                    let last_pong = ping_stats.last_pong_time();
715
716                    if last_pong > 0 {
717                        let elapsed = now - last_pong;
718                        #[allow(clippy::cast_possible_wrap)]
719                        if elapsed > pong_timeout_ms as i64 {
720                            // Pong timeout detected - log detailed diagnostics
721                            error!(
722                                pong_timeout_ms = pong_timeout_ms,
723                                elapsed_ms = elapsed,
724                                last_pong_time = last_pong,
725                                current_time = now,
726                                "WebSocket pong timeout detected - connection appears unresponsive (zombie connection)"
727                            );
728                            ping_state.store(WsConnectionState::Error.as_u8(), Ordering::Release);
729                            debug!(
730                                "WebSocket state set to Error due to pong timeout - AutoReconnectCoordinator will trigger reconnection if enabled"
731                            );
732                            break;
733                        }
734                    }
735
736                    ping_stats.record_ping();
737
738                    // Use try_send for ping to avoid blocking
739                    if write_tx_clone
740                        .try_send(Message::Ping(vec![].into()))
741                        .is_err()
742                    {
743                        debug!("WebSocket write channel closed, stopping ping loop");
744                        break;
745                    }
746                }
747            });
748        }
749
750        tokio::spawn(async move {
751            let _ = tokio::join!(write_handle, read_handle);
752        });
753    }
754
755    /// Sends a message with backpressure handling.
756    ///
757    /// This method implements the configured backpressure strategy when the
758    /// message channel is full.
759    async fn send_with_backpressure(
760        tx: &mpsc::Sender<Value>,
761        message: Value,
762        strategy: BackpressureStrategy,
763        dropped_counter: &Arc<AtomicU32>,
764    ) {
765        match strategy {
766            BackpressureStrategy::Block => {
767                // Block until space is available
768                if tx.send(message).await.is_err() {
769                    warn!("Message channel closed");
770                }
771            }
772            BackpressureStrategy::DropNewest => {
773                // Try to send, drop if full
774                match tx.try_send(message) {
775                    Ok(()) => {}
776                    Err(mpsc::error::TrySendError::Full(_)) => {
777                        let count = dropped_counter.fetch_add(1, Ordering::Relaxed) + 1;
778                        if count % 100 == 1 {
779                            // Log every 100th drop to avoid log spam
780                            warn!(
781                                dropped_count = count,
782                                "Message channel full, dropping newest message (backpressure)"
783                            );
784                        }
785                    }
786                    Err(mpsc::error::TrySendError::Closed(_)) => {
787                        warn!("Message channel closed");
788                    }
789                }
790            }
791            BackpressureStrategy::DropOldest => {
792                // Try to send, if full, make room by receiving and discarding
793                match tx.try_send(message) {
794                    Ok(()) => {}
795                    Err(mpsc::error::TrySendError::Full(msg)) => {
796                        // Channel is full, we need to drop oldest
797                        // Since we can't directly remove from the channel,
798                        // we use a permit-based approach
799                        let count = dropped_counter.fetch_add(1, Ordering::Relaxed) + 1;
800                        if count % 100 == 1 {
801                            warn!(
802                                dropped_count = count,
803                                "Message channel full, dropping oldest message (backpressure)"
804                            );
805                        }
806                        // For DropOldest, we actually drop the newest since we can't
807                        // remove from the receiver side. The semantic is that we
808                        // prioritize not blocking the read loop.
809                        // A true DropOldest would require a different data structure.
810                        drop(msg);
811                    }
812                    Err(mpsc::error::TrySendError::Closed(_)) => {
813                        warn!("Message channel closed");
814                    }
815                }
816            }
817        }
818    }
819
820    async fn send_subscribe_message(
821        &self,
822        channel: String,
823        symbol: Option<String>,
824        params: Option<HashMap<String, Value>>,
825    ) -> Result<()> {
826        let msg = WsMessage::Subscribe {
827            channel,
828            symbol,
829            params,
830        };
831        let json = serde_json::to_value(&msg).map_err(Error::from)?;
832        self.send_json(&json).await
833    }
834
835    async fn send_unsubscribe_message(
836        &self,
837        channel: String,
838        symbol: Option<String>,
839    ) -> Result<()> {
840        let msg = WsMessage::Unsubscribe { channel, symbol };
841        let json = serde_json::to_value(&msg).map_err(Error::from)?;
842        self.send_json(&json).await
843    }
844
845    pub(crate) async fn resubscribe_all(&self) -> Result<()> {
846        let subs = self.subscription_manager.collect_subscriptions();
847        for subscription in subs {
848            self.send_subscribe_message(
849                subscription.channel.clone(),
850                subscription.symbol.clone(),
851                subscription.params.clone(),
852            )
853            .await?;
854        }
855        Ok(())
856    }
857}
858
859#[cfg(test)]
860mod tests {
861    use super::*;
862
863    #[test]
864    fn test_backoff_config_default() {
865        let config = BackoffConfig::default();
866        assert_eq!(config.base_delay, Duration::from_secs(1));
867        assert_eq!(config.max_delay, Duration::from_secs(60));
868    }
869
870    #[test]
871    fn test_backoff_strategy_exponential_growth_no_jitter() {
872        let config = BackoffConfig {
873            base_delay: Duration::from_secs(1),
874            max_delay: Duration::from_secs(60),
875            jitter_factor: 0.0,
876            multiplier: 2.0,
877        };
878        let strategy = BackoffStrategy::new(config);
879
880        assert_eq!(strategy.calculate_delay(0), Duration::from_secs(1));
881        assert_eq!(strategy.calculate_delay(1), Duration::from_secs(2));
882        assert_eq!(strategy.calculate_delay(2), Duration::from_secs(4));
883        assert_eq!(strategy.calculate_delay(6), Duration::from_secs(60));
884    }
885
886    #[test]
887    fn test_ws_config_default() {
888        let config = WsConfig::default();
889        assert_eq!(config.connect_timeout, 10000);
890        assert_eq!(config.max_subscriptions, DEFAULT_MAX_SUBSCRIPTIONS);
891    }
892
893    #[test]
894    fn test_subscription_key() {
895        let key1 = WsClient::subscription_key("ticker", Some(&"BTC/USDT".to_string()));
896        assert_eq!(key1, "ticker:BTC/USDT");
897
898        let key2 = WsClient::subscription_key("trades", None);
899        assert_eq!(key2, "trades");
900    }
901
902    #[tokio::test]
903    async fn test_ws_client_creation() {
904        let config = WsConfig {
905            url: "wss://example.com/ws".to_string(),
906            ..Default::default()
907        };
908
909        let client = WsClient::new(config);
910        assert_eq!(client.state(), WsConnectionState::Disconnected);
911        assert!(!client.is_connected());
912    }
913
914    #[tokio::test]
915    async fn test_subscribe_adds_subscription() {
916        let config = WsConfig {
917            url: "wss://example.com/ws".to_string(),
918            ..Default::default()
919        };
920
921        let client = WsClient::new(config);
922        let result = client
923            .subscribe("ticker".to_string(), Some("BTC/USDT".to_string()), None)
924            .await;
925        assert!(result.is_ok());
926        assert_eq!(client.subscription_count(), 1);
927        assert!(client.is_subscribed("ticker", Some(&"BTC/USDT".to_string())));
928    }
929
930    #[test]
931    fn test_ws_connection_state_from_u8() {
932        assert_eq!(
933            WsConnectionState::from_u8(0),
934            WsConnectionState::Disconnected
935        );
936        assert_eq!(WsConnectionState::from_u8(1), WsConnectionState::Connecting);
937        assert_eq!(WsConnectionState::from_u8(2), WsConnectionState::Connected);
938        assert_eq!(WsConnectionState::from_u8(255), WsConnectionState::Error);
939    }
940
941    #[test]
942    fn test_ws_error_kind() {
943        assert!(WsErrorKind::Transient.is_transient());
944        assert!(WsErrorKind::Permanent.is_permanent());
945    }
946
947    #[test]
948    fn test_ws_error_creation() {
949        let err = WsError::transient("Connection timeout");
950        assert!(err.is_transient());
951        assert_eq!(err.message(), "Connection timeout");
952
953        let err = WsError::permanent("Invalid API key");
954        assert!(err.is_permanent());
955    }
956
957    #[test]
958    fn test_subscription_manager() {
959        let manager = SubscriptionManager::new(2);
960        assert_eq!(manager.max_subscriptions(), 2);
961        assert_eq!(manager.count(), 0);
962        assert!(!manager.is_full());
963
964        let sub = Subscription {
965            channel: "ticker".to_string(),
966            symbol: Some("BTC/USDT".to_string()),
967            params: None,
968        };
969        assert!(manager.try_add("ticker:BTC/USDT".to_string(), sub).is_ok());
970        assert_eq!(manager.count(), 1);
971    }
972}