Skip to main content

ccxt_core/ws_client/
mod.rs

1//! WebSocket client module.
2//!
3//! Provides asynchronous WebSocket connection management, subscription handling,
4//! and heartbeat maintenance for cryptocurrency exchange streaming APIs.
5
6mod config;
7mod error;
8mod event;
9mod message;
10mod reconnect;
11mod state;
12mod subscription;
13
14pub use config::{
15    BackoffConfig, BackoffStrategy, BackpressureStrategy, DEFAULT_MAX_SUBSCRIPTIONS,
16    DEFAULT_MESSAGE_CHANNEL_CAPACITY, DEFAULT_SHUTDOWN_TIMEOUT, DEFAULT_WRITE_CHANNEL_CAPACITY,
17    WsConfig,
18};
19pub use error::{WsError, WsErrorKind};
20pub use event::{WsEvent, WsEventCallback};
21pub use message::WsMessage;
22pub use reconnect::AutoReconnectCoordinator;
23pub use state::{WsConnectionState, WsStats, WsStatsSnapshot};
24pub use subscription::{Subscription, SubscriptionManager};
25
26use crate::error::{Error, Result};
27use derive_more::Debug;
28use futures_util::{SinkExt, StreamExt, stream::SplitSink};
29use serde_json::Value;
30use std::collections::HashMap;
31use std::sync::Arc;
32use std::sync::atomic::{AtomicU8, AtomicU32, Ordering};
33use tokio::net::TcpStream;
34use tokio::sync::{Mutex, RwLock, mpsc};
35use tokio::time::{Duration, interval};
36use tokio_tungstenite::{
37    MaybeTlsStream, WebSocketStream, connect_async_with_config,
38    tungstenite::protocol::{Message, WebSocketConfig},
39};
40use tokio_util::sync::CancellationToken;
41use tracing::{debug, error, info, instrument, warn};
42
43/// Type alias for WebSocket write half.
44#[allow(dead_code)]
45type WsWriter = SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>;
46
47/// Async WebSocket client for exchange streaming APIs.
48///
49/// # Backpressure Handling
50///
51/// This client uses bounded channels to prevent memory exhaustion in high-frequency
52/// trading scenarios. When the message channel is full, the configured
53/// `BackpressureStrategy` determines how to handle new messages:
54///
55/// - `DropOldest`: Removes the oldest message to make room (default)
56/// - `DropNewest`: Discards the incoming message
57/// - `Block`: Waits until space is available (may stall the read loop)
58#[derive(Debug)]
59pub struct WsClient {
60    config: WsConfig,
61    state: Arc<AtomicU8>,
62    subscription_manager: SubscriptionManager,
63    message_tx: mpsc::Sender<Value>,
64    message_rx: Arc<RwLock<mpsc::Receiver<Value>>>,
65    write_tx: Arc<RwLock<Option<mpsc::Sender<Message>>>>,
66    pub(crate) reconnect_count: AtomicU32,
67    shutdown_tx: Arc<Mutex<Option<mpsc::UnboundedSender<()>>>>,
68    stats: Arc<WsStats>,
69    cancel_token: Arc<Mutex<Option<CancellationToken>>>,
70    #[debug(skip)]
71    event_callback: Arc<Mutex<Option<WsEventCallback>>>,
72    /// Counter for dropped messages due to backpressure
73    dropped_messages: Arc<AtomicU32>,
74}
75
76impl WsClient {
77    /// Creates a new WebSocket client instance.
78    ///
79    /// The client uses bounded channels for message passing to prevent memory
80    /// exhaustion. Channel capacities are configured via `WsConfig`.
81    pub fn new(config: WsConfig) -> Self {
82        let (message_tx, message_rx) = mpsc::channel(config.message_channel_capacity);
83        let max_subscriptions = config.max_subscriptions;
84
85        Self {
86            config,
87            state: Arc::new(AtomicU8::new(WsConnectionState::Disconnected.as_u8())),
88            subscription_manager: SubscriptionManager::new(max_subscriptions),
89            message_tx,
90            message_rx: Arc::new(RwLock::new(message_rx)),
91            write_tx: Arc::new(RwLock::new(None)),
92            reconnect_count: AtomicU32::new(0),
93            shutdown_tx: Arc::new(Mutex::new(None)),
94            stats: Arc::new(WsStats::new()),
95            cancel_token: Arc::new(Mutex::new(None)),
96            event_callback: Arc::new(Mutex::new(None)),
97            dropped_messages: Arc::new(AtomicU32::new(0)),
98        }
99    }
100
101    /// Sets the event callback for connection lifecycle events.
102    pub async fn set_event_callback(&self, callback: WsEventCallback) {
103        *self.event_callback.lock().await = Some(callback);
104        debug!("Event callback set");
105    }
106
107    /// Clears the event callback.
108    pub async fn clear_event_callback(&self) {
109        *self.event_callback.lock().await = None;
110        debug!("Event callback cleared");
111    }
112
113    async fn emit_event(&self, event: WsEvent) {
114        let callback = self.event_callback.lock().await;
115        if let Some(ref cb) = *callback {
116            let cb = Arc::clone(cb);
117            drop(callback);
118            tokio::spawn(async move {
119                cb(event);
120            });
121        }
122    }
123
124    /// Sets the cancellation token for this client.
125    pub async fn set_cancel_token(&self, token: CancellationToken) {
126        *self.cancel_token.lock().await = Some(token);
127        debug!("Cancellation token set");
128    }
129
130    /// Clears the cancellation token.
131    pub async fn clear_cancel_token(&self) {
132        *self.cancel_token.lock().await = None;
133        debug!("Cancellation token cleared");
134    }
135
136    /// Returns a clone of the current cancellation token, if set.
137    pub async fn get_cancel_token(&self) -> Option<CancellationToken> {
138        self.cancel_token.lock().await.clone()
139    }
140
141    /// Establishes connection to the WebSocket server.
142    #[instrument(
143        name = "ws_connect",
144        skip(self),
145        fields(url = %self.config.url, timeout_ms = self.config.connect_timeout)
146    )]
147    pub async fn connect(&self) -> Result<()> {
148        if self.state() == WsConnectionState::Connected {
149            info!("WebSocket already connected");
150            return Ok(());
151        }
152
153        self.set_state(WsConnectionState::Connecting);
154
155        let url = self.config.url.clone();
156        info!("Initiating WebSocket connection");
157
158        let mut ws_config = WebSocketConfig::default();
159        if let Some(limit) = self.config.max_message_size {
160            ws_config.max_message_size = Some(limit);
161        }
162        if let Some(limit) = self.config.max_frame_size {
163            ws_config.max_frame_size = Some(limit);
164        }
165
166        match tokio::time::timeout(
167            Duration::from_millis(self.config.connect_timeout),
168            connect_async_with_config(&url, Some(ws_config), false),
169        )
170        .await
171        {
172            Ok(Ok((ws_stream, response))) => {
173                info!(
174                    status = response.status().as_u16(),
175                    "WebSocket connection established successfully"
176                );
177
178                self.set_state(WsConnectionState::Connected);
179                self.reconnect_count.store(0, Ordering::Release);
180                self.stats.record_connected();
181                self.start_message_loop(ws_stream).await;
182                self.resubscribe_all().await?;
183
184                Ok(())
185            }
186            Ok(Err(e)) => {
187                error!(error = %e, "WebSocket connection failed");
188                self.set_state(WsConnectionState::Error);
189                Err(Error::network(format!("WebSocket connection failed: {e}")))
190            }
191            Err(_) => {
192                error!(
193                    timeout_ms = self.config.connect_timeout,
194                    "WebSocket connection timeout"
195                );
196                self.set_state(WsConnectionState::Error);
197                Err(Error::timeout("WebSocket connection timeout"))
198            }
199        }
200    }
201
202    /// Establishes connection with cancellation support.
203    #[instrument(
204        name = "ws_connect_with_cancel",
205        skip(self, cancel_token),
206        fields(url = %self.config.url)
207    )]
208    pub async fn connect_with_cancel(&self, cancel_token: Option<CancellationToken>) -> Result<()> {
209        let token = if let Some(t) = cancel_token {
210            t
211        } else {
212            let internal_token = self.cancel_token.lock().await;
213            internal_token
214                .clone()
215                .unwrap_or_else(CancellationToken::new)
216        };
217
218        if self.state() == WsConnectionState::Connected {
219            info!("WebSocket already connected");
220            return Ok(());
221        }
222
223        self.set_state(WsConnectionState::Connecting);
224        let url = self.config.url.clone();
225
226        let mut ws_config = WebSocketConfig::default();
227        if let Some(limit) = self.config.max_message_size {
228            ws_config.max_message_size = Some(limit);
229        }
230        if let Some(limit) = self.config.max_frame_size {
231            ws_config.max_frame_size = Some(limit);
232        }
233
234        tokio::select! {
235            biased;
236            () = token.cancelled() => {
237                warn!("WebSocket connection cancelled");
238                self.set_state(WsConnectionState::Disconnected);
239                Err(Error::cancelled("WebSocket connection cancelled"))
240            }
241            result = tokio::time::timeout(
242                Duration::from_millis(self.config.connect_timeout),
243                connect_async_with_config(&url, Some(ws_config), false),
244            ) => {
245                match result {
246                    Ok(Ok((ws_stream, response))) => {
247                        info!(status = response.status().as_u16(), "WebSocket connected");
248                        self.set_state(WsConnectionState::Connected);
249                        self.reconnect_count.store(0, Ordering::Release);
250                        self.stats.record_connected();
251                        self.start_message_loop(ws_stream).await;
252                        self.resubscribe_all().await?;
253                        Ok(())
254                    }
255                    Ok(Err(e)) => {
256                        error!(error = %e, "WebSocket connection failed");
257                        self.set_state(WsConnectionState::Error);
258                        Err(Error::network(format!("WebSocket connection failed: {e}")))
259                    }
260                    Err(_) => {
261                        error!("WebSocket connection timeout");
262                        self.set_state(WsConnectionState::Error);
263                        Err(Error::timeout("WebSocket connection timeout"))
264                    }
265                }
266            }
267        }
268    }
269
270    /// Closes the WebSocket connection gracefully.
271    #[instrument(name = "ws_disconnect", skip(self))]
272    pub async fn disconnect(&self) -> Result<()> {
273        info!("Initiating WebSocket disconnect");
274
275        if let Some(tx) = self.shutdown_tx.lock().await.as_ref() {
276            let _ = tx.send(());
277        }
278
279        *self.write_tx.write().await = None;
280        self.set_state(WsConnectionState::Disconnected);
281
282        info!("WebSocket disconnected");
283        Ok(())
284    }
285
286    /// Gracefully shuts down the WebSocket client.
287    #[instrument(name = "ws_shutdown", skip(self))]
288    pub async fn shutdown(&self) {
289        info!("Initiating graceful shutdown");
290
291        {
292            let token_guard = self.cancel_token.lock().await;
293            if let Some(ref token) = *token_guard {
294                token.cancel();
295            }
296        }
297
298        self.set_state(WsConnectionState::Disconnected);
299
300        {
301            let write_tx_guard = self.write_tx.read().await;
302            if let Some(ref tx) = *write_tx_guard {
303                // Ignore send result - we're shutting down anyway
304                drop(tx.send(Message::Close(None)).await);
305            }
306        }
307
308        let shutdown_timeout = Duration::from_millis(self.config.shutdown_timeout);
309        let _ = tokio::time::timeout(shutdown_timeout, async {
310            if let Some(tx) = self.shutdown_tx.lock().await.as_ref() {
311                let _ = tx.send(());
312            }
313            tokio::time::sleep(Duration::from_millis(100)).await;
314        })
315        .await;
316
317        {
318            *self.write_tx.write().await = None;
319            *self.shutdown_tx.lock().await = None;
320            self.subscription_manager.clear();
321            self.reconnect_count.store(0, Ordering::Release);
322            self.dropped_messages.store(0, Ordering::Relaxed);
323            self.stats.reset();
324        }
325
326        self.emit_event(WsEvent::Shutdown).await;
327        info!("Graceful shutdown completed");
328    }
329
330    /// Attempts to reconnect to the WebSocket server.
331    #[instrument(name = "ws_reconnect", skip(self))]
332    pub async fn reconnect(&self) -> Result<()> {
333        let count = self.reconnect_count.fetch_add(1, Ordering::AcqRel) + 1;
334
335        if count > self.config.max_reconnect_attempts {
336            error!(attempts = count, "Max reconnect attempts reached");
337            return Err(Error::network("Max reconnect attempts reached"));
338        }
339
340        warn!(attempt = count, "Attempting WebSocket reconnection");
341        self.set_state(WsConnectionState::Reconnecting);
342
343        tokio::time::sleep(Duration::from_millis(self.config.reconnect_interval)).await;
344        self.connect().await
345    }
346
347    /// Attempts to reconnect with cancellation support.
348    #[instrument(name = "ws_reconnect_with_cancel", skip(self, cancel_token))]
349    pub async fn reconnect_with_cancel(
350        &self,
351        cancel_token: Option<CancellationToken>,
352    ) -> Result<()> {
353        let token = if let Some(t) = cancel_token {
354            t
355        } else {
356            let internal_token = self.cancel_token.lock().await;
357            internal_token
358                .clone()
359                .unwrap_or_else(CancellationToken::new)
360        };
361
362        let backoff = BackoffStrategy::new(self.config.backoff_config.clone());
363        self.set_state(WsConnectionState::Reconnecting);
364
365        loop {
366            if token.is_cancelled() {
367                self.set_state(WsConnectionState::Disconnected);
368                return Err(Error::cancelled("Reconnection cancelled"));
369            }
370
371            let attempt = self.reconnect_count.fetch_add(1, Ordering::AcqRel);
372
373            if attempt >= self.config.max_reconnect_attempts {
374                self.set_state(WsConnectionState::Error);
375                return Err(Error::network(format!(
376                    "Max reconnect attempts ({}) reached",
377                    self.config.max_reconnect_attempts
378                )));
379            }
380
381            let delay = backoff.calculate_delay(attempt);
382
383            tokio::select! {
384                biased;
385                () = token.cancelled() => {
386                    self.set_state(WsConnectionState::Disconnected);
387                    return Err(Error::cancelled("Reconnection cancelled during backoff"));
388                }
389                () = tokio::time::sleep(delay) => {}
390            }
391
392            match self.connect_with_cancel(Some(token.clone())).await {
393                Ok(()) => {
394                    self.reconnect_count.store(0, Ordering::Release);
395                    return Ok(());
396                }
397                Err(e) => {
398                    if e.as_cancelled().is_some() {
399                        self.set_state(WsConnectionState::Disconnected);
400                        return Err(e);
401                    }
402
403                    let ws_error = WsError::from_error(&e);
404                    if ws_error.is_permanent() {
405                        self.set_state(WsConnectionState::Error);
406                        return Err(e);
407                    }
408                }
409            }
410        }
411    }
412
413    /// Returns the current reconnection attempt count.
414    #[inline]
415    pub fn reconnect_count(&self) -> u32 {
416        self.reconnect_count.load(Ordering::Acquire)
417    }
418
419    /// Resets the reconnection attempt counter.
420    pub fn reset_reconnect_count(&self) {
421        self.reconnect_count.store(0, Ordering::Release);
422    }
423
424    /// Increments the reconnection attempt counter.
425    pub(crate) fn increment_reconnect_count(&self) {
426        self.reconnect_count.fetch_add(1, Ordering::AcqRel);
427    }
428
429    /// Returns a snapshot of connection statistics.
430    pub fn stats(&self) -> WsStatsSnapshot {
431        self.stats.snapshot()
432    }
433
434    /// Resets all connection statistics.
435    pub fn reset_stats(&self) {
436        self.stats.reset();
437    }
438
439    /// Calculates current connection latency in milliseconds.
440    pub fn latency(&self) -> Option<i64> {
441        let last_pong = self.stats.last_pong_time();
442        let last_ping = self.stats.last_ping_time();
443        if last_pong > 0 && last_ping > 0 {
444            Some(last_pong - last_ping)
445        } else {
446            None
447        }
448    }
449
450    /// Returns the number of messages dropped due to backpressure.
451    ///
452    /// This counter is incremented when the message channel is full and
453    /// messages are dropped according to the configured backpressure strategy.
454    pub fn dropped_messages(&self) -> u32 {
455        self.dropped_messages.load(Ordering::Relaxed)
456    }
457
458    /// Resets the dropped messages counter.
459    pub fn reset_dropped_messages(&self) {
460        self.dropped_messages.store(0, Ordering::Relaxed);
461    }
462
463    /// Creates an automatic reconnection coordinator.
464    pub fn create_auto_reconnect_coordinator(self: Arc<Self>) -> AutoReconnectCoordinator {
465        AutoReconnectCoordinator::new(self)
466    }
467
468    /// Subscribes to a WebSocket channel.
469    #[instrument(name = "ws_subscribe", skip(self, params), fields(channel = %channel))]
470    pub async fn subscribe(
471        &self,
472        channel: String,
473        symbol: Option<String>,
474        params: Option<HashMap<String, Value>>,
475    ) -> Result<()> {
476        let sub_key = Self::subscription_key(&channel, symbol.as_ref());
477        let subscription = Subscription {
478            channel: channel.clone(),
479            symbol: symbol.clone(),
480            params: params.clone(),
481        };
482
483        self.subscription_manager
484            .try_add(sub_key.clone(), subscription)?;
485
486        if self.state() == WsConnectionState::Connected {
487            self.send_subscribe_message(channel, symbol, params).await?;
488        }
489
490        Ok(())
491    }
492
493    /// Unsubscribes from a WebSocket channel.
494    #[instrument(name = "ws_unsubscribe", skip(self), fields(channel = %channel))]
495    pub async fn unsubscribe(&self, channel: String, symbol: Option<String>) -> Result<()> {
496        let sub_key = Self::subscription_key(&channel, symbol.as_ref());
497        self.subscription_manager.remove(&sub_key);
498
499        if self.state() == WsConnectionState::Connected {
500            self.send_unsubscribe_message(channel, symbol).await?;
501        }
502
503        Ok(())
504    }
505
506    /// Receives the next available message.
507    pub async fn receive(&self) -> Option<Value> {
508        let mut rx = self.message_rx.write().await;
509        rx.recv().await
510    }
511
512    /// Returns the current connection state.
513    #[inline]
514    pub fn state(&self) -> WsConnectionState {
515        WsConnectionState::from_u8(self.state.load(Ordering::Acquire))
516    }
517
518    /// Returns a reference to the WebSocket configuration.
519    #[inline]
520    pub fn config(&self) -> &WsConfig {
521        &self.config
522    }
523
524    /// Sets the connection state.
525    #[inline]
526    pub fn set_state(&self, state: WsConnectionState) {
527        self.state.store(state.as_u8(), Ordering::Release);
528    }
529
530    /// Checks whether the WebSocket is currently connected.
531    #[inline]
532    pub fn is_connected(&self) -> bool {
533        self.state() == WsConnectionState::Connected
534    }
535
536    /// Checks if subscribed to a specific channel.
537    pub fn is_subscribed(&self, channel: &str, symbol: Option<&String>) -> bool {
538        let sub_key = Self::subscription_key(channel, symbol);
539        self.subscription_manager.contains(&sub_key)
540    }
541
542    /// Returns the number of active subscriptions.
543    pub fn subscription_count(&self) -> usize {
544        self.subscription_manager.count()
545    }
546
547    /// Returns the remaining capacity for new subscriptions.
548    pub fn remaining_capacity(&self) -> usize {
549        self.subscription_manager.remaining_capacity()
550    }
551
552    /// Returns a list of all active subscription channel names.
553    ///
554    /// Each subscription is identified by its channel name, optionally combined
555    /// with a symbol in the format "channel:symbol" or just "channel".
556    pub fn subscriptions(&self) -> Vec<String> {
557        self.subscription_manager
558            .iter()
559            .map(|entry| {
560                let sub = entry.value();
561                match &sub.symbol {
562                    Some(sym) => format!("{}:{}", sub.channel, sym),
563                    None => sub.channel.clone(),
564                }
565            })
566            .collect()
567    }
568
569    /// Sends a raw WebSocket message.
570    ///
571    /// This method uses a bounded channel for sending. If the write channel is full,
572    /// it will wait until space is available.
573    #[instrument(name = "ws_send", skip(self, message))]
574    pub async fn send(&self, message: Message) -> Result<()> {
575        let tx = self.write_tx.read().await;
576
577        if let Some(sender) = tx.as_ref() {
578            sender
579                .send(message)
580                .await
581                .map_err(|e| Error::network(format!("Failed to send message: {e}")))?;
582            Ok(())
583        } else {
584            Err(Error::network("WebSocket not connected"))
585        }
586    }
587
588    /// Tries to send a raw WebSocket message without blocking.
589    ///
590    /// Returns an error if the channel is full or closed.
591    #[instrument(name = "ws_try_send", skip(self, message))]
592    pub fn try_send(&self, message: Message) -> Result<()> {
593        // Note: This is a sync method, so we can't use async lock
594        // We use try_read to avoid blocking
595        if let Ok(tx) = self.write_tx.try_read() {
596            if let Some(sender) = tx.as_ref() {
597                sender.try_send(message).map_err(|e| match e {
598                    mpsc::error::TrySendError::Full(_) => {
599                        Error::network("Write channel full (backpressure)")
600                    }
601                    mpsc::error::TrySendError::Closed(_) => {
602                        Error::network("WebSocket channel closed")
603                    }
604                })?;
605                Ok(())
606            } else {
607                Err(Error::network("WebSocket not connected"))
608            }
609        } else {
610            Err(Error::network("Write channel busy"))
611        }
612    }
613
614    /// Sends a text message.
615    #[instrument(name = "ws_send_text", skip(self, text))]
616    pub async fn send_text(&self, text: String) -> Result<()> {
617        self.send(Message::Text(text.into())).await
618    }
619
620    /// Sends a JSON-encoded message.
621    #[instrument(name = "ws_send_json", skip(self, json))]
622    pub async fn send_json(&self, json: &Value) -> Result<()> {
623        let text = serde_json::to_string(json).map_err(Error::from)?;
624        self.send_text(text).await
625    }
626
627    fn subscription_key(channel: &str, symbol: Option<&String>) -> String {
628        match symbol {
629            Some(s) => format!("{channel}:{s}"),
630            None => channel.to_string(),
631        }
632    }
633
634    async fn start_message_loop(&self, ws_stream: WebSocketStream<MaybeTlsStream<TcpStream>>) {
635        let (write, mut read) = ws_stream.split();
636
637        // Use bounded channel for write operations to prevent memory exhaustion
638        let (write_tx, mut write_rx) = mpsc::channel::<Message>(self.config.write_channel_capacity);
639        *self.write_tx.write().await = Some(write_tx.clone());
640
641        // Shutdown channel remains unbounded as it's only used for signaling
642        let (shutdown_tx, mut shutdown_rx) = mpsc::unbounded_channel::<()>();
643        *self.shutdown_tx.lock().await = Some(shutdown_tx);
644
645        let state = Arc::clone(&self.state);
646        let message_tx = self.message_tx.clone();
647        let ping_interval_ms = self.config.ping_interval;
648        let backpressure_strategy = self.config.backpressure_strategy;
649        let dropped_messages = Arc::clone(&self.dropped_messages);
650
651        let write_handle = tokio::spawn(async move {
652            let mut write = write;
653            loop {
654                tokio::select! {
655                    Some(msg) = write_rx.recv() => {
656                        if let Err(e) = write.send(msg).await {
657                            error!(error = %e, "Failed to write message");
658                            break;
659                        }
660                    }
661                    _ = shutdown_rx.recv() => {
662                        let _ = write.send(Message::Close(None)).await;
663                        break;
664                    }
665                }
666            }
667        });
668
669        let state_clone = Arc::clone(&state);
670        let ws_stats = Arc::clone(&self.stats);
671        let read_handle = tokio::spawn(async move {
672            while let Some(msg_result) = read.next().await {
673                match msg_result {
674                    Ok(Message::Text(text)) => {
675                        ws_stats.record_received(text.len() as u64);
676                        if let Ok(json) = serde_json::from_str::<Value>(&text) {
677                            Self::send_with_backpressure(
678                                &message_tx,
679                                json,
680                                backpressure_strategy,
681                                &dropped_messages,
682                            )
683                            .await;
684                        }
685                    }
686                    Ok(Message::Binary(data)) => {
687                        ws_stats.record_received(data.len() as u64);
688                        if let Some(json) = String::from_utf8(data.to_vec())
689                            .ok()
690                            .and_then(|text| serde_json::from_str::<Value>(&text).ok())
691                        {
692                            Self::send_with_backpressure(
693                                &message_tx,
694                                json,
695                                backpressure_strategy,
696                                &dropped_messages,
697                            )
698                            .await;
699                        }
700                    }
701                    Ok(Message::Pong(_)) => {
702                        ws_stats.record_pong();
703                    }
704                    Ok(Message::Close(_)) => {
705                        state_clone
706                            .store(WsConnectionState::Disconnected.as_u8(), Ordering::Release);
707                        break;
708                    }
709                    Err(_) => {
710                        state_clone.store(WsConnectionState::Error.as_u8(), Ordering::Release);
711                        break;
712                    }
713                    _ => {}
714                }
715            }
716        });
717
718        if ping_interval_ms > 0 {
719            let write_tx_clone = write_tx.clone();
720            let ping_stats = Arc::clone(&self.stats);
721            let ping_state = Arc::clone(&state);
722            let pong_timeout_ms = self.config.pong_timeout;
723
724            tokio::spawn(async move {
725                let mut interval = interval(Duration::from_millis(ping_interval_ms));
726
727                loop {
728                    interval.tick().await;
729
730                    let now = chrono::Utc::now().timestamp_millis();
731                    let last_pong = ping_stats.last_pong_time();
732
733                    if last_pong > 0 {
734                        let elapsed = now - last_pong;
735                        #[allow(clippy::cast_possible_wrap)]
736                        if elapsed > pong_timeout_ms as i64 {
737                            // Pong timeout detected - log detailed diagnostics
738                            error!(
739                                pong_timeout_ms = pong_timeout_ms,
740                                elapsed_ms = elapsed,
741                                last_pong_time = last_pong,
742                                current_time = now,
743                                "WebSocket pong timeout detected - connection appears unresponsive (zombie connection)"
744                            );
745                            ping_state.store(WsConnectionState::Error.as_u8(), Ordering::Release);
746                            debug!(
747                                "WebSocket state set to Error due to pong timeout - AutoReconnectCoordinator will trigger reconnection if enabled"
748                            );
749                            break;
750                        }
751                    }
752
753                    ping_stats.record_ping();
754
755                    // Use try_send for ping to avoid blocking
756                    if write_tx_clone
757                        .try_send(Message::Ping(vec![].into()))
758                        .is_err()
759                    {
760                        debug!("WebSocket write channel closed, stopping ping loop");
761                        break;
762                    }
763                }
764            });
765        }
766
767        tokio::spawn(async move {
768            let _ = tokio::join!(write_handle, read_handle);
769        });
770    }
771
772    /// Sends a message with backpressure handling.
773    ///
774    /// This method implements the configured backpressure strategy when the
775    /// message channel is full.
776    async fn send_with_backpressure(
777        tx: &mpsc::Sender<Value>,
778        message: Value,
779        strategy: BackpressureStrategy,
780        dropped_counter: &Arc<AtomicU32>,
781    ) {
782        match strategy {
783            BackpressureStrategy::Block => {
784                // Block until space is available
785                if tx.send(message).await.is_err() {
786                    warn!("Message channel closed");
787                }
788            }
789            BackpressureStrategy::DropNewest => {
790                // Try to send, drop if full
791                match tx.try_send(message) {
792                    Ok(()) => {}
793                    Err(mpsc::error::TrySendError::Full(_)) => {
794                        let count = dropped_counter.fetch_add(1, Ordering::Relaxed) + 1;
795                        if count % 100 == 1 {
796                            // Log every 100th drop to avoid log spam
797                            warn!(
798                                dropped_count = count,
799                                "Message channel full, dropping newest message (backpressure)"
800                            );
801                        }
802                    }
803                    Err(mpsc::error::TrySendError::Closed(_)) => {
804                        warn!("Message channel closed");
805                    }
806                }
807            }
808            BackpressureStrategy::DropOldest => {
809                // Try to send, if full, make room by receiving and discarding
810                match tx.try_send(message) {
811                    Ok(()) => {}
812                    Err(mpsc::error::TrySendError::Full(msg)) => {
813                        // Channel is full, we need to drop oldest
814                        // Since we can't directly remove from the channel,
815                        // we use a permit-based approach
816                        let count = dropped_counter.fetch_add(1, Ordering::Relaxed) + 1;
817                        if count % 100 == 1 {
818                            warn!(
819                                dropped_count = count,
820                                "Message channel full, dropping oldest message (backpressure)"
821                            );
822                        }
823                        // For DropOldest, we actually drop the newest since we can't
824                        // remove from the receiver side. The semantic is that we
825                        // prioritize not blocking the read loop.
826                        // A true DropOldest would require a different data structure.
827                        drop(msg);
828                    }
829                    Err(mpsc::error::TrySendError::Closed(_)) => {
830                        warn!("Message channel closed");
831                    }
832                }
833            }
834        }
835    }
836
837    async fn send_subscribe_message(
838        &self,
839        channel: String,
840        symbol: Option<String>,
841        params: Option<HashMap<String, Value>>,
842    ) -> Result<()> {
843        let msg = WsMessage::Subscribe {
844            channel,
845            symbol,
846            params,
847        };
848        let json = serde_json::to_value(&msg).map_err(Error::from)?;
849        self.send_json(&json).await
850    }
851
852    async fn send_unsubscribe_message(
853        &self,
854        channel: String,
855        symbol: Option<String>,
856    ) -> Result<()> {
857        let msg = WsMessage::Unsubscribe { channel, symbol };
858        let json = serde_json::to_value(&msg).map_err(Error::from)?;
859        self.send_json(&json).await
860    }
861
862    pub(crate) async fn resubscribe_all(&self) -> Result<()> {
863        let subs = self.subscription_manager.collect_subscriptions();
864        for subscription in subs {
865            self.send_subscribe_message(
866                subscription.channel.clone(),
867                subscription.symbol.clone(),
868                subscription.params.clone(),
869            )
870            .await?;
871        }
872        Ok(())
873    }
874}
875
876#[cfg(test)]
877mod tests {
878    use super::*;
879
880    #[test]
881    fn test_backoff_config_default() {
882        let config = BackoffConfig::default();
883        assert_eq!(config.base_delay, Duration::from_secs(1));
884        assert_eq!(config.max_delay, Duration::from_secs(60));
885    }
886
887    #[test]
888    fn test_backoff_strategy_exponential_growth_no_jitter() {
889        let config = BackoffConfig {
890            base_delay: Duration::from_secs(1),
891            max_delay: Duration::from_secs(60),
892            jitter_factor: 0.0,
893            multiplier: 2.0,
894        };
895        let strategy = BackoffStrategy::new(config);
896
897        assert_eq!(strategy.calculate_delay(0), Duration::from_secs(1));
898        assert_eq!(strategy.calculate_delay(1), Duration::from_secs(2));
899        assert_eq!(strategy.calculate_delay(2), Duration::from_secs(4));
900        assert_eq!(strategy.calculate_delay(6), Duration::from_secs(60));
901    }
902
903    #[test]
904    fn test_ws_config_default() {
905        let config = WsConfig::default();
906        assert_eq!(config.connect_timeout, 10000);
907        assert_eq!(config.max_subscriptions, DEFAULT_MAX_SUBSCRIPTIONS);
908    }
909
910    #[test]
911    fn test_subscription_key() {
912        let key1 = WsClient::subscription_key("ticker", Some(&"BTC/USDT".to_string()));
913        assert_eq!(key1, "ticker:BTC/USDT");
914
915        let key2 = WsClient::subscription_key("trades", None);
916        assert_eq!(key2, "trades");
917    }
918
919    #[tokio::test]
920    async fn test_ws_client_creation() {
921        let config = WsConfig {
922            url: "wss://example.com/ws".to_string(),
923            ..Default::default()
924        };
925
926        let client = WsClient::new(config);
927        assert_eq!(client.state(), WsConnectionState::Disconnected);
928        assert!(!client.is_connected());
929    }
930
931    #[tokio::test]
932    async fn test_subscribe_adds_subscription() {
933        let config = WsConfig {
934            url: "wss://example.com/ws".to_string(),
935            ..Default::default()
936        };
937
938        let client = WsClient::new(config);
939        let result = client
940            .subscribe("ticker".to_string(), Some("BTC/USDT".to_string()), None)
941            .await;
942        assert!(result.is_ok());
943        assert_eq!(client.subscription_count(), 1);
944        assert!(client.is_subscribed("ticker", Some(&"BTC/USDT".to_string())));
945    }
946
947    #[test]
948    fn test_ws_connection_state_from_u8() {
949        assert_eq!(
950            WsConnectionState::from_u8(0),
951            WsConnectionState::Disconnected
952        );
953        assert_eq!(WsConnectionState::from_u8(1), WsConnectionState::Connecting);
954        assert_eq!(WsConnectionState::from_u8(2), WsConnectionState::Connected);
955        assert_eq!(WsConnectionState::from_u8(255), WsConnectionState::Error);
956    }
957
958    #[test]
959    fn test_ws_error_kind() {
960        assert!(WsErrorKind::Transient.is_transient());
961        assert!(WsErrorKind::Permanent.is_permanent());
962    }
963
964    #[test]
965    fn test_ws_error_creation() {
966        let err = WsError::transient("Connection timeout");
967        assert!(err.is_transient());
968        assert_eq!(err.message(), "Connection timeout");
969
970        let err = WsError::permanent("Invalid API key");
971        assert!(err.is_permanent());
972    }
973
974    #[test]
975    fn test_subscription_manager() {
976        let manager = SubscriptionManager::new(2);
977        assert_eq!(manager.max_subscriptions(), 2);
978        assert_eq!(manager.count(), 0);
979        assert!(!manager.is_full());
980
981        let sub = Subscription {
982            channel: "ticker".to_string(),
983            symbol: Some("BTC/USDT".to_string()),
984            params: None,
985        };
986        assert!(manager.try_add("ticker:BTC/USDT".to_string(), sub).is_ok());
987        assert_eq!(manager.count(), 1);
988    }
989}