ccxt_core/
ws_client.rs

1//! WebSocket client module.
2//!
3//! Provides asynchronous WebSocket connection management, subscription handling,
4//! and heartbeat maintenance for cryptocurrency exchange streaming APIs.
5//!
6//! # Observability
7//!
8//! This module uses the `tracing` crate for structured logging. Key events:
9//! - Connection establishment and disconnection
10//! - Subscription and unsubscription events with stream names
11//! - Message parsing failures with raw message preview (truncated)
12//! - Reconnection attempts and outcomes
13//! - Ping/pong heartbeat events
14
15use crate::error::{Error, Result};
16use futures_util::{SinkExt, StreamExt, stream::SplitSink};
17use serde::{Deserialize, Serialize};
18use serde_json::Value;
19use std::collections::HashMap;
20use std::sync::Arc;
21use std::sync::atomic::{AtomicBool, Ordering};
22use tokio::net::TcpStream;
23use tokio::sync::{Mutex, RwLock, mpsc};
24use tokio::task::JoinHandle;
25use tokio::time::{Duration, interval};
26use tokio_tungstenite::{
27    MaybeTlsStream, WebSocketStream, connect_async, tungstenite::protocol::Message,
28};
29use tracing::{debug, error, info, instrument, warn};
30
31/// WebSocket connection state.
32#[derive(Debug, Clone, Copy, PartialEq, Eq)]
33pub enum WsConnectionState {
34    /// Not connected
35    Disconnected,
36    /// Establishing connection
37    Connecting,
38    /// Successfully connected
39    Connected,
40    /// Attempting to reconnect
41    Reconnecting,
42    /// Error state
43    Error,
44}
45
46/// WebSocket message types for exchange communication.
47#[derive(Debug, Clone, Serialize, Deserialize)]
48#[serde(tag = "type", rename_all = "lowercase")]
49pub enum WsMessage {
50    /// Subscribe to a channel
51    Subscribe {
52        /// Channel name
53        channel: String,
54        /// Optional trading pair symbol
55        symbol: Option<String>,
56        /// Additional parameters
57        params: Option<HashMap<String, Value>>,
58    },
59    /// Unsubscribe from a channel
60    Unsubscribe {
61        /// Channel name
62        channel: String,
63        /// Optional trading pair symbol
64        symbol: Option<String>,
65    },
66    /// Ping message for keepalive
67    Ping {
68        /// Timestamp in milliseconds
69        timestamp: i64,
70    },
71    /// Pong response to ping
72    Pong {
73        /// Timestamp in milliseconds
74        timestamp: i64,
75    },
76    /// Authentication message
77    Auth {
78        /// API key
79        api_key: String,
80        /// HMAC signature
81        signature: String,
82        /// Timestamp in milliseconds
83        timestamp: i64,
84    },
85    /// Custom message payload
86    Custom(Value),
87}
88
89/// WebSocket connection configuration.
90#[derive(Debug, Clone)]
91pub struct WsConfig {
92    /// WebSocket server URL
93    pub url: String,
94    /// Connection timeout in milliseconds
95    pub connect_timeout: u64,
96    /// Ping interval in milliseconds
97    pub ping_interval: u64,
98    /// Reconnection delay in milliseconds
99    pub reconnect_interval: u64,
100    /// Maximum reconnection attempts before giving up
101    pub max_reconnect_attempts: u32,
102    /// Enable automatic reconnection on disconnect
103    pub auto_reconnect: bool,
104    /// Enable message compression
105    pub enable_compression: bool,
106    /// Pong timeout in milliseconds
107    ///
108    /// Connection is considered dead if no pong received within this duration.
109    pub pong_timeout: u64,
110}
111
112impl Default for WsConfig {
113    fn default() -> Self {
114        Self {
115            url: String::new(),
116            connect_timeout: 10000,
117            ping_interval: 30000,
118            reconnect_interval: 5000,
119            max_reconnect_attempts: 5,
120            auto_reconnect: true,
121            enable_compression: false,
122            pong_timeout: 90000,
123        }
124    }
125}
126
127/// WebSocket subscription metadata.
128#[derive(Debug, Clone)]
129pub struct Subscription {
130    channel: String,
131    symbol: Option<String>,
132    params: Option<HashMap<String, Value>>,
133}
134
135/// Type alias for WebSocket write half.
136#[allow(dead_code)]
137type WsWriter = SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>;
138
139/// Async WebSocket client for exchange streaming APIs.
140pub struct WsClient {
141    config: WsConfig,
142    state: Arc<RwLock<WsConnectionState>>,
143    subscriptions: Arc<RwLock<HashMap<String, Subscription>>>,
144
145    message_tx: mpsc::UnboundedSender<Value>,
146    message_rx: Arc<RwLock<mpsc::UnboundedReceiver<Value>>>,
147
148    write_tx: Arc<Mutex<Option<mpsc::UnboundedSender<Message>>>>,
149
150    reconnect_count: Arc<RwLock<u32>>,
151
152    shutdown_tx: Arc<Mutex<Option<mpsc::UnboundedSender<()>>>>,
153
154    stats: Arc<RwLock<WsStats>>,
155}
156
157/// WebSocket connection statistics.
158#[derive(Debug, Clone, Default)]
159pub struct WsStats {
160    /// Total messages received
161    pub messages_received: u64,
162    /// Total messages sent
163    pub messages_sent: u64,
164    /// Total bytes received
165    pub bytes_received: u64,
166    /// Total bytes sent
167    pub bytes_sent: u64,
168    /// Last message timestamp in milliseconds
169    pub last_message_time: i64,
170    /// Last ping timestamp in milliseconds
171    pub last_ping_time: i64,
172    /// Last pong timestamp in milliseconds
173    pub last_pong_time: i64,
174    /// Connection established timestamp in milliseconds
175    pub connected_at: i64,
176    /// Number of reconnection attempts
177    pub reconnect_attempts: u32,
178}
179
180impl WsClient {
181    /// Creates a new WebSocket client instance.
182    ///
183    /// # Arguments
184    ///
185    /// * `config` - WebSocket connection configuration
186    ///
187    /// # Returns
188    ///
189    /// A new `WsClient` instance ready to connect
190    pub fn new(config: WsConfig) -> Self {
191        let (message_tx, message_rx) = mpsc::unbounded_channel();
192
193        Self {
194            config,
195            state: Arc::new(RwLock::new(WsConnectionState::Disconnected)),
196            subscriptions: Arc::new(RwLock::new(HashMap::new())),
197            message_tx,
198            message_rx: Arc::new(RwLock::new(message_rx)),
199            write_tx: Arc::new(Mutex::new(None)),
200            reconnect_count: Arc::new(RwLock::new(0)),
201            shutdown_tx: Arc::new(Mutex::new(None)),
202            stats: Arc::new(RwLock::new(WsStats::default())),
203        }
204    }
205
206    /// Establishes connection to the WebSocket server.
207    ///
208    /// Returns immediately if already connected. Automatically starts message
209    /// processing loop and resubscribes to previous channels on success.
210    ///
211    /// # Errors
212    ///
213    /// Returns error if:
214    /// - Connection timeout exceeded
215    /// - Network error occurs
216    /// - Server rejects connection
217    #[instrument(
218        name = "ws_connect",
219        skip(self),
220        fields(url = %self.config.url, timeout_ms = self.config.connect_timeout)
221    )]
222    pub async fn connect(&self) -> Result<()> {
223        {
224            let state = self.state.read().await;
225            if *state == WsConnectionState::Connected {
226                info!("WebSocket already connected");
227                return Ok(());
228            }
229        }
230
231        {
232            let mut state = self.state.write().await;
233            *state = WsConnectionState::Connecting;
234        }
235
236        let url = self.config.url.clone();
237        info!("Initiating WebSocket connection");
238
239        match tokio::time::timeout(
240            Duration::from_millis(self.config.connect_timeout),
241            connect_async(&url),
242        )
243        .await
244        {
245            Ok(Ok((ws_stream, response))) => {
246                info!(
247                    status = response.status().as_u16(),
248                    "WebSocket connection established successfully"
249                );
250
251                *self.state.write().await = WsConnectionState::Connected;
252                *self.reconnect_count.write().await = 0;
253
254                {
255                    let mut stats = self.stats.write().await;
256                    stats.connected_at = chrono::Utc::now().timestamp_millis();
257                }
258
259                self.start_message_loop(ws_stream).await;
260
261                self.resubscribe_all().await?;
262
263                Ok(())
264            }
265            Ok(Err(e)) => {
266                error!(
267                    error = %e,
268                    error_debug = ?e,
269                    "WebSocket connection failed"
270                );
271                *self.state.write().await = WsConnectionState::Error;
272                Err(Error::network(format!(
273                    "WebSocket connection failed: {}",
274                    e
275                )))
276            }
277            Err(_) => {
278                error!(
279                    timeout_ms = self.config.connect_timeout,
280                    "WebSocket connection timeout exceeded"
281                );
282                *self.state.write().await = WsConnectionState::Error;
283                Err(Error::timeout("WebSocket connection timeout"))
284            }
285        }
286    }
287
288    /// Closes the WebSocket connection gracefully.
289    ///
290    /// Sends shutdown signal to background tasks and clears internal state.
291    #[instrument(name = "ws_disconnect", skip(self))]
292    pub async fn disconnect(&self) -> Result<()> {
293        info!("Initiating WebSocket disconnect");
294
295        if let Some(tx) = self.shutdown_tx.lock().await.as_ref() {
296            let _ = tx.send(());
297            debug!("Shutdown signal sent to background tasks");
298        }
299
300        *self.write_tx.lock().await = None;
301
302        let mut state = self.state.write().await;
303        *state = WsConnectionState::Disconnected;
304
305        info!("WebSocket disconnected successfully");
306        Ok(())
307    }
308
309    /// Attempts to reconnect to the WebSocket server.
310    ///
311    /// Respects `max_reconnect_attempts` configuration and waits for
312    /// `reconnect_interval` before attempting connection.
313    ///
314    /// # Errors
315    ///
316    /// Returns error if maximum reconnection attempts exceeded or connection fails.
317    #[instrument(
318        name = "ws_reconnect",
319        skip(self),
320        fields(
321            max_attempts = self.config.max_reconnect_attempts,
322            reconnect_interval_ms = self.config.reconnect_interval
323        )
324    )]
325    pub async fn reconnect(&self) -> Result<()> {
326        let mut count = self.reconnect_count.write().await;
327
328        if *count >= self.config.max_reconnect_attempts {
329            error!(
330                attempts = *count,
331                max = self.config.max_reconnect_attempts,
332                "Max reconnect attempts reached, giving up"
333            );
334            return Err(Error::network("Max reconnect attempts reached"));
335        }
336
337        *count += 1;
338
339        warn!(
340            attempt = *count,
341            max = self.config.max_reconnect_attempts,
342            delay_ms = self.config.reconnect_interval,
343            "Attempting WebSocket reconnection"
344        );
345
346        *self.state.write().await = WsConnectionState::Reconnecting;
347
348        tokio::time::sleep(Duration::from_millis(self.config.reconnect_interval)).await;
349
350        self.connect().await
351    }
352
353    /// Returns the current reconnection attempt count.
354    pub async fn reconnect_count(&self) -> u32 {
355        *self.reconnect_count.read().await
356    }
357
358    /// Resets the reconnection attempt counter to zero.
359    pub async fn reset_reconnect_count(&self) {
360        *self.reconnect_count.write().await = 0;
361        debug!("Reconnect count reset");
362    }
363
364    /// Returns a snapshot of connection statistics.
365    pub async fn stats(&self) -> WsStats {
366        self.stats.read().await.clone()
367    }
368
369    /// Resets all connection statistics to default values.
370    pub async fn reset_stats(&self) {
371        *self.stats.write().await = WsStats::default();
372        debug!("Stats reset");
373    }
374
375    /// Calculates current connection latency in milliseconds.
376    ///
377    /// # Returns
378    ///
379    /// Time difference between last pong and ping, or `None` if no data available.
380    pub async fn latency(&self) -> Option<i64> {
381        let stats = self.stats.read().await;
382        if stats.last_pong_time > 0 && stats.last_ping_time > 0 {
383            Some(stats.last_pong_time - stats.last_ping_time)
384        } else {
385            None
386        }
387    }
388
389    /// Creates an automatic reconnection coordinator.
390    ///
391    /// # Returns
392    ///
393    /// A new [`AutoReconnectCoordinator`] instance for managing reconnection logic.
394    pub fn create_auto_reconnect_coordinator(self: Arc<Self>) -> AutoReconnectCoordinator {
395        AutoReconnectCoordinator::new(self)
396    }
397
398    /// Subscribes to a WebSocket channel.
399    ///
400    /// Subscription is persisted and automatically reestablished on reconnection.
401    ///
402    /// # Arguments
403    ///
404    /// * `channel` - Channel name to subscribe to
405    /// * `symbol` - Optional trading pair symbol
406    /// * `params` - Optional additional subscription parameters
407    ///
408    /// # Errors
409    ///
410    /// Returns error if subscription message fails to send.
411    #[instrument(
412        name = "ws_subscribe",
413        skip(self, params),
414        fields(channel = %channel, symbol = ?symbol)
415    )]
416    pub async fn subscribe(
417        &self,
418        channel: String,
419        symbol: Option<String>,
420        params: Option<HashMap<String, Value>>,
421    ) -> Result<()> {
422        let sub_key = Self::subscription_key(&channel, &symbol);
423        let subscription = Subscription {
424            channel: channel.clone(),
425            symbol: symbol.clone(),
426            params: params.clone(),
427        };
428
429        {
430            let mut subs = self.subscriptions.write().await;
431            subs.insert(sub_key.clone(), subscription);
432        }
433
434        info!(subscription_key = %sub_key, "Subscription registered");
435
436        let state = *self.state.read().await;
437        if state == WsConnectionState::Connected {
438            self.send_subscribe_message(channel, symbol, params).await?;
439            info!(subscription_key = %sub_key, "Subscription message sent");
440        } else {
441            debug!(
442                subscription_key = %sub_key,
443                state = ?state,
444                "Subscription queued (not connected)"
445            );
446        }
447
448        Ok(())
449    }
450
451    /// Unsubscribes from a WebSocket channel.
452    ///
453    /// Removes subscription from internal state and sends unsubscribe message if connected.
454    ///
455    /// # Arguments
456    ///
457    /// * `channel` - Channel name to unsubscribe from
458    /// * `symbol` - Optional trading pair symbol
459    ///
460    /// # Errors
461    ///
462    /// Returns error if unsubscribe message fails to send.
463    #[instrument(
464        name = "ws_unsubscribe",
465        skip(self),
466        fields(channel = %channel, symbol = ?symbol)
467    )]
468    pub async fn unsubscribe(&self, channel: String, symbol: Option<String>) -> Result<()> {
469        let sub_key = Self::subscription_key(&channel, &symbol);
470
471        {
472            let mut subs = self.subscriptions.write().await;
473            subs.remove(&sub_key);
474        }
475
476        info!(subscription_key = %sub_key, "Subscription removed");
477
478        let state = *self.state.read().await;
479        if state == WsConnectionState::Connected {
480            self.send_unsubscribe_message(channel, symbol).await?;
481            info!(subscription_key = %sub_key, "Unsubscribe message sent");
482        }
483
484        Ok(())
485    }
486
487    /// Receives the next available message from the WebSocket stream.
488    ///
489    /// # Returns
490    ///
491    /// The received JSON message, or `None` if the channel is closed.
492    pub async fn receive(&self) -> Option<Value> {
493        let mut rx = self.message_rx.write().await;
494        rx.recv().await
495    }
496
497    /// Returns the current connection state.
498    pub async fn state(&self) -> WsConnectionState {
499        *self.state.read().await
500    }
501
502    /// Checks whether the WebSocket is currently connected.
503    pub async fn is_connected(&self) -> bool {
504        *self.state.read().await == WsConnectionState::Connected
505    }
506
507    /// Sends a raw WebSocket message.
508    ///
509    /// # Arguments
510    ///
511    /// * `message` - WebSocket message to send
512    ///
513    /// # Errors
514    ///
515    /// Returns error if not connected or message transmission fails.
516    #[instrument(name = "ws_send", skip(self, message))]
517    pub async fn send(&self, message: Message) -> Result<()> {
518        let tx = self.write_tx.lock().await;
519
520        if let Some(sender) = tx.as_ref() {
521            sender.send(message).map_err(|e| {
522                error!(
523                    error = %e,
524                    "Failed to send WebSocket message"
525                );
526                Error::network(format!("Failed to send message: {}", e))
527            })?;
528            debug!("WebSocket message sent successfully");
529            Ok(())
530        } else {
531            warn!("WebSocket not connected, cannot send message");
532            Err(Error::network("WebSocket not connected"))
533        }
534    }
535
536    /// Sends a text message over the WebSocket connection.
537    ///
538    /// # Arguments
539    ///
540    /// * `text` - Text content to send
541    ///
542    /// # Errors
543    ///
544    /// Returns error if not connected or transmission fails.
545    #[instrument(name = "ws_send_text", skip(self, text), fields(text_len = text.len()))]
546    pub async fn send_text(&self, text: String) -> Result<()> {
547        self.send(Message::Text(text.into())).await
548    }
549
550    /// Sends a JSON-encoded message over the WebSocket connection.
551    ///
552    /// # Arguments
553    ///
554    /// * `json` - JSON value to serialize and send
555    ///
556    /// # Errors
557    ///
558    /// Returns error if serialization fails, not connected, or transmission fails.
559    #[instrument(name = "ws_send_json", skip(self, json))]
560    pub async fn send_json(&self, json: &Value) -> Result<()> {
561        let text = serde_json::to_string(json).map_err(|e| {
562            error!(error = %e, "Failed to serialize JSON for WebSocket");
563            Error::from(e)
564        })?;
565        self.send_text(text).await
566    }
567
568    /// Generates a unique subscription key from channel and symbol.
569    fn subscription_key(channel: &str, symbol: &Option<String>) -> String {
570        match symbol {
571            Some(s) => format!("{}:{}", channel, s),
572            None => channel.to_string(),
573        }
574    }
575
576    /// Starts the WebSocket message processing loop.
577    ///
578    /// Spawns separate tasks for reading and writing messages, handling shutdown signals.
579    async fn start_message_loop(&self, ws_stream: WebSocketStream<MaybeTlsStream<TcpStream>>) {
580        let (write, mut read) = ws_stream.split();
581
582        let (write_tx, mut write_rx) = mpsc::unbounded_channel::<Message>();
583        *self.write_tx.lock().await = Some(write_tx.clone());
584
585        let (shutdown_tx, mut shutdown_rx) = mpsc::unbounded_channel::<()>();
586        *self.shutdown_tx.lock().await = Some(shutdown_tx);
587
588        let state = Arc::clone(&self.state);
589        let message_tx = self.message_tx.clone();
590        let ping_interval_ms = self.config.ping_interval;
591
592        info!("Starting WebSocket message loop");
593
594        let write_handle = tokio::spawn(async move {
595            let mut write = write;
596            loop {
597                tokio::select! {
598                    Some(msg) = write_rx.recv() => {
599                        if let Err(e) = write.send(msg).await {
600                            error!(error = %e, "Failed to write message");
601                            break;
602                        }
603                    }
604                    _ = shutdown_rx.recv() => {
605                        debug!("Write task received shutdown signal");
606                        let _ = write.send(Message::Close(None)).await;
607                        break;
608                    }
609                }
610            }
611            debug!("Write task terminated");
612        });
613
614        let state_clone = Arc::clone(&state);
615        let ws_stats = Arc::clone(&self.stats);
616        let read_handle = tokio::spawn(async move {
617            debug!("Starting WebSocket read task");
618            while let Some(msg_result) = read.next().await {
619                match msg_result {
620                    Ok(Message::Text(text)) => {
621                        debug!(len = text.len(), "Received text message");
622
623                        {
624                            let mut stats_guard = ws_stats.write().await;
625                            stats_guard.messages_received += 1;
626                            stats_guard.bytes_received += text.len() as u64;
627                            stats_guard.last_message_time = chrono::Utc::now().timestamp_millis();
628                        }
629
630                        match serde_json::from_str::<Value>(&text) {
631                            Ok(json) => {
632                                let _ = message_tx.send(json);
633                            }
634                            Err(e) => {
635                                // Log parse failure with truncated raw message preview
636                                let raw_preview: String = text.chars().take(200).collect();
637                                warn!(
638                                    error = %e,
639                                    raw_message_preview = %raw_preview,
640                                    raw_message_len = text.len(),
641                                    "Failed to parse WebSocket text message as JSON"
642                                );
643                            }
644                        }
645                    }
646                    Ok(Message::Binary(data)) => {
647                        debug!(len = data.len(), "Received binary message");
648
649                        {
650                            let mut stats_guard = ws_stats.write().await;
651                            stats_guard.messages_received += 1;
652                            stats_guard.bytes_received += data.len() as u64;
653                            stats_guard.last_message_time = chrono::Utc::now().timestamp_millis();
654                        }
655
656                        match String::from_utf8(data.to_vec()) {
657                            Ok(text) => {
658                                match serde_json::from_str::<Value>(&text) {
659                                    Ok(json) => {
660                                        let _ = message_tx.send(json);
661                                    }
662                                    Err(e) => {
663                                        // Log parse failure with truncated raw message preview
664                                        let raw_preview: String = text.chars().take(200).collect();
665                                        warn!(
666                                            error = %e,
667                                            raw_message_preview = %raw_preview,
668                                            raw_message_len = text.len(),
669                                            "Failed to parse WebSocket binary message as JSON"
670                                        );
671                                    }
672                                }
673                            }
674                            Err(e) => {
675                                // Log UTF-8 decode failure with hex preview
676                                let hex_preview: String = data
677                                    .iter()
678                                    .take(50)
679                                    .map(|b| format!("{:02x}", b))
680                                    .collect::<Vec<_>>()
681                                    .join(" ");
682                                warn!(
683                                    error = %e,
684                                    hex_preview = %hex_preview,
685                                    data_len = data.len(),
686                                    "Failed to decode WebSocket binary message as UTF-8"
687                                );
688                            }
689                        }
690                    }
691                    Ok(Message::Ping(_)) => {
692                        debug!("Received ping, auto-responding with pong");
693                    }
694                    Ok(Message::Pong(_)) => {
695                        debug!("Received pong");
696
697                        {
698                            let mut stats_guard = ws_stats.write().await;
699                            stats_guard.last_pong_time = chrono::Utc::now().timestamp_millis();
700                        }
701                    }
702                    Ok(Message::Close(frame)) => {
703                        info!(
704                            close_frame = ?frame,
705                            "Received WebSocket close frame"
706                        );
707                        *state_clone.write().await = WsConnectionState::Disconnected;
708                        break;
709                    }
710                    Err(e) => {
711                        error!(
712                            error = %e,
713                            error_debug = ?e,
714                            "WebSocket read error"
715                        );
716                        *state_clone.write().await = WsConnectionState::Error;
717                        break;
718                    }
719                    _ => {
720                        debug!("Received other WebSocket message type");
721                    }
722                }
723            }
724            debug!("WebSocket read task terminated");
725        });
726
727        if ping_interval_ms > 0 {
728            let write_tx_clone = write_tx.clone();
729            let ping_stats = Arc::clone(&self.stats);
730            let ping_state = Arc::clone(&state);
731            let pong_timeout_ms = self.config.pong_timeout;
732
733            tokio::spawn(async move {
734                let mut interval = interval(Duration::from_millis(ping_interval_ms));
735                debug!(
736                    interval_ms = ping_interval_ms,
737                    timeout_ms = pong_timeout_ms,
738                    "Starting ping task with timeout detection"
739                );
740
741                loop {
742                    interval.tick().await;
743
744                    let now = chrono::Utc::now().timestamp_millis();
745                    let last_pong = {
746                        let stats_guard = ping_stats.read().await;
747                        stats_guard.last_pong_time
748                    };
749
750                    if last_pong > 0 {
751                        let elapsed = now - last_pong;
752                        #[allow(clippy::cast_possible_wrap)]
753                        if elapsed > pong_timeout_ms as i64 {
754                            warn!(
755                                elapsed_ms = elapsed,
756                                timeout_ms = pong_timeout_ms,
757                                "Pong timeout detected, marking connection as error"
758                            );
759                            *ping_state.write().await = WsConnectionState::Error;
760                            break;
761                        }
762                    }
763
764                    {
765                        let mut stats_guard = ping_stats.write().await;
766                        stats_guard.last_ping_time = now;
767                    }
768
769                    if write_tx_clone.send(Message::Ping(vec![].into())).is_err() {
770                        debug!("Ping task: write channel closed");
771                        break;
772                    }
773                    debug!("Sent ping");
774                }
775                debug!("Ping task terminated");
776            });
777        }
778
779        tokio::spawn(async move {
780            let _ = tokio::join!(write_handle, read_handle);
781            info!("All WebSocket tasks completed");
782        });
783    }
784
785    /// Sends a subscription message to the WebSocket server.
786    #[instrument(
787        name = "ws_send_subscribe",
788        skip(self, params),
789        fields(channel = %channel, symbol = ?symbol)
790    )]
791    async fn send_subscribe_message(
792        &self,
793        channel: String,
794        symbol: Option<String>,
795        params: Option<HashMap<String, Value>>,
796    ) -> Result<()> {
797        let msg = WsMessage::Subscribe {
798            channel: channel.clone(),
799            symbol: symbol.clone(),
800            params,
801        };
802
803        let json = serde_json::to_value(&msg).map_err(|e| {
804            error!(error = %e, "Failed to serialize subscribe message");
805            Error::from(e)
806        })?;
807
808        debug!("Sending subscribe message to server");
809
810        self.send_json(&json).await?;
811        info!("Subscribe message sent successfully");
812        Ok(())
813    }
814
815    /// Sends an unsubscribe message to the WebSocket server.
816    #[instrument(
817        name = "ws_send_unsubscribe",
818        skip(self),
819        fields(channel = %channel, symbol = ?symbol)
820    )]
821    async fn send_unsubscribe_message(
822        &self,
823        channel: String,
824        symbol: Option<String>,
825    ) -> Result<()> {
826        let msg = WsMessage::Unsubscribe {
827            channel: channel.clone(),
828            symbol: symbol.clone(),
829        };
830
831        let json = serde_json::to_value(&msg).map_err(|e| {
832            error!(error = %e, "Failed to serialize unsubscribe message");
833            Error::from(e)
834        })?;
835
836        debug!("Sending unsubscribe message to server");
837
838        self.send_json(&json).await?;
839        info!("Unsubscribe message sent successfully");
840        Ok(())
841    }
842
843    /// Resubscribes to all previously subscribed channels.
844    async fn resubscribe_all(&self) -> Result<()> {
845        let subs = self.subscriptions.read().await;
846        for subscription in subs.values() {
847            self.send_subscribe_message(
848                subscription.channel.clone(),
849                subscription.symbol.clone(),
850                subscription.params.clone(),
851            )
852            .await?;
853        }
854        Ok(())
855    }
856}
857/// WebSocket connection event types.
858#[derive(Debug, Clone)]
859pub enum WsEvent {
860    /// Connection established successfully
861    Connected,
862    /// Connection closed
863    Disconnected,
864    /// Reconnection in progress
865    Reconnecting {
866        /// Current reconnection attempt number
867        attempt: u32,
868    },
869    /// Reconnection succeeded
870    ReconnectSuccess,
871    /// Reconnection failed
872    ReconnectFailed {
873        /// Error message
874        error: String,
875    },
876    /// Subscriptions restored after reconnection
877    SubscriptionRestored,
878}
879
880/// Event callback function type.
881pub type WsEventCallback = Arc<dyn Fn(WsEvent) + Send + Sync>;
882
883/// Automatic reconnection coordinator for WebSocket connections.
884///
885/// Monitors connection state and triggers reconnection attempts when disconnected.
886pub struct AutoReconnectCoordinator {
887    client: Arc<WsClient>,
888    enabled: Arc<AtomicBool>,
889    reconnect_task: Arc<Mutex<Option<JoinHandle<()>>>>,
890    event_callback: Option<WsEventCallback>,
891}
892
893impl AutoReconnectCoordinator {
894    /// Creates a new automatic reconnection coordinator.
895    ///
896    /// # Arguments
897    ///
898    /// * `client` - Arc reference to the WebSocket client
899    pub fn new(client: Arc<WsClient>) -> Self {
900        Self {
901            client,
902            enabled: Arc::new(AtomicBool::new(false)),
903            reconnect_task: Arc::new(Mutex::new(None)),
904            event_callback: None,
905        }
906    }
907
908    /// 设置事件回调
909    ///
910    /// # Arguments
911    /// * `callback` - 事件回调函数
912    ///
913    /// # Returns
914    /// Self,用于链式调用
915    pub fn with_callback(mut self, callback: WsEventCallback) -> Self {
916        self.event_callback = Some(callback);
917        self
918    }
919
920    /// Starts the automatic reconnection coordinator.
921    ///
922    /// Begins monitoring connection state and automatically reconnects on disconnect.
923    pub async fn start(&self) {
924        if self.enabled.swap(true, Ordering::SeqCst) {
925            info!("Auto-reconnect already started");
926            return;
927        }
928
929        info!("Starting auto-reconnect coordinator");
930
931        let client = Arc::clone(&self.client);
932        let enabled = Arc::clone(&self.enabled);
933        let callback = self.event_callback.clone();
934
935        let handle = tokio::spawn(async move {
936            Self::reconnect_loop(client, enabled, callback).await;
937        });
938
939        *self.reconnect_task.lock().await = Some(handle);
940    }
941
942    /// Stops the automatic reconnection coordinator.
943    ///
944    /// Halts monitoring and reconnection tasks.
945    pub async fn stop(&self) {
946        if !self.enabled.swap(false, Ordering::SeqCst) {
947            info!("Auto-reconnect already stopped");
948            return;
949        }
950
951        info!("Stopping auto-reconnect coordinator");
952
953        let mut task = self.reconnect_task.lock().await;
954        if let Some(handle) = task.take() {
955            handle.abort();
956        }
957    }
958
959    /// Internal reconnection loop.
960    ///
961    /// Continuously monitors connection state and triggers reconnection
962    /// when `Error` or `Disconnected` state is detected.
963    async fn reconnect_loop(
964        client: Arc<WsClient>,
965        enabled: Arc<AtomicBool>,
966        callback: Option<WsEventCallback>,
967    ) {
968        let mut check_interval = interval(Duration::from_secs(1));
969
970        loop {
971            check_interval.tick().await;
972
973            if !enabled.load(Ordering::SeqCst) {
974                debug!("Auto-reconnect disabled, exiting loop");
975                break;
976            }
977
978            let state = client.state().await;
979
980            if matches!(
981                state,
982                WsConnectionState::Disconnected | WsConnectionState::Error
983            ) {
984                let attempt = client.reconnect_count().await;
985
986                info!(
987                    attempt = attempt + 1,
988                    state = ?state,
989                    "Connection lost, attempting reconnect"
990                );
991
992                if let Some(ref cb) = callback {
993                    cb(WsEvent::Reconnecting {
994                        attempt: attempt + 1,
995                    });
996                }
997
998                match client.reconnect().await {
999                    Ok(_) => {
1000                        info!("Reconnection successful");
1001
1002                        if let Some(ref cb) = callback {
1003                            cb(WsEvent::ReconnectSuccess);
1004                        }
1005
1006                        match client.resubscribe_all().await {
1007                            Ok(_) => {
1008                                info!("Subscriptions restored");
1009                                if let Some(ref cb) = callback {
1010                                    cb(WsEvent::SubscriptionRestored);
1011                                }
1012                            }
1013                            Err(e) => {
1014                                error!(error = %e, "Failed to restore subscriptions");
1015                            }
1016                        }
1017                    }
1018                    Err(e) => {
1019                        error!(error = %e, "Reconnection failed");
1020
1021                        if let Some(ref cb) = callback {
1022                            cb(WsEvent::ReconnectFailed {
1023                                error: e.to_string(),
1024                            });
1025                        }
1026
1027                        tokio::time::sleep(Duration::from_secs(5)).await;
1028                    }
1029                }
1030            }
1031        }
1032
1033        info!("Auto-reconnect loop terminated");
1034    }
1035}
1036
1037#[cfg(test)]
1038mod tests {
1039    use super::*;
1040
1041    #[test]
1042    fn test_ws_config_default() {
1043        let config = WsConfig::default();
1044        assert_eq!(config.connect_timeout, 10000);
1045        assert_eq!(config.ping_interval, 30000);
1046        assert_eq!(config.reconnect_interval, 5000);
1047        assert_eq!(config.max_reconnect_attempts, 5);
1048        assert!(config.auto_reconnect);
1049        assert!(!config.enable_compression);
1050        assert_eq!(config.pong_timeout, 90000);
1051    }
1052
1053    #[test]
1054    fn test_subscription_key() {
1055        let key1 = WsClient::subscription_key("ticker", &Some("BTC/USDT".to_string()));
1056        assert_eq!(key1, "ticker:BTC/USDT");
1057
1058        let key2 = WsClient::subscription_key("trades", &None);
1059        assert_eq!(key2, "trades");
1060    }
1061
1062    #[tokio::test]
1063    async fn test_ws_client_creation() {
1064        let config = WsConfig {
1065            url: "wss://example.com/ws".to_string(),
1066            ..Default::default()
1067        };
1068
1069        let client = WsClient::new(config);
1070        assert_eq!(client.state().await, WsConnectionState::Disconnected);
1071        assert!(!client.is_connected().await);
1072    }
1073
1074    #[tokio::test]
1075    async fn test_subscribe_adds_subscription() {
1076        let config = WsConfig {
1077            url: "wss://example.com/ws".to_string(),
1078            ..Default::default()
1079        };
1080
1081        let client = WsClient::new(config);
1082
1083        let result = client
1084            .subscribe("ticker".to_string(), Some("BTC/USDT".to_string()), None)
1085            .await;
1086        assert!(result.is_ok());
1087
1088        let subs = client.subscriptions.read().await;
1089        assert_eq!(subs.len(), 1);
1090        assert!(subs.contains_key("ticker:BTC/USDT"));
1091    }
1092
1093    #[tokio::test]
1094    async fn test_unsubscribe_removes_subscription() {
1095        let config = WsConfig {
1096            url: "wss://example.com/ws".to_string(),
1097            ..Default::default()
1098        };
1099
1100        let client = WsClient::new(config);
1101
1102        client
1103            .subscribe("ticker".to_string(), Some("BTC/USDT".to_string()), None)
1104            .await
1105            .unwrap();
1106
1107        let result = client
1108            .unsubscribe("ticker".to_string(), Some("BTC/USDT".to_string()))
1109            .await;
1110        assert!(result.is_ok());
1111
1112        let subs = client.subscriptions.read().await;
1113        assert_eq!(subs.len(), 0);
1114    }
1115
1116    #[test]
1117    fn test_ws_message_serialization() {
1118        let msg = WsMessage::Subscribe {
1119            channel: "ticker".to_string(),
1120            symbol: Some("BTC/USDT".to_string()),
1121            params: None,
1122        };
1123
1124        let json = serde_json::to_string(&msg).unwrap();
1125        assert!(json.contains("\"type\":\"subscribe\""));
1126        assert!(json.contains("\"channel\":\"ticker\""));
1127    }
1128}