ccxt_core/ws_client.rs
1//! WebSocket client module.
2//!
3//! Provides asynchronous WebSocket connection management, subscription handling,
4//! and heartbeat maintenance for cryptocurrency exchange streaming APIs.
5//!
6//! # Features
7//!
8//! - **Exponential Backoff Reconnection**: Configurable retry delays with jitter
9//! to prevent thundering herd effects during reconnection.
10//! - **Error Classification**: Distinguishes between transient (retryable) and
11//! permanent (non-retryable) errors for intelligent reconnection decisions.
12//! - **Cancellation Support**: Graceful cancellation of long-running operations
13//! via [`CancellationToken`](tokio_util::sync::CancellationToken).
14//! - **Subscription Limits**: Configurable maximum subscription count to prevent
15//! resource exhaustion.
16//! - **Lock-Free Statistics**: Atomic counters for connection statistics to
17//! prevent deadlocks in async contexts.
18//! - **Graceful Shutdown**: Clean shutdown with pending operation completion
19//! and resource cleanup.
20//!
21//! # Exponential Backoff Configuration
22//!
23//! The [`BackoffConfig`] struct controls how retry delays are calculated during
24//! reconnection attempts. The formula used is:
25//!
26//! ```text
27//! delay = min(base_delay * multiplier^attempt, max_delay) + jitter
28//! ```
29//!
30//! Where jitter is a random value in the range `[0, delay * jitter_factor]`.
31//!
32//! ## Default Configuration
33//!
34//! | Parameter | Default Value | Description |
35//! |-----------|---------------|-------------|
36//! | `base_delay` | 1 second | Initial delay before first retry |
37//! | `max_delay` | 60 seconds | Maximum delay cap |
38//! | `jitter_factor` | 0.25 (25%) | Random jitter to prevent thundering herd |
39//! | `multiplier` | 2.0 | Exponential growth factor |
40//!
41//! ## Example: Custom Backoff Configuration
42//!
43//! ```rust
44//! use ccxt_core::ws_client::{WsConfig, BackoffConfig};
45//! use std::time::Duration;
46//!
47//! let config = WsConfig {
48//! url: "wss://stream.example.com/ws".to_string(),
49//! backoff_config: BackoffConfig {
50//! base_delay: Duration::from_millis(500), // Start with 500ms
51//! max_delay: Duration::from_secs(30), // Cap at 30 seconds
52//! jitter_factor: 0.2, // 20% jitter
53//! multiplier: 2.0, // Double each attempt
54//! },
55//! ..Default::default()
56//! };
57//! ```
58//!
59//! ## Retry Delay Progression (with default config, no jitter)
60//!
61//! | Attempt | Delay |
62//! |---------|-------|
63//! | 0 | 1s |
64//! | 1 | 2s |
65//! | 2 | 4s |
66//! | 3 | 8s |
67//! | 4 | 16s |
68//! | 5 | 32s |
69//! | 6+ | 60s (capped) |
70//!
71//! # Cancellation Support
72//!
73//! Long-running operations like `connect`, `reconnect`, and `subscribe` can be
74//! cancelled using a [`CancellationToken`](tokio_util::sync::CancellationToken).
75//! This enables graceful shutdown and timeout handling.
76//!
77//! ## Example: Using CancellationToken
78//!
79//! ```rust,ignore
80//! use ccxt_core::ws_client::{WsClient, WsConfig};
81//! use tokio_util::sync::CancellationToken;
82//! use std::time::Duration;
83//!
84//! #[tokio::main]
85//! async fn main() -> Result<(), Box<dyn std::error::Error>> {
86//! let client = WsClient::new(WsConfig {
87//! url: "wss://stream.example.com/ws".to_string(),
88//! ..Default::default()
89//! });
90//!
91//! // Create a cancellation token
92//! let token = CancellationToken::new();
93//! let token_clone = token.clone();
94//!
95//! // Set the token on the client
96//! client.set_cancel_token(token.clone()).await;
97//!
98//! // Spawn a task to cancel after 10 seconds
99//! tokio::spawn(async move {
100//! tokio::time::sleep(Duration::from_secs(10)).await;
101//! println!("Cancelling connection...");
102//! token_clone.cancel();
103//! });
104//!
105//! // Connect with cancellation support
106//! match client.connect_with_cancel(Some(token)).await {
107//! Ok(()) => println!("Connected successfully!"),
108//! Err(e) if e.as_cancelled().is_some() => {
109//! println!("Connection was cancelled");
110//! }
111//! Err(e) => println!("Connection failed: {}", e),
112//! }
113//!
114//! Ok(())
115//! }
116//! ```
117//!
118//! ## Sharing CancellationToken
119//!
120//! The `CancellationToken` can be cloned and shared across multiple operations.
121//! Cancelling any clone will cancel all operations using that token:
122//!
123//! ```rust,ignore
124//! use tokio_util::sync::CancellationToken;
125//!
126//! let token = CancellationToken::new();
127//!
128//! // Clone for different operations
129//! let connect_token = token.clone();
130//! let reconnect_token = token.clone();
131//!
132//! // Cancelling the original cancels all clones
133//! token.cancel();
134//!
135//! assert!(connect_token.is_cancelled());
136//! assert!(reconnect_token.is_cancelled());
137//! ```
138//!
139//! # Subscription Limits
140//!
141//! The [`WsConfig::max_subscriptions`] field limits the number of concurrent
142//! subscriptions to prevent resource exhaustion. When the limit is reached,
143//! new subscription attempts will fail with [`Error::ResourceExhausted`](crate::error::Error).
144//!
145//! ## Default Limit
146//!
147//! The default maximum is 100 subscriptions (see [`DEFAULT_MAX_SUBSCRIPTIONS`]).
148//!
149//! ## Example: Configuring Subscription Limits
150//!
151//! ```rust
152//! use ccxt_core::ws_client::{WsClient, WsConfig};
153//!
154//! let client = WsClient::new(WsConfig {
155//! url: "wss://stream.example.com/ws".to_string(),
156//! max_subscriptions: 50, // Limit to 50 subscriptions
157//! ..Default::default()
158//! });
159//!
160//! // Check current capacity
161//! assert_eq!(client.subscription_count(), 0);
162//! assert_eq!(client.remaining_capacity(), 50);
163//! ```
164//!
165//! ## Checking Capacity Before Subscribing
166//!
167//! ```rust,ignore
168//! use ccxt_core::ws_client::{WsClient, WsConfig};
169//!
170//! let client = WsClient::new(WsConfig::default());
171//!
172//! // Check if there's room for more subscriptions
173//! if client.remaining_capacity() > 0 {
174//! client.subscribe("ticker".to_string(), Some("BTC/USDT".to_string()), None).await?;
175//! } else {
176//! println!("No subscription capacity available");
177//! }
178//! ```
179//!
180//! # Error Classification
181//!
182//! The [`WsErrorKind`] enum classifies WebSocket errors into two categories:
183//!
184//! - **Transient**: Temporary errors that may recover with retry (network issues,
185//! server unavailable, connection resets).
186//! - **Permanent**: Errors that should not be retried (authentication failures,
187//! protocol errors, invalid credentials).
188//!
189//! ## Example: Handling Errors by Type
190//!
191//! ```rust
192//! use ccxt_core::ws_client::{WsError, WsErrorKind};
193//!
194//! fn handle_error(error: &WsError) {
195//! if error.is_transient() {
196//! println!("Transient error, will retry: {}", error.message());
197//! } else {
198//! println!("Permanent error, stopping: {}", error.message());
199//! }
200//! }
201//! ```
202//!
203//! # Graceful Shutdown
204//!
205//! The [`WsClient::shutdown`] method performs a complete graceful shutdown:
206//!
207//! 1. Cancels all pending reconnection attempts
208//! 2. Sends WebSocket close frame to the server
209//! 3. Waits for pending operations to complete (with timeout)
210//! 4. Clears all resources (subscriptions, channels, etc.)
211//! 5. Emits a [`WsEvent::Shutdown`] event
212//!
213//! ## Example: Graceful Shutdown
214//!
215//! ```rust,ignore
216//! use ccxt_core::ws_client::{WsClient, WsConfig, WsEvent};
217//! use std::sync::Arc;
218//!
219//! let client = WsClient::new(WsConfig {
220//! url: "wss://stream.example.com/ws".to_string(),
221//! shutdown_timeout: 5000, // 5 second timeout
222//! ..Default::default()
223//! });
224//!
225//! // Set up event callback
226//! client.set_event_callback(Arc::new(|event| {
227//! if let WsEvent::Shutdown = event {
228//! println!("Shutdown completed!");
229//! }
230//! })).await;
231//!
232//! // Connect and do work...
233//! client.connect().await?;
234//!
235//! // Gracefully shutdown
236//! client.shutdown().await;
237//! ```
238//!
239//! # Observability
240//!
241//! This module uses the `tracing` crate for structured logging. Key events:
242//! - Connection establishment and disconnection
243//! - Subscription and unsubscription events with stream names
244//! - Message parsing failures with raw message preview (truncated)
245//! - Reconnection attempts and outcomes with backoff delays
246//! - Ping/pong heartbeat events
247//! - Cancellation events
248//! - Shutdown progress
249//!
250//! # Lock Ordering Rules
251//!
252//! To prevent deadlocks, locks in this module are acquired in a consistent order:
253//!
254//! 1. `cancel_token` - Cancellation token mutex
255//! 2. `event_callback` - Event callback mutex
256//! 3. `write_tx` - Write channel mutex
257//! 4. `shutdown_tx` - Shutdown channel mutex
258//!
259//! **Important**: No locks are held across await points. All lock guards are
260//! dropped before any async operations.
261
262use crate::error::{Error, Result};
263use dashmap::DashMap;
264use futures_util::{SinkExt, StreamExt, stream::SplitSink};
265use rand::Rng;
266use serde::{Deserialize, Serialize};
267use serde_json::Value;
268use std::collections::HashMap;
269use std::sync::Arc;
270use std::sync::atomic::{AtomicBool, AtomicI64, AtomicU8, AtomicU32, AtomicU64, Ordering};
271use tokio::net::TcpStream;
272use tokio::sync::{Mutex, RwLock, mpsc};
273use tokio::task::JoinHandle;
274use tokio::time::{Duration, interval};
275use tokio_tungstenite::{
276 MaybeTlsStream, WebSocketStream, connect_async, tungstenite::protocol::Message,
277};
278use tokio_util::sync::CancellationToken;
279use tracing::{debug, error, info, instrument, warn};
280
281// ==================== WebSocket Error Classification ====================
282
283/// WebSocket error classification.
284///
285/// This enum categorizes WebSocket errors into two types:
286/// - `Transient`: Temporary errors that may recover with retry (network issues, server unavailable)
287/// - `Permanent`: Errors that should not be retried (authentication failures, protocol errors)
288///
289/// # Example
290///
291/// ```rust
292/// use ccxt_core::ws_client::WsErrorKind;
293///
294/// let kind = WsErrorKind::Transient;
295/// assert!(kind.is_transient());
296///
297/// let kind = WsErrorKind::Permanent;
298/// assert!(!kind.is_transient());
299/// ```
300#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
301pub enum WsErrorKind {
302 /// Transient errors that may recover with retry.
303 ///
304 /// Examples:
305 /// - Network timeouts
306 /// - Connection resets
307 /// - Server unavailable (5xx errors)
308 /// - Temporary connection failures
309 Transient,
310
311 /// Permanent errors that should not be retried.
312 ///
313 /// Examples:
314 /// - Authentication failures (401/403)
315 /// - Protocol errors
316 /// - Invalid credentials
317 /// - Invalid parameters
318 Permanent,
319}
320
321impl WsErrorKind {
322 /// Returns `true` if this is a transient error that may recover with retry.
323 ///
324 /// # Example
325 ///
326 /// ```rust
327 /// use ccxt_core::ws_client::WsErrorKind;
328 ///
329 /// assert!(WsErrorKind::Transient.is_transient());
330 /// assert!(!WsErrorKind::Permanent.is_transient());
331 /// ```
332 #[inline]
333 #[must_use]
334 pub fn is_transient(self) -> bool {
335 matches!(self, Self::Transient)
336 }
337
338 /// Returns `true` if this is a permanent error that should not be retried.
339 ///
340 /// # Example
341 ///
342 /// ```rust
343 /// use ccxt_core::ws_client::WsErrorKind;
344 ///
345 /// assert!(WsErrorKind::Permanent.is_permanent());
346 /// assert!(!WsErrorKind::Transient.is_permanent());
347 /// ```
348 #[inline]
349 #[must_use]
350 pub fn is_permanent(self) -> bool {
351 matches!(self, Self::Permanent)
352 }
353}
354
355impl std::fmt::Display for WsErrorKind {
356 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
357 match self {
358 Self::Transient => write!(f, "Transient"),
359 Self::Permanent => write!(f, "Permanent"),
360 }
361 }
362}
363
364/// Extended WebSocket error with classification.
365///
366/// This struct wraps WebSocket errors with additional metadata including:
367/// - Error kind (transient or permanent)
368/// - Human-readable message
369/// - Optional source error for error chaining
370///
371/// # Example
372///
373/// ```rust
374/// use ccxt_core::ws_client::{WsError, WsErrorKind};
375///
376/// // Create a transient error
377/// let err = WsError::transient("Connection reset by peer");
378/// assert!(err.is_transient());
379/// assert_eq!(err.kind(), WsErrorKind::Transient);
380///
381/// // Create a permanent error
382/// let err = WsError::permanent("Authentication failed: invalid API key");
383/// assert!(!err.is_transient());
384/// assert_eq!(err.kind(), WsErrorKind::Permanent);
385/// ```
386#[derive(Debug)]
387pub struct WsError {
388 /// Error kind (transient or permanent)
389 kind: WsErrorKind,
390 /// Human-readable error message
391 message: String,
392 /// Original error source (if any)
393 source: Option<Box<dyn std::error::Error + Send + Sync + 'static>>,
394}
395
396impl WsError {
397 /// Creates a new `WsError` with the specified kind and message.
398 ///
399 /// # Arguments
400 ///
401 /// * `kind` - The error classification (transient or permanent)
402 /// * `message` - Human-readable error message
403 ///
404 /// # Example
405 ///
406 /// ```rust
407 /// use ccxt_core::ws_client::{WsError, WsErrorKind};
408 ///
409 /// let err = WsError::new(WsErrorKind::Transient, "Connection timeout");
410 /// assert!(err.is_transient());
411 /// ```
412 pub fn new(kind: WsErrorKind, message: impl Into<String>) -> Self {
413 Self {
414 kind,
415 message: message.into(),
416 source: None,
417 }
418 }
419
420 /// Creates a new `WsError` with a source error.
421 ///
422 /// # Arguments
423 ///
424 /// * `kind` - The error classification (transient or permanent)
425 /// * `message` - Human-readable error message
426 /// * `source` - The underlying error that caused this error
427 ///
428 /// # Example
429 ///
430 /// ```rust
431 /// use ccxt_core::ws_client::{WsError, WsErrorKind};
432 /// use std::io;
433 ///
434 /// let io_err = io::Error::new(io::ErrorKind::ConnectionReset, "connection reset");
435 /// let err = WsError::with_source(
436 /// WsErrorKind::Transient,
437 /// "Connection lost",
438 /// io_err
439 /// );
440 /// assert!(err.source().is_some());
441 /// ```
442 pub fn with_source<E>(kind: WsErrorKind, message: impl Into<String>, source: E) -> Self
443 where
444 E: std::error::Error + Send + Sync + 'static,
445 {
446 Self {
447 kind,
448 message: message.into(),
449 source: Some(Box::new(source)),
450 }
451 }
452
453 /// Creates a transient error.
454 ///
455 /// Transient errors are temporary and may recover with retry.
456 ///
457 /// # Arguments
458 ///
459 /// * `message` - Human-readable error message
460 ///
461 /// # Example
462 ///
463 /// ```rust
464 /// use ccxt_core::ws_client::WsError;
465 ///
466 /// let err = WsError::transient("Network timeout");
467 /// assert!(err.is_transient());
468 /// ```
469 pub fn transient(message: impl Into<String>) -> Self {
470 Self::new(WsErrorKind::Transient, message)
471 }
472
473 /// Creates a transient error with a source.
474 ///
475 /// # Arguments
476 ///
477 /// * `message` - Human-readable error message
478 /// * `source` - The underlying error that caused this error
479 pub fn transient_with_source<E>(message: impl Into<String>, source: E) -> Self
480 where
481 E: std::error::Error + Send + Sync + 'static,
482 {
483 Self::with_source(WsErrorKind::Transient, message, source)
484 }
485
486 /// Creates a permanent error.
487 ///
488 /// Permanent errors should not be retried as they indicate
489 /// a fundamental issue that won't resolve with retries.
490 ///
491 /// # Arguments
492 ///
493 /// * `message` - Human-readable error message
494 ///
495 /// # Example
496 ///
497 /// ```rust
498 /// use ccxt_core::ws_client::WsError;
499 ///
500 /// let err = WsError::permanent("Invalid API key");
501 /// assert!(!err.is_transient());
502 /// ```
503 pub fn permanent(message: impl Into<String>) -> Self {
504 Self::new(WsErrorKind::Permanent, message)
505 }
506
507 /// Creates a permanent error with a source.
508 ///
509 /// # Arguments
510 ///
511 /// * `message` - Human-readable error message
512 /// * `source` - The underlying error that caused this error
513 pub fn permanent_with_source<E>(message: impl Into<String>, source: E) -> Self
514 where
515 E: std::error::Error + Send + Sync + 'static,
516 {
517 Self::with_source(WsErrorKind::Permanent, message, source)
518 }
519
520 /// Returns the error kind.
521 ///
522 /// # Example
523 ///
524 /// ```rust
525 /// use ccxt_core::ws_client::{WsError, WsErrorKind};
526 ///
527 /// let err = WsError::transient("timeout");
528 /// assert_eq!(err.kind(), WsErrorKind::Transient);
529 /// ```
530 #[inline]
531 #[must_use]
532 pub fn kind(&self) -> WsErrorKind {
533 self.kind
534 }
535
536 /// Returns the error message.
537 ///
538 /// # Example
539 ///
540 /// ```rust
541 /// use ccxt_core::ws_client::WsError;
542 ///
543 /// let err = WsError::transient("Connection timeout");
544 /// assert_eq!(err.message(), "Connection timeout");
545 /// ```
546 #[inline]
547 #[must_use]
548 pub fn message(&self) -> &str {
549 &self.message
550 }
551
552 /// Returns `true` if this is a transient error.
553 ///
554 /// Transient errors may recover with retry.
555 ///
556 /// # Example
557 ///
558 /// ```rust
559 /// use ccxt_core::ws_client::WsError;
560 ///
561 /// let err = WsError::transient("timeout");
562 /// assert!(err.is_transient());
563 ///
564 /// let err = WsError::permanent("auth failed");
565 /// assert!(!err.is_transient());
566 /// ```
567 #[inline]
568 #[must_use]
569 pub fn is_transient(&self) -> bool {
570 self.kind.is_transient()
571 }
572
573 /// Returns `true` if this is a permanent error.
574 ///
575 /// Permanent errors should not be retried.
576 ///
577 /// # Example
578 ///
579 /// ```rust
580 /// use ccxt_core::ws_client::WsError;
581 ///
582 /// let err = WsError::permanent("auth failed");
583 /// assert!(err.is_permanent());
584 ///
585 /// let err = WsError::transient("timeout");
586 /// assert!(!err.is_permanent());
587 /// ```
588 #[inline]
589 #[must_use]
590 pub fn is_permanent(&self) -> bool {
591 self.kind.is_permanent()
592 }
593
594 /// Returns the source error, if any.
595 ///
596 /// # Example
597 ///
598 /// ```rust
599 /// use ccxt_core::ws_client::WsError;
600 /// use std::io;
601 ///
602 /// let io_err = io::Error::new(io::ErrorKind::ConnectionReset, "reset");
603 /// let err = WsError::transient_with_source("Connection lost", io_err);
604 /// assert!(err.source().is_some());
605 /// ```
606 #[must_use]
607 pub fn source(&self) -> Option<&(dyn std::error::Error + Send + Sync + 'static)> {
608 self.source.as_deref()
609 }
610}
611
612impl std::fmt::Display for WsError {
613 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
614 write!(f, "[{}] {}", self.kind, self.message)
615 }
616}
617
618impl std::error::Error for WsError {
619 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
620 self.source
621 .as_ref()
622 .map(|e| e.as_ref() as &(dyn std::error::Error + 'static))
623 }
624}
625
626impl WsError {
627 /// Classifies a tungstenite WebSocket error.
628 ///
629 /// This method analyzes the error type and classifies it as either
630 /// transient (retryable) or permanent (non-retryable).
631 ///
632 /// # Classification Rules
633 ///
634 /// ## Transient Errors (retryable)
635 /// - IO errors (network issues, connection resets)
636 /// - `ConnectionClosed` (server closed connection)
637 /// - `AlreadyClosed` (connection was already closed)
638 /// - Server errors (5xx HTTP status codes)
639 ///
640 /// ## Permanent Errors (non-retryable)
641 /// - Protocol errors (WebSocket protocol violations)
642 /// - UTF-8 encoding errors
643 /// - Authentication errors (401/403 HTTP status codes)
644 /// - Client errors (4xx HTTP status codes, except 5xx)
645 ///
646 /// # Arguments
647 ///
648 /// * `err` - The tungstenite error to classify
649 ///
650 /// # Example
651 ///
652 /// ```rust,ignore
653 /// use ccxt_core::ws_client::WsError;
654 /// use tokio_tungstenite::tungstenite::Error as TungError;
655 ///
656 /// // IO errors are transient
657 /// let io_err = std::io::Error::new(std::io::ErrorKind::ConnectionReset, "reset");
658 /// let ws_err = WsError::from_tungstenite(&TungError::Io(io_err));
659 /// assert!(ws_err.is_transient());
660 /// ```
661 pub fn from_tungstenite(err: &tokio_tungstenite::tungstenite::Error) -> Self {
662 use tokio_tungstenite::tungstenite::Error as TungError;
663
664 match err {
665 // Transient errors - network issues, temporary failures
666 TungError::Io(io_err) => {
667 let message = format!("IO error: {io_err}");
668 Self::transient_with_source(
669 message,
670 std::io::Error::new(io_err.kind(), io_err.to_string()),
671 )
672 }
673
674 TungError::ConnectionClosed => Self::transient("Connection closed by server"),
675
676 TungError::AlreadyClosed => Self::transient("Connection already closed"),
677
678 // Permanent errors - protocol violations, auth failures
679 TungError::Protocol(protocol_err) => {
680 let message = format!("Protocol error: {protocol_err}");
681 Self::permanent(message)
682 }
683
684 TungError::Utf8(_) => Self::permanent("UTF-8 encoding error in WebSocket message"),
685
686 TungError::Http(response) => {
687 let status = response.status();
688 let status_code = status.as_u16();
689
690 if status_code == 401 || status_code == 403 {
691 // Authentication errors are permanent
692 Self::permanent(format!("Authentication error: HTTP {status}"))
693 } else if status.is_server_error() {
694 // Server errors (5xx) are transient
695 Self::transient(format!("Server error: HTTP {status}"))
696 } else {
697 // Other client errors (4xx) are permanent
698 Self::permanent(format!("HTTP error: {status}"))
699 }
700 }
701
702 TungError::HttpFormat(http_err) => {
703 Self::permanent(format!("HTTP format error: {http_err}"))
704 }
705
706 TungError::Url(url_err) => Self::permanent(format!("Invalid URL: {url_err}")),
707
708 TungError::Tls(tls_err) => {
709 // TLS errors could be transient (certificate issues during handshake)
710 // or permanent (invalid certificates). We treat them as transient
711 // to allow retry with potential certificate refresh.
712 Self::transient(format!("TLS error: {tls_err}"))
713 }
714
715 TungError::Capacity(capacity_err) => {
716 // Capacity errors (message too large) are permanent
717 Self::permanent(format!("Capacity error: {capacity_err}"))
718 }
719
720 TungError::WriteBufferFull(msg) => {
721 // Write buffer full is transient - can retry after buffer drains
722 Self::transient(format!("Write buffer full: {msg:?}"))
723 }
724
725 TungError::AttackAttempt => {
726 // Attack attempts are permanent - don't retry
727 Self::permanent("Potential attack detected")
728 }
729 }
730 }
731
732 /// Classifies a generic error and wraps it in a `WsError`.
733 ///
734 /// This is a convenience method for wrapping arbitrary errors.
735 /// By default, unknown errors are classified as transient.
736 ///
737 /// # Arguments
738 ///
739 /// * `err` - The error to classify
740 ///
741 /// # Example
742 ///
743 /// ```rust
744 /// use ccxt_core::ws_client::WsError;
745 /// use ccxt_core::error::Error;
746 ///
747 /// let err = Error::network("Connection failed");
748 /// let ws_err = WsError::from_error(&err);
749 /// assert!(ws_err.is_transient());
750 /// ```
751 pub fn from_error(err: &Error) -> Self {
752 // Check for specific error types that indicate permanent failures
753 if err.as_authentication().is_some() {
754 return Self::permanent(format!("Authentication error: {err}"));
755 }
756
757 if err.as_cancelled().is_some() {
758 // Cancelled is a special case - not really transient or permanent
759 // but we treat it as permanent since retrying won't help
760 return Self::permanent(format!("Operation cancelled: {err}"));
761 }
762
763 if err.as_resource_exhausted().is_some() {
764 // Resource exhausted is permanent until resources are freed
765 return Self::permanent(format!("Resource exhausted: {err}"));
766 }
767
768 // Default to transient for network and other errors
769 Self::transient(format!("Error: {err}"))
770 }
771}
772
773/// Exponential backoff configuration for reconnection.
774///
775/// This configuration controls how the WebSocket client calculates retry delays
776/// when attempting to reconnect after a connection failure.
777///
778/// # Example
779///
780/// ```
781/// use ccxt_core::ws_client::BackoffConfig;
782/// use std::time::Duration;
783///
784/// let config = BackoffConfig {
785/// base_delay: Duration::from_millis(500),
786/// max_delay: Duration::from_secs(30),
787/// jitter_factor: 0.2,
788/// multiplier: 2.0,
789/// };
790/// ```
791#[derive(Debug, Clone)]
792pub struct BackoffConfig {
793 /// Base delay for first retry (default: 1 second)
794 ///
795 /// This is the initial delay before the first reconnection attempt.
796 pub base_delay: Duration,
797
798 /// Maximum delay cap (default: 60 seconds)
799 ///
800 /// The calculated delay will never exceed this value (before jitter).
801 pub max_delay: Duration,
802
803 /// Jitter factor (0.0 - 1.0, default: 0.25 for 25%)
804 ///
805 /// Random jitter is added to prevent thundering herd effect.
806 /// The jitter range is [0, delay * jitter_factor].
807 pub jitter_factor: f64,
808
809 /// Multiplier for exponential growth (default: 2.0)
810 ///
811 /// Each retry delay is multiplied by this factor.
812 pub multiplier: f64,
813}
814
815impl Default for BackoffConfig {
816 fn default() -> Self {
817 Self {
818 base_delay: Duration::from_secs(1),
819 max_delay: Duration::from_secs(60),
820 jitter_factor: 0.25,
821 multiplier: 2.0,
822 }
823 }
824}
825
826/// Calculates retry delay with exponential backoff and jitter.
827///
828/// The backoff strategy uses the formula:
829/// `min(base_delay * multiplier^attempt, max_delay) + jitter`
830///
831/// Where jitter is a random value in the range [0, delay * jitter_factor].
832///
833/// # Example
834///
835/// ```
836/// use ccxt_core::ws_client::{BackoffConfig, BackoffStrategy};
837///
838/// let strategy = BackoffStrategy::new(BackoffConfig::default());
839///
840/// // First attempt (attempt = 0): ~1 second + jitter
841/// let delay0 = strategy.calculate_delay(0);
842///
843/// // Second attempt (attempt = 1): ~2 seconds + jitter
844/// let delay1 = strategy.calculate_delay(1);
845///
846/// // Third attempt (attempt = 2): ~4 seconds + jitter
847/// let delay2 = strategy.calculate_delay(2);
848/// ```
849#[derive(Debug, Clone)]
850pub struct BackoffStrategy {
851 config: BackoffConfig,
852}
853
854impl BackoffStrategy {
855 /// Creates a new backoff strategy with the given configuration.
856 ///
857 /// # Arguments
858 ///
859 /// * `config` - The backoff configuration to use
860 pub fn new(config: BackoffConfig) -> Self {
861 Self { config }
862 }
863
864 /// Creates a new backoff strategy with default configuration.
865 pub fn with_defaults() -> Self {
866 Self::new(BackoffConfig::default())
867 }
868
869 /// Returns a reference to the underlying configuration.
870 pub fn config(&self) -> &BackoffConfig {
871 &self.config
872 }
873
874 /// Calculates delay for the given attempt number.
875 ///
876 /// Formula: `min(base_delay * multiplier^attempt, max_delay) + jitter`
877 ///
878 /// # Arguments
879 ///
880 /// * `attempt` - The attempt number (0-indexed)
881 ///
882 /// # Returns
883 ///
884 /// The calculated delay duration including jitter.
885 ///
886 /// # Example
887 ///
888 /// ```
889 /// use ccxt_core::ws_client::{BackoffConfig, BackoffStrategy};
890 /// use std::time::Duration;
891 ///
892 /// let config = BackoffConfig {
893 /// base_delay: Duration::from_secs(1),
894 /// max_delay: Duration::from_secs(60),
895 /// jitter_factor: 0.0, // No jitter for predictable results
896 /// multiplier: 2.0,
897 /// };
898 /// let strategy = BackoffStrategy::new(config);
899 ///
900 /// // With no jitter, delays are exactly: 1s, 2s, 4s, 8s, ...
901 /// assert_eq!(strategy.calculate_delay(0), Duration::from_secs(1));
902 /// assert_eq!(strategy.calculate_delay(1), Duration::from_secs(2));
903 /// assert_eq!(strategy.calculate_delay(2), Duration::from_secs(4));
904 /// ```
905 #[allow(clippy::cast_precision_loss, clippy::cast_possible_truncation)]
906 pub fn calculate_delay(&self, attempt: u32) -> Duration {
907 let base_ms = self.config.base_delay.as_millis() as f64;
908 let multiplier = self.config.multiplier;
909 let max_ms = self.config.max_delay.as_millis() as f64;
910
911 // Calculate exponential delay: base_delay * multiplier^attempt
912 let exponential_delay_ms = base_ms * multiplier.powi(attempt as i32);
913
914 // Cap at max_delay
915 let capped_delay_ms = exponential_delay_ms.min(max_ms);
916
917 // Add jitter: random value in [0, delay * jitter_factor]
918 let jitter_ms = if self.config.jitter_factor > 0.0 {
919 let jitter_range = capped_delay_ms * self.config.jitter_factor;
920 rand::rng().random::<f64>() * jitter_range
921 } else {
922 0.0
923 };
924
925 Duration::from_millis((capped_delay_ms + jitter_ms) as u64)
926 }
927
928 /// Calculates the base delay (without jitter) for the given attempt number.
929 ///
930 /// This is useful for testing or when you need predictable delay values.
931 ///
932 /// # Arguments
933 ///
934 /// * `attempt` - The attempt number (0-indexed)
935 ///
936 /// # Returns
937 ///
938 /// The calculated delay duration without jitter.
939 #[allow(clippy::cast_precision_loss, clippy::cast_possible_truncation)]
940 pub fn calculate_delay_without_jitter(&self, attempt: u32) -> Duration {
941 let base_ms = self.config.base_delay.as_millis() as f64;
942 let multiplier = self.config.multiplier;
943 let max_ms = self.config.max_delay.as_millis() as f64;
944
945 // Calculate exponential delay: base_delay * multiplier^attempt
946 let exponential_delay_ms = base_ms * multiplier.powi(attempt as i32);
947
948 // Cap at max_delay
949 let capped_delay_ms = exponential_delay_ms.min(max_ms);
950
951 Duration::from_millis(capped_delay_ms as u64)
952 }
953}
954
955/// WebSocket connection state.
956///
957/// Uses `#[repr(u8)]` to enable atomic storage via `AtomicU8`.
958#[repr(u8)]
959#[derive(Debug, Clone, Copy, PartialEq, Eq)]
960pub enum WsConnectionState {
961 /// Not connected
962 Disconnected = 0,
963 /// Establishing connection
964 Connecting = 1,
965 /// Successfully connected
966 Connected = 2,
967 /// Attempting to reconnect
968 Reconnecting = 3,
969 /// Error state
970 Error = 4,
971}
972
973impl WsConnectionState {
974 /// Converts a `u8` value to `WsConnectionState`.
975 ///
976 /// # Arguments
977 ///
978 /// * `value` - The u8 value to convert
979 ///
980 /// # Returns
981 ///
982 /// The corresponding `WsConnectionState`, defaulting to `Error` for unknown values.
983 #[inline]
984 pub fn from_u8(value: u8) -> Self {
985 match value {
986 0 => Self::Disconnected,
987 1 => Self::Connecting,
988 2 => Self::Connected,
989 3 => Self::Reconnecting,
990 _ => Self::Error,
991 }
992 }
993
994 /// Converts the `WsConnectionState` to its `u8` representation.
995 #[inline]
996 pub fn as_u8(self) -> u8 {
997 self as u8
998 }
999}
1000
1001/// WebSocket message types for exchange communication.
1002#[derive(Debug, Clone, Serialize, Deserialize)]
1003#[serde(tag = "type", rename_all = "lowercase")]
1004pub enum WsMessage {
1005 /// Subscribe to a channel
1006 Subscribe {
1007 /// Channel name
1008 channel: String,
1009 /// Optional trading pair symbol
1010 symbol: Option<String>,
1011 /// Additional parameters
1012 params: Option<HashMap<String, Value>>,
1013 },
1014 /// Unsubscribe from a channel
1015 Unsubscribe {
1016 /// Channel name
1017 channel: String,
1018 /// Optional trading pair symbol
1019 symbol: Option<String>,
1020 },
1021 /// Ping message for keepalive
1022 Ping {
1023 /// Timestamp in milliseconds
1024 timestamp: i64,
1025 },
1026 /// Pong response to ping
1027 Pong {
1028 /// Timestamp in milliseconds
1029 timestamp: i64,
1030 },
1031 /// Authentication message
1032 Auth {
1033 /// API key
1034 api_key: String,
1035 /// HMAC signature
1036 signature: String,
1037 /// Timestamp in milliseconds
1038 timestamp: i64,
1039 },
1040 /// Custom message payload
1041 Custom(Value),
1042}
1043
1044/// WebSocket connection configuration.
1045///
1046/// This struct contains all configuration options for WebSocket connections,
1047/// including connection timeouts, reconnection behavior, and resource limits.
1048///
1049/// # Example
1050///
1051/// ```rust
1052/// use ccxt_core::ws_client::{WsConfig, BackoffConfig};
1053/// use std::time::Duration;
1054///
1055/// let config = WsConfig {
1056/// url: "wss://stream.example.com/ws".to_string(),
1057/// max_subscriptions: 50,
1058/// backoff_config: BackoffConfig {
1059/// base_delay: Duration::from_millis(500),
1060/// max_delay: Duration::from_secs(30),
1061/// ..Default::default()
1062/// },
1063/// ..Default::default()
1064/// };
1065/// ```
1066#[derive(Debug, Clone)]
1067pub struct WsConfig {
1068 /// WebSocket server URL
1069 pub url: String,
1070 /// Connection timeout in milliseconds
1071 pub connect_timeout: u64,
1072 /// Ping interval in milliseconds
1073 pub ping_interval: u64,
1074 /// Reconnection delay in milliseconds (legacy, use `backoff_config` for exponential backoff)
1075 pub reconnect_interval: u64,
1076 /// Maximum reconnection attempts before giving up
1077 pub max_reconnect_attempts: u32,
1078 /// Enable automatic reconnection on disconnect
1079 pub auto_reconnect: bool,
1080 /// Enable message compression
1081 pub enable_compression: bool,
1082 /// Pong timeout in milliseconds
1083 ///
1084 /// Connection is considered dead if no pong received within this duration.
1085 pub pong_timeout: u64,
1086 /// Exponential backoff configuration for reconnection.
1087 ///
1088 /// This configuration controls how retry delays are calculated during
1089 /// reconnection attempts. Uses exponential backoff with jitter to prevent
1090 /// thundering herd effects.
1091 ///
1092 /// # Default
1093 ///
1094 /// - `base_delay`: 1 second
1095 /// - `max_delay`: 60 seconds
1096 /// - `jitter_factor`: 0.25 (25%)
1097 /// - `multiplier`: 2.0
1098 pub backoff_config: BackoffConfig,
1099 /// Maximum number of subscriptions allowed.
1100 ///
1101 /// When this limit is reached, new subscription attempts will fail with
1102 /// `Error::ResourceExhausted`. This prevents resource exhaustion from
1103 /// too many concurrent subscriptions.
1104 ///
1105 /// # Default
1106 ///
1107 /// 100 subscriptions (see `DEFAULT_MAX_SUBSCRIPTIONS`)
1108 pub max_subscriptions: usize,
1109 /// Shutdown timeout in milliseconds.
1110 ///
1111 /// Maximum time to wait for pending operations to complete during
1112 /// graceful shutdown. After this timeout, the shutdown will proceed
1113 /// regardless of pending operations.
1114 ///
1115 /// # Default
1116 ///
1117 /// 5000 milliseconds (5 seconds)
1118 pub shutdown_timeout: u64,
1119}
1120
1121/// Default shutdown timeout in milliseconds.
1122pub const DEFAULT_SHUTDOWN_TIMEOUT: u64 = 5000;
1123
1124impl Default for WsConfig {
1125 fn default() -> Self {
1126 Self {
1127 url: String::new(),
1128 connect_timeout: 10000,
1129 ping_interval: 30000,
1130 reconnect_interval: 5000,
1131 max_reconnect_attempts: 5,
1132 auto_reconnect: true,
1133 enable_compression: false,
1134 pong_timeout: 90000,
1135 backoff_config: BackoffConfig::default(),
1136 max_subscriptions: DEFAULT_MAX_SUBSCRIPTIONS,
1137 shutdown_timeout: DEFAULT_SHUTDOWN_TIMEOUT,
1138 }
1139 }
1140}
1141
1142/// WebSocket subscription metadata.
1143#[derive(Debug, Clone)]
1144pub struct Subscription {
1145 channel: String,
1146 symbol: Option<String>,
1147 params: Option<HashMap<String, Value>>,
1148}
1149
1150/// Default maximum number of subscriptions.
1151pub const DEFAULT_MAX_SUBSCRIPTIONS: usize = 100;
1152
1153/// Subscription manager with capacity limits.
1154///
1155/// This struct manages WebSocket subscriptions with a configurable maximum limit
1156/// to prevent resource exhaustion. It uses `DashMap` for lock-free concurrent access.
1157///
1158/// # Example
1159///
1160/// ```rust
1161/// use ccxt_core::ws_client::SubscriptionManager;
1162///
1163/// let manager = SubscriptionManager::new(100);
1164/// assert_eq!(manager.count(), 0);
1165/// assert_eq!(manager.remaining_capacity(), 100);
1166/// assert_eq!(manager.max_subscriptions(), 100);
1167/// ```
1168#[derive(Debug)]
1169pub struct SubscriptionManager {
1170 /// Active subscriptions (lock-free)
1171 subscriptions: DashMap<String, Subscription>,
1172 /// Maximum allowed subscriptions
1173 max_subscriptions: usize,
1174}
1175
1176impl SubscriptionManager {
1177 /// Creates a new subscription manager with the specified maximum capacity.
1178 ///
1179 /// # Arguments
1180 ///
1181 /// * `max_subscriptions` - Maximum number of subscriptions allowed
1182 ///
1183 /// # Example
1184 ///
1185 /// ```rust
1186 /// use ccxt_core::ws_client::SubscriptionManager;
1187 ///
1188 /// let manager = SubscriptionManager::new(50);
1189 /// assert_eq!(manager.max_subscriptions(), 50);
1190 /// ```
1191 pub fn new(max_subscriptions: usize) -> Self {
1192 Self {
1193 subscriptions: DashMap::new(),
1194 max_subscriptions,
1195 }
1196 }
1197
1198 /// Creates a new subscription manager with the default maximum capacity (100).
1199 ///
1200 /// # Example
1201 ///
1202 /// ```rust
1203 /// use ccxt_core::ws_client::{SubscriptionManager, DEFAULT_MAX_SUBSCRIPTIONS};
1204 ///
1205 /// let manager = SubscriptionManager::with_default_capacity();
1206 /// assert_eq!(manager.max_subscriptions(), DEFAULT_MAX_SUBSCRIPTIONS);
1207 /// ```
1208 pub fn with_default_capacity() -> Self {
1209 Self::new(DEFAULT_MAX_SUBSCRIPTIONS)
1210 }
1211
1212 /// Returns the maximum number of subscriptions allowed.
1213 ///
1214 /// # Example
1215 ///
1216 /// ```rust
1217 /// use ccxt_core::ws_client::SubscriptionManager;
1218 ///
1219 /// let manager = SubscriptionManager::new(75);
1220 /// assert_eq!(manager.max_subscriptions(), 75);
1221 /// ```
1222 #[inline]
1223 #[must_use]
1224 pub fn max_subscriptions(&self) -> usize {
1225 self.max_subscriptions
1226 }
1227
1228 /// Attempts to add a subscription.
1229 ///
1230 /// Returns `Ok(())` if the subscription was added successfully, or
1231 /// `Err(Error::ResourceExhausted)` if the maximum capacity has been reached.
1232 ///
1233 /// If a subscription with the same key already exists, it will be replaced
1234 /// without counting against the capacity limit.
1235 ///
1236 /// # Arguments
1237 ///
1238 /// * `key` - Unique subscription key (typically "channel:symbol")
1239 /// * `subscription` - The subscription metadata
1240 ///
1241 /// # Errors
1242 ///
1243 /// Returns `Error::ResourceExhausted` if the subscription count has reached
1244 /// the maximum capacity and the key doesn't already exist.
1245 ///
1246 /// # Example
1247 ///
1248 /// ```rust
1249 /// use ccxt_core::ws_client::SubscriptionManager;
1250 ///
1251 /// let manager = SubscriptionManager::new(2);
1252 ///
1253 /// // First two subscriptions succeed
1254 /// // (Note: try_add requires internal Subscription type, this is conceptual)
1255 /// assert_eq!(manager.count(), 0);
1256 /// assert_eq!(manager.remaining_capacity(), 2);
1257 /// ```
1258 pub fn try_add(&self, key: String, subscription: Subscription) -> Result<()> {
1259 // If the key already exists, we're replacing, not adding
1260 if self.subscriptions.contains_key(&key) {
1261 self.subscriptions.insert(key, subscription);
1262 return Ok(());
1263 }
1264
1265 // Check capacity before adding new subscription
1266 if self.subscriptions.len() >= self.max_subscriptions {
1267 return Err(Error::resource_exhausted(format!(
1268 "Maximum subscriptions ({}) reached",
1269 self.max_subscriptions
1270 )));
1271 }
1272
1273 self.subscriptions.insert(key, subscription);
1274 Ok(())
1275 }
1276
1277 /// Removes a subscription by key.
1278 ///
1279 /// Returns the removed subscription if it existed, or `None` if not found.
1280 ///
1281 /// # Arguments
1282 ///
1283 /// * `key` - The subscription key to remove
1284 ///
1285 /// # Example
1286 ///
1287 /// ```rust
1288 /// use ccxt_core::ws_client::SubscriptionManager;
1289 ///
1290 /// let manager = SubscriptionManager::new(100);
1291 /// // After adding and removing a subscription:
1292 /// // let removed = manager.remove("ticker:BTC/USDT");
1293 /// ```
1294 pub fn remove(&self, key: &str) -> Option<Subscription> {
1295 self.subscriptions.remove(key).map(|(_, v)| v)
1296 }
1297
1298 /// Returns the current number of active subscriptions.
1299 ///
1300 /// This operation is lock-free and thread-safe.
1301 ///
1302 /// # Example
1303 ///
1304 /// ```rust
1305 /// use ccxt_core::ws_client::SubscriptionManager;
1306 ///
1307 /// let manager = SubscriptionManager::new(100);
1308 /// assert_eq!(manager.count(), 0);
1309 /// ```
1310 #[inline]
1311 #[must_use]
1312 pub fn count(&self) -> usize {
1313 self.subscriptions.len()
1314 }
1315
1316 /// Returns the remaining capacity for new subscriptions.
1317 ///
1318 /// This is calculated as `max_subscriptions - current_count`.
1319 ///
1320 /// # Example
1321 ///
1322 /// ```rust
1323 /// use ccxt_core::ws_client::SubscriptionManager;
1324 ///
1325 /// let manager = SubscriptionManager::new(100);
1326 /// assert_eq!(manager.remaining_capacity(), 100);
1327 /// ```
1328 #[inline]
1329 #[must_use]
1330 pub fn remaining_capacity(&self) -> usize {
1331 self.max_subscriptions
1332 .saturating_sub(self.subscriptions.len())
1333 }
1334
1335 /// Checks if a subscription exists for the given key.
1336 ///
1337 /// # Arguments
1338 ///
1339 /// * `key` - The subscription key to check
1340 ///
1341 /// # Example
1342 ///
1343 /// ```rust
1344 /// use ccxt_core::ws_client::SubscriptionManager;
1345 ///
1346 /// let manager = SubscriptionManager::new(100);
1347 /// assert!(!manager.contains("ticker:BTC/USDT"));
1348 /// ```
1349 #[inline]
1350 #[must_use]
1351 pub fn contains(&self, key: &str) -> bool {
1352 self.subscriptions.contains_key(key)
1353 }
1354
1355 /// Returns a reference to the subscription for the given key, if it exists.
1356 ///
1357 /// # Arguments
1358 ///
1359 /// * `key` - The subscription key to look up
1360 #[must_use]
1361 pub fn get(&self, key: &str) -> Option<dashmap::mapref::one::Ref<'_, String, Subscription>> {
1362 self.subscriptions.get(key)
1363 }
1364
1365 /// Clears all subscriptions.
1366 ///
1367 /// This removes all subscriptions from the manager, freeing all capacity.
1368 ///
1369 /// # Example
1370 ///
1371 /// ```rust
1372 /// use ccxt_core::ws_client::SubscriptionManager;
1373 ///
1374 /// let manager = SubscriptionManager::new(100);
1375 /// // After adding subscriptions:
1376 /// manager.clear();
1377 /// assert_eq!(manager.count(), 0);
1378 /// assert_eq!(manager.remaining_capacity(), 100);
1379 /// ```
1380 pub fn clear(&self) {
1381 self.subscriptions.clear();
1382 }
1383
1384 /// Returns an iterator over all subscriptions.
1385 ///
1386 /// The iterator yields `(key, subscription)` pairs.
1387 pub fn iter(
1388 &self,
1389 ) -> impl Iterator<Item = dashmap::mapref::multiple::RefMulti<'_, String, Subscription>> {
1390 self.subscriptions.iter()
1391 }
1392
1393 /// Collects all subscriptions into a vector.
1394 ///
1395 /// This is useful when you need to iterate over subscriptions while
1396 /// potentially modifying the manager.
1397 ///
1398 /// # Example
1399 ///
1400 /// ```rust
1401 /// use ccxt_core::ws_client::SubscriptionManager;
1402 ///
1403 /// let manager = SubscriptionManager::new(100);
1404 /// let subs = manager.collect_subscriptions();
1405 /// assert!(subs.is_empty());
1406 /// ```
1407 #[must_use]
1408 pub fn collect_subscriptions(&self) -> Vec<Subscription> {
1409 self.subscriptions
1410 .iter()
1411 .map(|entry| entry.value().clone())
1412 .collect()
1413 }
1414
1415 /// Checks if the manager is at full capacity.
1416 ///
1417 /// # Example
1418 ///
1419 /// ```rust
1420 /// use ccxt_core::ws_client::SubscriptionManager;
1421 ///
1422 /// let manager = SubscriptionManager::new(100);
1423 /// assert!(!manager.is_full());
1424 /// ```
1425 #[inline]
1426 #[must_use]
1427 pub fn is_full(&self) -> bool {
1428 self.subscriptions.len() >= self.max_subscriptions
1429 }
1430
1431 /// Checks if the manager has no subscriptions.
1432 ///
1433 /// # Example
1434 ///
1435 /// ```rust
1436 /// use ccxt_core::ws_client::SubscriptionManager;
1437 ///
1438 /// let manager = SubscriptionManager::new(100);
1439 /// assert!(manager.is_empty());
1440 /// ```
1441 #[inline]
1442 #[must_use]
1443 pub fn is_empty(&self) -> bool {
1444 self.subscriptions.is_empty()
1445 }
1446}
1447
1448impl Default for SubscriptionManager {
1449 fn default() -> Self {
1450 Self::with_default_capacity()
1451 }
1452}
1453
1454/// Type alias for WebSocket write half.
1455#[allow(dead_code)]
1456type WsWriter = SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>;
1457
1458/// Async WebSocket client for exchange streaming APIs.
1459pub struct WsClient {
1460 config: WsConfig,
1461 /// Connection state (atomic for lock-free reads)
1462 state: Arc<AtomicU8>,
1463 /// Subscription manager with capacity limits for lock-free concurrent access.
1464 ///
1465 /// Uses `SubscriptionManager` to enforce subscription limits and prevent
1466 /// resource exhaustion. The manager uses `DashMap` internally for lock-free
1467 /// concurrent access.
1468 subscription_manager: SubscriptionManager,
1469
1470 message_tx: mpsc::UnboundedSender<Value>,
1471 message_rx: Arc<RwLock<mpsc::UnboundedReceiver<Value>>>,
1472
1473 write_tx: Arc<Mutex<Option<mpsc::UnboundedSender<Message>>>>,
1474
1475 /// Reconnection counter (atomic for lock-free access)
1476 reconnect_count: AtomicU32,
1477
1478 shutdown_tx: Arc<Mutex<Option<mpsc::UnboundedSender<()>>>>,
1479
1480 /// Connection statistics (lock-free atomic access)
1481 stats: Arc<WsStats>,
1482
1483 /// Cancellation token for graceful shutdown and operation cancellation.
1484 ///
1485 /// This token can be used to cancel long-running operations like connect,
1486 /// reconnect, and subscribe. When cancelled, operations will return
1487 /// `Error::Cancelled`.
1488 cancel_token: Arc<Mutex<Option<CancellationToken>>>,
1489
1490 /// Optional event callback for connection lifecycle events.
1491 ///
1492 /// This callback is invoked for events like `Shutdown`, `Connected`, etc.
1493 /// The callback is stored as an `Arc` to allow sharing across tasks.
1494 event_callback: Arc<Mutex<Option<WsEventCallback>>>,
1495}
1496
1497/// WebSocket connection statistics (lock-free).
1498///
1499/// This struct uses atomic types for all fields to enable lock-free concurrent access.
1500/// This prevents potential deadlocks when stats are accessed across await points.
1501///
1502/// # Thread Safety
1503///
1504/// All operations on `WsStats` are thread-safe and lock-free. Multiple tasks can
1505/// read and update statistics concurrently without blocking.
1506///
1507/// # Example
1508///
1509/// ```rust
1510/// use ccxt_core::ws_client::WsStats;
1511///
1512/// let stats = WsStats::new();
1513///
1514/// // Record a received message
1515/// stats.record_received(1024);
1516///
1517/// // Get a snapshot of current stats
1518/// let snapshot = stats.snapshot();
1519/// assert_eq!(snapshot.messages_received, 1);
1520/// assert_eq!(snapshot.bytes_received, 1024);
1521/// ```
1522#[derive(Debug)]
1523pub struct WsStats {
1524 /// Total messages received
1525 messages_received: AtomicU64,
1526 /// Total messages sent
1527 messages_sent: AtomicU64,
1528 /// Total bytes received
1529 bytes_received: AtomicU64,
1530 /// Total bytes sent
1531 bytes_sent: AtomicU64,
1532 /// Last message timestamp in milliseconds
1533 last_message_time: AtomicI64,
1534 /// Last ping timestamp in milliseconds
1535 last_ping_time: AtomicI64,
1536 /// Last pong timestamp in milliseconds
1537 last_pong_time: AtomicI64,
1538 /// Connection established timestamp in milliseconds
1539 connected_at: AtomicI64,
1540 /// Number of reconnection attempts
1541 reconnect_attempts: AtomicU32,
1542}
1543
1544impl WsStats {
1545 /// Creates a new `WsStats` instance with all counters initialized to zero.
1546 ///
1547 /// # Example
1548 ///
1549 /// ```rust
1550 /// use ccxt_core::ws_client::WsStats;
1551 ///
1552 /// let stats = WsStats::new();
1553 /// let snapshot = stats.snapshot();
1554 /// assert_eq!(snapshot.messages_received, 0);
1555 /// ```
1556 pub fn new() -> Self {
1557 Self {
1558 messages_received: AtomicU64::new(0),
1559 messages_sent: AtomicU64::new(0),
1560 bytes_received: AtomicU64::new(0),
1561 bytes_sent: AtomicU64::new(0),
1562 last_message_time: AtomicI64::new(0),
1563 last_ping_time: AtomicI64::new(0),
1564 last_pong_time: AtomicI64::new(0),
1565 connected_at: AtomicI64::new(0),
1566 reconnect_attempts: AtomicU32::new(0),
1567 }
1568 }
1569
1570 /// Records a received message.
1571 ///
1572 /// Increments the message count, adds to bytes received, and updates
1573 /// the last message timestamp.
1574 ///
1575 /// # Arguments
1576 ///
1577 /// * `bytes` - Number of bytes received in the message
1578 ///
1579 /// # Example
1580 ///
1581 /// ```rust
1582 /// use ccxt_core::ws_client::WsStats;
1583 ///
1584 /// let stats = WsStats::new();
1585 /// stats.record_received(512);
1586 /// stats.record_received(256);
1587 ///
1588 /// let snapshot = stats.snapshot();
1589 /// assert_eq!(snapshot.messages_received, 2);
1590 /// assert_eq!(snapshot.bytes_received, 768);
1591 /// ```
1592 pub fn record_received(&self, bytes: u64) {
1593 self.messages_received.fetch_add(1, Ordering::Relaxed);
1594 self.bytes_received.fetch_add(bytes, Ordering::Relaxed);
1595 self.last_message_time
1596 .store(chrono::Utc::now().timestamp_millis(), Ordering::Relaxed);
1597 }
1598
1599 /// Records a sent message.
1600 ///
1601 /// Increments the message count and adds to bytes sent.
1602 ///
1603 /// # Arguments
1604 ///
1605 /// * `bytes` - Number of bytes sent in the message
1606 ///
1607 /// # Example
1608 ///
1609 /// ```rust
1610 /// use ccxt_core::ws_client::WsStats;
1611 ///
1612 /// let stats = WsStats::new();
1613 /// stats.record_sent(128);
1614 ///
1615 /// let snapshot = stats.snapshot();
1616 /// assert_eq!(snapshot.messages_sent, 1);
1617 /// assert_eq!(snapshot.bytes_sent, 128);
1618 /// ```
1619 pub fn record_sent(&self, bytes: u64) {
1620 self.messages_sent.fetch_add(1, Ordering::Relaxed);
1621 self.bytes_sent.fetch_add(bytes, Ordering::Relaxed);
1622 }
1623
1624 /// Records a ping sent.
1625 ///
1626 /// Updates the last ping timestamp.
1627 ///
1628 /// # Example
1629 ///
1630 /// ```rust
1631 /// use ccxt_core::ws_client::WsStats;
1632 ///
1633 /// let stats = WsStats::new();
1634 /// stats.record_ping();
1635 ///
1636 /// let snapshot = stats.snapshot();
1637 /// assert!(snapshot.last_ping_time > 0);
1638 /// ```
1639 pub fn record_ping(&self) {
1640 self.last_ping_time
1641 .store(chrono::Utc::now().timestamp_millis(), Ordering::Relaxed);
1642 }
1643
1644 /// Records a pong received.
1645 ///
1646 /// Updates the last pong timestamp.
1647 ///
1648 /// # Example
1649 ///
1650 /// ```rust
1651 /// use ccxt_core::ws_client::WsStats;
1652 ///
1653 /// let stats = WsStats::new();
1654 /// stats.record_pong();
1655 ///
1656 /// let snapshot = stats.snapshot();
1657 /// assert!(snapshot.last_pong_time > 0);
1658 /// ```
1659 pub fn record_pong(&self) {
1660 self.last_pong_time
1661 .store(chrono::Utc::now().timestamp_millis(), Ordering::Relaxed);
1662 }
1663
1664 /// Records a connection established.
1665 ///
1666 /// Updates the connected_at timestamp.
1667 ///
1668 /// # Example
1669 ///
1670 /// ```rust
1671 /// use ccxt_core::ws_client::WsStats;
1672 ///
1673 /// let stats = WsStats::new();
1674 /// stats.record_connected();
1675 ///
1676 /// let snapshot = stats.snapshot();
1677 /// assert!(snapshot.connected_at > 0);
1678 /// ```
1679 pub fn record_connected(&self) {
1680 self.connected_at
1681 .store(chrono::Utc::now().timestamp_millis(), Ordering::Relaxed);
1682 }
1683
1684 /// Increments the reconnection attempt counter.
1685 ///
1686 /// # Returns
1687 ///
1688 /// The new reconnection attempt count after incrementing.
1689 ///
1690 /// # Example
1691 ///
1692 /// ```rust
1693 /// use ccxt_core::ws_client::WsStats;
1694 ///
1695 /// let stats = WsStats::new();
1696 /// let count = stats.increment_reconnect_attempts();
1697 /// assert_eq!(count, 1);
1698 ///
1699 /// let count = stats.increment_reconnect_attempts();
1700 /// assert_eq!(count, 2);
1701 /// ```
1702 pub fn increment_reconnect_attempts(&self) -> u32 {
1703 self.reconnect_attempts.fetch_add(1, Ordering::Relaxed) + 1
1704 }
1705
1706 /// Resets the reconnection attempt counter to zero.
1707 ///
1708 /// # Example
1709 ///
1710 /// ```rust
1711 /// use ccxt_core::ws_client::WsStats;
1712 ///
1713 /// let stats = WsStats::new();
1714 /// stats.increment_reconnect_attempts();
1715 /// stats.increment_reconnect_attempts();
1716 ///
1717 /// stats.reset_reconnect_attempts();
1718 ///
1719 /// let snapshot = stats.snapshot();
1720 /// assert_eq!(snapshot.reconnect_attempts, 0);
1721 /// ```
1722 pub fn reset_reconnect_attempts(&self) {
1723 self.reconnect_attempts.store(0, Ordering::Relaxed);
1724 }
1725
1726 /// Returns the last pong timestamp.
1727 ///
1728 /// This is useful for checking connection health without creating a full snapshot.
1729 ///
1730 /// # Returns
1731 ///
1732 /// The last pong timestamp in milliseconds, or 0 if no pong has been received.
1733 ///
1734 /// # Example
1735 ///
1736 /// ```rust
1737 /// use ccxt_core::ws_client::WsStats;
1738 ///
1739 /// let stats = WsStats::new();
1740 /// assert_eq!(stats.last_pong_time(), 0);
1741 ///
1742 /// stats.record_pong();
1743 /// assert!(stats.last_pong_time() > 0);
1744 /// ```
1745 pub fn last_pong_time(&self) -> i64 {
1746 self.last_pong_time.load(Ordering::Relaxed)
1747 }
1748
1749 /// Returns the last ping timestamp.
1750 ///
1751 /// This is useful for calculating latency without creating a full snapshot.
1752 ///
1753 /// # Returns
1754 ///
1755 /// The last ping timestamp in milliseconds, or 0 if no ping has been sent.
1756 pub fn last_ping_time(&self) -> i64 {
1757 self.last_ping_time.load(Ordering::Relaxed)
1758 }
1759
1760 /// Creates an immutable snapshot of current statistics.
1761 ///
1762 /// The snapshot captures all statistics at a point in time. Note that since
1763 /// each field is read independently, the snapshot may not represent a perfectly
1764 /// consistent state if updates are happening concurrently. However, this is
1765 /// acceptable for statistics purposes.
1766 ///
1767 /// # Returns
1768 ///
1769 /// A `WsStatsSnapshot` containing the current values of all statistics.
1770 ///
1771 /// # Example
1772 ///
1773 /// ```rust
1774 /// use ccxt_core::ws_client::WsStats;
1775 ///
1776 /// let stats = WsStats::new();
1777 /// stats.record_received(100);
1778 /// stats.record_sent(50);
1779 ///
1780 /// let snapshot = stats.snapshot();
1781 /// assert_eq!(snapshot.messages_received, 1);
1782 /// assert_eq!(snapshot.bytes_received, 100);
1783 /// assert_eq!(snapshot.messages_sent, 1);
1784 /// assert_eq!(snapshot.bytes_sent, 50);
1785 /// ```
1786 pub fn snapshot(&self) -> WsStatsSnapshot {
1787 WsStatsSnapshot {
1788 messages_received: self.messages_received.load(Ordering::Relaxed),
1789 messages_sent: self.messages_sent.load(Ordering::Relaxed),
1790 bytes_received: self.bytes_received.load(Ordering::Relaxed),
1791 bytes_sent: self.bytes_sent.load(Ordering::Relaxed),
1792 last_message_time: self.last_message_time.load(Ordering::Relaxed),
1793 last_ping_time: self.last_ping_time.load(Ordering::Relaxed),
1794 last_pong_time: self.last_pong_time.load(Ordering::Relaxed),
1795 connected_at: self.connected_at.load(Ordering::Relaxed),
1796 reconnect_attempts: self.reconnect_attempts.load(Ordering::Relaxed),
1797 }
1798 }
1799
1800 /// Resets all statistics to their default values.
1801 ///
1802 /// # Example
1803 ///
1804 /// ```rust
1805 /// use ccxt_core::ws_client::WsStats;
1806 ///
1807 /// let stats = WsStats::new();
1808 /// stats.record_received(100);
1809 /// stats.record_sent(50);
1810 ///
1811 /// stats.reset();
1812 ///
1813 /// let snapshot = stats.snapshot();
1814 /// assert_eq!(snapshot.messages_received, 0);
1815 /// assert_eq!(snapshot.bytes_received, 0);
1816 /// ```
1817 pub fn reset(&self) {
1818 self.messages_received.store(0, Ordering::Relaxed);
1819 self.messages_sent.store(0, Ordering::Relaxed);
1820 self.bytes_received.store(0, Ordering::Relaxed);
1821 self.bytes_sent.store(0, Ordering::Relaxed);
1822 self.last_message_time.store(0, Ordering::Relaxed);
1823 self.last_ping_time.store(0, Ordering::Relaxed);
1824 self.last_pong_time.store(0, Ordering::Relaxed);
1825 self.connected_at.store(0, Ordering::Relaxed);
1826 self.reconnect_attempts.store(0, Ordering::Relaxed);
1827 }
1828}
1829
1830impl Default for WsStats {
1831 fn default() -> Self {
1832 Self::new()
1833 }
1834}
1835
1836/// Immutable snapshot of WebSocket connection statistics.
1837///
1838/// This struct provides a point-in-time view of the connection statistics.
1839/// It is created by calling `WsStats::snapshot()` and can be safely cloned
1840/// and passed around without affecting the underlying atomic counters.
1841///
1842/// # Example
1843///
1844/// ```rust
1845/// use ccxt_core::ws_client::WsStats;
1846///
1847/// let stats = WsStats::new();
1848/// stats.record_received(1024);
1849///
1850/// let snapshot = stats.snapshot();
1851/// let snapshot_clone = snapshot.clone();
1852///
1853/// assert_eq!(snapshot.messages_received, snapshot_clone.messages_received);
1854/// ```
1855#[derive(Debug, Clone, Default)]
1856pub struct WsStatsSnapshot {
1857 /// Total messages received
1858 pub messages_received: u64,
1859 /// Total messages sent
1860 pub messages_sent: u64,
1861 /// Total bytes received
1862 pub bytes_received: u64,
1863 /// Total bytes sent
1864 pub bytes_sent: u64,
1865 /// Last message timestamp in milliseconds
1866 pub last_message_time: i64,
1867 /// Last ping timestamp in milliseconds
1868 pub last_ping_time: i64,
1869 /// Last pong timestamp in milliseconds
1870 pub last_pong_time: i64,
1871 /// Connection established timestamp in milliseconds
1872 pub connected_at: i64,
1873 /// Number of reconnection attempts
1874 pub reconnect_attempts: u32,
1875}
1876
1877impl WsClient {
1878 /// Creates a new WebSocket client instance.
1879 ///
1880 /// # Arguments
1881 ///
1882 /// * `config` - WebSocket connection configuration
1883 ///
1884 /// # Returns
1885 ///
1886 /// A new `WsClient` instance ready to connect
1887 pub fn new(config: WsConfig) -> Self {
1888 let (message_tx, message_rx) = mpsc::unbounded_channel();
1889 let max_subscriptions = config.max_subscriptions;
1890
1891 Self {
1892 config,
1893 state: Arc::new(AtomicU8::new(WsConnectionState::Disconnected.as_u8())),
1894 subscription_manager: SubscriptionManager::new(max_subscriptions),
1895 message_tx,
1896 message_rx: Arc::new(RwLock::new(message_rx)),
1897 write_tx: Arc::new(Mutex::new(None)),
1898 reconnect_count: AtomicU32::new(0),
1899 shutdown_tx: Arc::new(Mutex::new(None)),
1900 stats: Arc::new(WsStats::new()),
1901 cancel_token: Arc::new(Mutex::new(None)),
1902 event_callback: Arc::new(Mutex::new(None)),
1903 }
1904 }
1905
1906 /// Sets the event callback for connection lifecycle events.
1907 ///
1908 /// The callback will be invoked for events like `Shutdown`, `Connected`, etc.
1909 /// This allows the application to react to connection state changes.
1910 ///
1911 /// # Arguments
1912 ///
1913 /// * `callback` - The event callback function
1914 ///
1915 /// # Example
1916 ///
1917 /// ```rust,ignore
1918 /// use ccxt_core::ws_client::{WsClient, WsConfig, WsEvent};
1919 /// use std::sync::Arc;
1920 ///
1921 /// let client = WsClient::new(WsConfig::default());
1922 ///
1923 /// client.set_event_callback(Arc::new(|event| {
1924 /// match event {
1925 /// WsEvent::Shutdown => println!("Client shutdown"),
1926 /// WsEvent::Connected => println!("Client connected"),
1927 /// _ => {}
1928 /// }
1929 /// })).await;
1930 /// ```
1931 pub async fn set_event_callback(&self, callback: WsEventCallback) {
1932 *self.event_callback.lock().await = Some(callback);
1933 debug!("Event callback set");
1934 }
1935
1936 /// Clears the event callback.
1937 ///
1938 /// After calling this method, no events will be emitted to the callback.
1939 pub async fn clear_event_callback(&self) {
1940 *self.event_callback.lock().await = None;
1941 debug!("Event callback cleared");
1942 }
1943
1944 /// Emits an event to the registered callback (if any).
1945 ///
1946 /// This method is used internally to notify the application about
1947 /// connection lifecycle events.
1948 ///
1949 /// # Arguments
1950 ///
1951 /// * `event` - The event to emit
1952 async fn emit_event(&self, event: WsEvent) {
1953 let callback = self.event_callback.lock().await;
1954 if let Some(ref cb) = *callback {
1955 // Clone the callback to release the lock before invoking
1956 let cb = Arc::clone(cb);
1957 drop(callback);
1958 // Invoke callback asynchronously to avoid blocking
1959 tokio::spawn(async move {
1960 cb(event);
1961 });
1962 }
1963 }
1964
1965 /// Sets the cancellation token for this client.
1966 ///
1967 /// The cancellation token can be used to cancel long-running operations
1968 /// like connect, reconnect, and subscribe. When the token is cancelled,
1969 /// these operations will return `Error::Cancelled`.
1970 ///
1971 /// # Arguments
1972 ///
1973 /// * `token` - The cancellation token to use
1974 ///
1975 /// # Example
1976 ///
1977 /// ```rust,ignore
1978 /// use ccxt_core::ws_client::{WsClient, WsConfig};
1979 /// use tokio_util::sync::CancellationToken;
1980 ///
1981 /// let client = WsClient::new(WsConfig::default());
1982 /// let token = CancellationToken::new();
1983 ///
1984 /// // Set the cancellation token
1985 /// client.set_cancel_token(token.clone()).await;
1986 ///
1987 /// // Later, cancel all operations
1988 /// token.cancel();
1989 /// ```
1990 pub async fn set_cancel_token(&self, token: CancellationToken) {
1991 *self.cancel_token.lock().await = Some(token);
1992 debug!("Cancellation token set");
1993 }
1994
1995 /// Clears the cancellation token.
1996 ///
1997 /// After calling this method, operations will no longer be cancellable
1998 /// via the previously set token.
1999 pub async fn clear_cancel_token(&self) {
2000 *self.cancel_token.lock().await = None;
2001 debug!("Cancellation token cleared");
2002 }
2003
2004 /// Returns a clone of the current cancellation token, if set.
2005 ///
2006 /// This is useful for sharing the token with other components or
2007 /// for checking if a token is currently set.
2008 ///
2009 /// # Returns
2010 ///
2011 /// `Some(CancellationToken)` if a token is set, `None` otherwise.
2012 pub async fn get_cancel_token(&self) -> Option<CancellationToken> {
2013 self.cancel_token.lock().await.clone()
2014 }
2015
2016 /// Establishes connection to the WebSocket server.
2017 ///
2018 /// Returns immediately if already connected. Automatically starts message
2019 /// processing loop and resubscribes to previous channels on success.
2020 ///
2021 /// # Errors
2022 ///
2023 /// Returns error if:
2024 /// - Connection timeout exceeded
2025 /// - Network error occurs
2026 /// - Server rejects connection
2027 #[instrument(
2028 name = "ws_connect",
2029 skip(self),
2030 fields(url = %self.config.url, timeout_ms = self.config.connect_timeout)
2031 )]
2032 pub async fn connect(&self) -> Result<()> {
2033 // Lock-free state check
2034 if self.state() == WsConnectionState::Connected {
2035 info!("WebSocket already connected");
2036 return Ok(());
2037 }
2038
2039 // Lock-free state update
2040 self.set_state(WsConnectionState::Connecting);
2041
2042 let url = self.config.url.clone();
2043 info!("Initiating WebSocket connection");
2044
2045 match tokio::time::timeout(
2046 Duration::from_millis(self.config.connect_timeout),
2047 connect_async(&url),
2048 )
2049 .await
2050 {
2051 Ok(Ok((ws_stream, response))) => {
2052 info!(
2053 status = response.status().as_u16(),
2054 "WebSocket connection established successfully"
2055 );
2056
2057 self.set_state(WsConnectionState::Connected);
2058 // Lock-free reconnect count reset
2059 self.reconnect_count.store(0, Ordering::Release);
2060
2061 // Lock-free stats update
2062 self.stats.record_connected();
2063
2064 self.start_message_loop(ws_stream).await;
2065
2066 self.resubscribe_all().await?;
2067
2068 Ok(())
2069 }
2070 Ok(Err(e)) => {
2071 error!(
2072 error = %e,
2073 error_debug = ?e,
2074 "WebSocket connection failed"
2075 );
2076 self.set_state(WsConnectionState::Error);
2077 Err(Error::network(format!("WebSocket connection failed: {e}")))
2078 }
2079 Err(_) => {
2080 error!(
2081 timeout_ms = self.config.connect_timeout,
2082 "WebSocket connection timeout exceeded"
2083 );
2084 self.set_state(WsConnectionState::Error);
2085 Err(Error::timeout("WebSocket connection timeout"))
2086 }
2087 }
2088 }
2089
2090 /// Establishes connection to the WebSocket server with cancellation support.
2091 ///
2092 /// This method is similar to [`connect`](Self::connect), but accepts an optional
2093 /// `CancellationToken` that can be used to cancel the connection attempt.
2094 ///
2095 /// If no token is provided, the method will use the client's internal token
2096 /// (if set via [`set_cancel_token`](Self::set_cancel_token)).
2097 ///
2098 /// # Arguments
2099 ///
2100 /// * `cancel_token` - Optional cancellation token. If `None`, uses the client's
2101 /// internal token (if set).
2102 ///
2103 /// # Errors
2104 ///
2105 /// Returns error if:
2106 /// - Connection timeout exceeded
2107 /// - Network error occurs
2108 /// - Server rejects connection
2109 /// - Operation was cancelled via the cancellation token
2110 ///
2111 /// # Example
2112 ///
2113 /// ```rust,ignore
2114 /// use ccxt_core::ws_client::{WsClient, WsConfig};
2115 /// use tokio_util::sync::CancellationToken;
2116 ///
2117 /// let client = WsClient::new(WsConfig {
2118 /// url: "wss://stream.example.com/ws".to_string(),
2119 /// ..Default::default()
2120 /// });
2121 ///
2122 /// let token = CancellationToken::new();
2123 /// let token_clone = token.clone();
2124 ///
2125 /// // Spawn a task to cancel after 5 seconds
2126 /// tokio::spawn(async move {
2127 /// tokio::time::sleep(Duration::from_secs(5)).await;
2128 /// token_clone.cancel();
2129 /// });
2130 ///
2131 /// // Connect with cancellation support
2132 /// match client.connect_with_cancel(Some(token)).await {
2133 /// Ok(()) => println!("Connected!"),
2134 /// Err(e) if e.as_cancelled().is_some() => println!("Connection cancelled"),
2135 /// Err(e) => println!("Connection failed: {}", e),
2136 /// }
2137 /// ```
2138 #[instrument(
2139 name = "ws_connect_with_cancel",
2140 skip(self, cancel_token),
2141 fields(url = %self.config.url, timeout_ms = self.config.connect_timeout)
2142 )]
2143 pub async fn connect_with_cancel(&self, cancel_token: Option<CancellationToken>) -> Result<()> {
2144 // Use provided token, or fall back to client's internal token, or create a new one
2145 let token = if let Some(t) = cancel_token {
2146 t
2147 } else {
2148 let internal_token = self.cancel_token.lock().await;
2149 internal_token
2150 .clone()
2151 .unwrap_or_else(CancellationToken::new)
2152 };
2153
2154 // Lock-free state check
2155 if self.state() == WsConnectionState::Connected {
2156 info!("WebSocket already connected");
2157 return Ok(());
2158 }
2159
2160 // Lock-free state update
2161 self.set_state(WsConnectionState::Connecting);
2162
2163 let url = self.config.url.clone();
2164 info!("Initiating WebSocket connection with cancellation support");
2165
2166 // Use tokio::select! to race between connection and cancellation
2167 tokio::select! {
2168 biased;
2169
2170 // Check for cancellation first
2171 () = token.cancelled() => {
2172 warn!("WebSocket connection cancelled");
2173 self.set_state(WsConnectionState::Disconnected);
2174 Err(Error::cancelled("WebSocket connection cancelled"))
2175 }
2176
2177 // Attempt connection with timeout
2178 result = tokio::time::timeout(
2179 Duration::from_millis(self.config.connect_timeout),
2180 connect_async(&url),
2181 ) => {
2182 match result {
2183 Ok(Ok((ws_stream, response))) => {
2184 info!(
2185 status = response.status().as_u16(),
2186 "WebSocket connection established successfully"
2187 );
2188
2189 self.set_state(WsConnectionState::Connected);
2190 // Lock-free reconnect count reset
2191 self.reconnect_count.store(0, Ordering::Release);
2192
2193 // Lock-free stats update
2194 self.stats.record_connected();
2195
2196 self.start_message_loop(ws_stream).await;
2197
2198 self.resubscribe_all().await?;
2199
2200 Ok(())
2201 }
2202 Ok(Err(e)) => {
2203 error!(
2204 error = %e,
2205 error_debug = ?e,
2206 "WebSocket connection failed"
2207 );
2208 self.set_state(WsConnectionState::Error);
2209 Err(Error::network(format!("WebSocket connection failed: {e}")))
2210 }
2211 Err(_) => {
2212 error!(
2213 timeout_ms = self.config.connect_timeout,
2214 "WebSocket connection timeout exceeded"
2215 );
2216 self.set_state(WsConnectionState::Error);
2217 Err(Error::timeout("WebSocket connection timeout"))
2218 }
2219 }
2220 }
2221 }
2222 }
2223
2224 /// Closes the WebSocket connection gracefully.
2225 ///
2226 /// Sends shutdown signal to background tasks and clears internal state.
2227 #[instrument(name = "ws_disconnect", skip(self))]
2228 pub async fn disconnect(&self) -> Result<()> {
2229 info!("Initiating WebSocket disconnect");
2230
2231 if let Some(tx) = self.shutdown_tx.lock().await.as_ref() {
2232 let _ = tx.send(());
2233 debug!("Shutdown signal sent to background tasks");
2234 }
2235
2236 *self.write_tx.lock().await = None;
2237
2238 // Lock-free state update
2239 self.set_state(WsConnectionState::Disconnected);
2240
2241 info!("WebSocket disconnected successfully");
2242 Ok(())
2243 }
2244
2245 /// Gracefully shuts down the WebSocket client.
2246 ///
2247 /// This method performs a complete shutdown of the WebSocket client:
2248 /// 1. Cancels all pending reconnection attempts
2249 /// 2. Sends WebSocket close frame to the server
2250 /// 3. Waits for pending operations to complete (with timeout)
2251 /// 4. Clears all resources (subscriptions, channels, etc.)
2252 /// 5. Emits a `Shutdown` event
2253 ///
2254 /// # Behavior
2255 ///
2256 /// - If a cancellation token is set, it will be cancelled to stop any
2257 /// in-progress reconnection attempts.
2258 /// - The method waits up to `shutdown_timeout` milliseconds for pending
2259 /// operations to complete before proceeding with cleanup.
2260 /// - All subscriptions are cleared during shutdown.
2261 /// - The connection state is set to `Disconnected`.
2262 ///
2263 /// # Errors
2264 ///
2265 /// This method does not return errors. All cleanup operations are
2266 /// performed on a best-effort basis.
2267 ///
2268 /// # Example
2269 ///
2270 /// ```rust,ignore
2271 /// use ccxt_core::ws_client::{WsClient, WsConfig, WsEvent};
2272 /// use std::sync::Arc;
2273 ///
2274 /// let client = WsClient::new(WsConfig {
2275 /// url: "wss://stream.example.com/ws".to_string(),
2276 /// shutdown_timeout: 5000, // 5 seconds
2277 /// ..Default::default()
2278 /// });
2279 ///
2280 /// // Set up event callback to know when shutdown completes
2281 /// client.set_event_callback(Arc::new(|event| {
2282 /// if let WsEvent::Shutdown = event {
2283 /// println!("Shutdown completed");
2284 /// }
2285 /// })).await;
2286 ///
2287 /// // Connect and do work...
2288 /// client.connect().await?;
2289 ///
2290 /// // Gracefully shutdown
2291 /// client.shutdown().await;
2292 /// ```
2293 #[instrument(name = "ws_shutdown", skip(self), fields(timeout_ms = self.config.shutdown_timeout))]
2294 pub async fn shutdown(&self) {
2295 info!("Initiating graceful shutdown");
2296
2297 // 1. Cancel any pending reconnection attempts (Requirements 7.2)
2298 {
2299 let token_guard = self.cancel_token.lock().await;
2300 if let Some(ref token) = *token_guard {
2301 info!("Cancelling pending reconnection attempts");
2302 token.cancel();
2303 }
2304 }
2305
2306 // 2. Set state to disconnected to prevent new operations
2307 self.set_state(WsConnectionState::Disconnected);
2308
2309 // 3. Send WebSocket close frame (Requirements 7.3)
2310 {
2311 let write_tx_guard = self.write_tx.lock().await;
2312 if let Some(ref tx) = *write_tx_guard {
2313 info!("Sending WebSocket close frame");
2314 // Send close frame - ignore errors as connection may already be closed
2315 let _ = tx.send(Message::Close(None));
2316 }
2317 }
2318
2319 // 4. Wait for pending operations to complete with timeout (Requirements 7.4)
2320 let shutdown_timeout = Duration::from_millis(self.config.shutdown_timeout);
2321 let shutdown_result = tokio::time::timeout(shutdown_timeout, async {
2322 // Signal background tasks to stop
2323 if let Some(tx) = self.shutdown_tx.lock().await.as_ref() {
2324 let _ = tx.send(());
2325 debug!("Shutdown signal sent to background tasks");
2326 }
2327
2328 // Give tasks a moment to process the shutdown signal
2329 tokio::time::sleep(Duration::from_millis(100)).await;
2330 })
2331 .await;
2332
2333 match shutdown_result {
2334 Ok(()) => {
2335 debug!("Pending operations completed within timeout");
2336 }
2337 Err(_) => {
2338 warn!(
2339 timeout_ms = self.config.shutdown_timeout,
2340 "Shutdown timeout exceeded, proceeding with cleanup"
2341 );
2342 }
2343 }
2344
2345 // 5. Clear resources (Requirements 7.1)
2346 {
2347 // Clear write channel
2348 *self.write_tx.lock().await = None;
2349
2350 // Clear shutdown channel
2351 *self.shutdown_tx.lock().await = None;
2352
2353 // Clear subscriptions using SubscriptionManager
2354 self.subscription_manager.clear();
2355 debug!("Subscriptions cleared");
2356
2357 // Reset reconnect count
2358 self.reconnect_count.store(0, Ordering::Release);
2359
2360 // Reset stats
2361 self.stats.reset();
2362 debug!("Statistics reset");
2363 }
2364
2365 info!("Shutdown cleanup completed");
2366
2367 // 6. Emit Shutdown event (Requirements 7.5)
2368 self.emit_event(WsEvent::Shutdown).await;
2369
2370 info!("Graceful shutdown completed");
2371 }
2372
2373 /// Attempts to reconnect to the WebSocket server.
2374 ///
2375 /// Respects `max_reconnect_attempts` configuration and waits for
2376 /// `reconnect_interval` before attempting connection.
2377 ///
2378 /// # Errors
2379 ///
2380 /// Returns error if maximum reconnection attempts exceeded or connection fails.
2381 #[instrument(
2382 name = "ws_reconnect",
2383 skip(self),
2384 fields(
2385 max_attempts = self.config.max_reconnect_attempts,
2386 reconnect_interval_ms = self.config.reconnect_interval
2387 )
2388 )]
2389 pub async fn reconnect(&self) -> Result<()> {
2390 // Lock-free atomic increment and check
2391 let count = self.reconnect_count.fetch_add(1, Ordering::AcqRel) + 1;
2392
2393 if count > self.config.max_reconnect_attempts {
2394 error!(
2395 attempts = count,
2396 max = self.config.max_reconnect_attempts,
2397 "Max reconnect attempts reached, giving up"
2398 );
2399 return Err(Error::network("Max reconnect attempts reached"));
2400 }
2401
2402 warn!(
2403 attempt = count,
2404 max = self.config.max_reconnect_attempts,
2405 delay_ms = self.config.reconnect_interval,
2406 "Attempting WebSocket reconnection"
2407 );
2408
2409 // Lock-free state update
2410 self.set_state(WsConnectionState::Reconnecting);
2411
2412 tokio::time::sleep(Duration::from_millis(self.config.reconnect_interval)).await;
2413
2414 self.connect().await
2415 }
2416
2417 /// Attempts to reconnect to the WebSocket server with cancellation support.
2418 ///
2419 /// This method uses exponential backoff for retry delays and classifies errors
2420 /// to determine if reconnection should continue. It supports cancellation via
2421 /// an optional `CancellationToken`.
2422 ///
2423 /// # Arguments
2424 ///
2425 /// * `cancel_token` - Optional cancellation token. If `None`, uses the client's
2426 /// internal token (if set via [`set_cancel_token`](Self::set_cancel_token)).
2427 ///
2428 /// # Behavior
2429 ///
2430 /// 1. Calculates retry delay using exponential backoff strategy
2431 /// 2. Waits for the calculated delay (can be cancelled)
2432 /// 3. Attempts to connect (can be cancelled)
2433 /// 4. On success: resets reconnect counter and returns `Ok(())`
2434 /// 5. On transient error: continues retry loop
2435 /// 6. On permanent error: stops retrying and returns error
2436 /// 7. On max attempts reached: returns error
2437 ///
2438 /// # Errors
2439 ///
2440 /// Returns error if:
2441 /// - Maximum reconnection attempts exceeded
2442 /// - Permanent error occurs (authentication failure, protocol error)
2443 /// - Operation was cancelled via the cancellation token
2444 ///
2445 /// # Example
2446 ///
2447 /// ```rust,ignore
2448 /// use ccxt_core::ws_client::{WsClient, WsConfig};
2449 /// use tokio_util::sync::CancellationToken;
2450 ///
2451 /// let client = WsClient::new(WsConfig {
2452 /// url: "wss://stream.example.com/ws".to_string(),
2453 /// max_reconnect_attempts: 5,
2454 /// ..Default::default()
2455 /// });
2456 ///
2457 /// let token = CancellationToken::new();
2458 ///
2459 /// // Reconnect with cancellation support
2460 /// match client.reconnect_with_cancel(Some(token)).await {
2461 /// Ok(()) => println!("Reconnected!"),
2462 /// Err(e) if e.as_cancelled().is_some() => println!("Reconnection cancelled"),
2463 /// Err(e) => println!("Reconnection failed: {}", e),
2464 /// }
2465 /// ```
2466 #[instrument(
2467 name = "ws_reconnect_with_cancel",
2468 skip(self, cancel_token),
2469 fields(
2470 max_attempts = self.config.max_reconnect_attempts,
2471 )
2472 )]
2473 pub async fn reconnect_with_cancel(
2474 &self,
2475 cancel_token: Option<CancellationToken>,
2476 ) -> Result<()> {
2477 // Use provided token, or fall back to client's internal token, or create a new one
2478 let token = if let Some(t) = cancel_token {
2479 t
2480 } else {
2481 let internal_token = self.cancel_token.lock().await;
2482 internal_token
2483 .clone()
2484 .unwrap_or_else(CancellationToken::new)
2485 };
2486
2487 // Create backoff strategy from config
2488 let backoff = BackoffStrategy::new(self.config.backoff_config.clone());
2489
2490 // Lock-free state update
2491 self.set_state(WsConnectionState::Reconnecting);
2492
2493 loop {
2494 // Check for cancellation before each attempt
2495 if token.is_cancelled() {
2496 warn!("Reconnection cancelled before attempt");
2497 self.set_state(WsConnectionState::Disconnected);
2498 return Err(Error::cancelled("Reconnection cancelled"));
2499 }
2500
2501 // Lock-free atomic increment and check
2502 let attempt = self.reconnect_count.fetch_add(1, Ordering::AcqRel);
2503
2504 if attempt >= self.config.max_reconnect_attempts {
2505 error!(
2506 attempts = attempt + 1,
2507 max = self.config.max_reconnect_attempts,
2508 "Max reconnect attempts reached, giving up"
2509 );
2510 self.set_state(WsConnectionState::Error);
2511 return Err(Error::network(format!(
2512 "Max reconnect attempts ({}) reached",
2513 self.config.max_reconnect_attempts
2514 )));
2515 }
2516
2517 // Calculate delay using backoff strategy
2518 let delay = backoff.calculate_delay(attempt);
2519
2520 #[allow(clippy::cast_possible_truncation)]
2521 {
2522 info!(
2523 attempt = attempt + 1,
2524 max = self.config.max_reconnect_attempts,
2525 delay_ms = delay.as_millis() as u64,
2526 "Attempting WebSocket reconnection with exponential backoff"
2527 );
2528 }
2529
2530 // Wait for delay with cancellation support
2531 tokio::select! {
2532 biased;
2533
2534 () = token.cancelled() => {
2535 warn!("Reconnection cancelled during backoff delay");
2536 self.set_state(WsConnectionState::Disconnected);
2537 return Err(Error::cancelled("Reconnection cancelled during backoff"));
2538 }
2539
2540 () = tokio::time::sleep(delay) => {
2541 // Delay completed, proceed with connection attempt
2542 }
2543 }
2544
2545 // Attempt connection with cancellation support
2546 match self.connect_with_cancel(Some(token.clone())).await {
2547 Ok(()) => {
2548 info!(attempt = attempt + 1, "Reconnection successful");
2549 // Reset reconnect count on success
2550 self.reconnect_count.store(0, Ordering::Release);
2551 return Ok(());
2552 }
2553 Err(e) => {
2554 // Check if cancelled
2555 if e.as_cancelled().is_some() {
2556 warn!("Reconnection cancelled during connection attempt");
2557 self.set_state(WsConnectionState::Disconnected);
2558 return Err(e);
2559 }
2560
2561 // Classify the error
2562 let ws_error = WsError::from_error(&e);
2563
2564 if ws_error.is_permanent() {
2565 error!(
2566 attempt = attempt + 1,
2567 error = %e,
2568 "Permanent error during reconnection, stopping retry"
2569 );
2570 self.set_state(WsConnectionState::Error);
2571 return Err(e);
2572 }
2573
2574 // Transient error - log and continue retry loop
2575 warn!(
2576 attempt = attempt + 1,
2577 error = %e,
2578 "Transient error during reconnection, will retry"
2579 );
2580 // Continue to next iteration
2581 }
2582 }
2583 }
2584 }
2585
2586 /// Returns the current reconnection attempt count (lock-free).
2587 #[inline]
2588 pub fn reconnect_count(&self) -> u32 {
2589 self.reconnect_count.load(Ordering::Acquire)
2590 }
2591
2592 /// Resets the reconnection attempt counter to zero (lock-free).
2593 pub fn reset_reconnect_count(&self) {
2594 self.reconnect_count.store(0, Ordering::Release);
2595 debug!("Reconnect count reset");
2596 }
2597
2598 /// Returns a snapshot of connection statistics.
2599 ///
2600 /// This method is lock-free and can be called from any context without
2601 /// blocking or risking deadlocks.
2602 ///
2603 /// # Returns
2604 ///
2605 /// A `WsStatsSnapshot` containing the current values of all statistics.
2606 pub fn stats(&self) -> WsStatsSnapshot {
2607 self.stats.snapshot()
2608 }
2609
2610 /// Resets all connection statistics to default values.
2611 ///
2612 /// This method is lock-free and can be called from any context.
2613 pub fn reset_stats(&self) {
2614 self.stats.reset();
2615 debug!("Stats reset");
2616 }
2617
2618 /// Calculates current connection latency in milliseconds.
2619 ///
2620 /// This method is lock-free and can be called from any context.
2621 ///
2622 /// # Returns
2623 ///
2624 /// Time difference between last pong and ping, or `None` if no data available.
2625 pub fn latency(&self) -> Option<i64> {
2626 let last_pong = self.stats.last_pong_time();
2627 let last_ping = self.stats.last_ping_time();
2628 if last_pong > 0 && last_ping > 0 {
2629 Some(last_pong - last_ping)
2630 } else {
2631 None
2632 }
2633 }
2634
2635 /// Creates an automatic reconnection coordinator.
2636 ///
2637 /// # Returns
2638 ///
2639 /// A new [`AutoReconnectCoordinator`] instance for managing reconnection logic.
2640 pub fn create_auto_reconnect_coordinator(self: Arc<Self>) -> AutoReconnectCoordinator {
2641 AutoReconnectCoordinator::new(self)
2642 }
2643
2644 /// Subscribes to a WebSocket channel.
2645 ///
2646 /// Subscription is persisted and automatically reestablished on reconnection.
2647 /// The subscription count is limited by `max_subscriptions` in `WsConfig`.
2648 ///
2649 /// # Arguments
2650 ///
2651 /// * `channel` - Channel name to subscribe to
2652 /// * `symbol` - Optional trading pair symbol
2653 /// * `params` - Optional additional subscription parameters
2654 ///
2655 /// # Errors
2656 ///
2657 /// Returns error if:
2658 /// - Maximum subscription limit is reached (`Error::ResourceExhausted`)
2659 /// - Subscription message fails to send
2660 ///
2661 /// # Example
2662 ///
2663 /// ```rust,ignore
2664 /// use ccxt_core::ws_client::{WsClient, WsConfig};
2665 ///
2666 /// let client = WsClient::new(WsConfig {
2667 /// url: "wss://stream.example.com/ws".to_string(),
2668 /// max_subscriptions: 50,
2669 /// ..Default::default()
2670 /// });
2671 ///
2672 /// // Subscribe to a channel
2673 /// client.subscribe("ticker".to_string(), Some("BTC/USDT".to_string()), None).await?;
2674 ///
2675 /// // Check remaining capacity
2676 /// println!("Remaining capacity: {}", client.remaining_capacity());
2677 /// ```
2678 #[instrument(
2679 name = "ws_subscribe",
2680 skip(self, params),
2681 fields(channel = %channel, symbol = ?symbol)
2682 )]
2683 pub async fn subscribe(
2684 &self,
2685 channel: String,
2686 symbol: Option<String>,
2687 params: Option<HashMap<String, Value>>,
2688 ) -> Result<()> {
2689 let sub_key = Self::subscription_key(&channel, symbol.as_ref());
2690 let subscription = Subscription {
2691 channel: channel.clone(),
2692 symbol: symbol.clone(),
2693 params: params.clone(),
2694 };
2695
2696 // Use SubscriptionManager to enforce capacity limits (Requirements 4.2)
2697 self.subscription_manager
2698 .try_add(sub_key.clone(), subscription)?;
2699
2700 info!(subscription_key = %sub_key, "Subscription registered");
2701
2702 // Lock-free state check
2703 let state = self.state();
2704 if state == WsConnectionState::Connected {
2705 self.send_subscribe_message(channel, symbol, params).await?;
2706 info!(subscription_key = %sub_key, "Subscription message sent");
2707 } else {
2708 debug!(
2709 subscription_key = %sub_key,
2710 state = ?state,
2711 "Subscription queued (not connected)"
2712 );
2713 }
2714
2715 Ok(())
2716 }
2717
2718 /// Unsubscribes from a WebSocket channel.
2719 ///
2720 /// Removes subscription from internal state and sends unsubscribe message if connected.
2721 /// The subscription slot is immediately freed for new subscriptions (Requirements 4.5).
2722 ///
2723 /// # Arguments
2724 ///
2725 /// * `channel` - Channel name to unsubscribe from
2726 /// * `symbol` - Optional trading pair symbol
2727 ///
2728 /// # Errors
2729 ///
2730 /// Returns error if unsubscribe message fails to send.
2731 #[instrument(
2732 name = "ws_unsubscribe",
2733 skip(self),
2734 fields(channel = %channel, symbol = ?symbol)
2735 )]
2736 pub async fn unsubscribe(&self, channel: String, symbol: Option<String>) -> Result<()> {
2737 let sub_key = Self::subscription_key(&channel, symbol.as_ref());
2738
2739 // Use SubscriptionManager to remove subscription (Requirements 4.5)
2740 self.subscription_manager.remove(&sub_key);
2741
2742 info!(subscription_key = %sub_key, "Subscription removed");
2743
2744 // Lock-free state check
2745 let state = self.state();
2746 if state == WsConnectionState::Connected {
2747 self.send_unsubscribe_message(channel, symbol).await?;
2748 info!(subscription_key = %sub_key, "Unsubscribe message sent");
2749 }
2750
2751 Ok(())
2752 }
2753
2754 /// Receives the next available message from the WebSocket stream.
2755 ///
2756 /// # Returns
2757 ///
2758 /// The received JSON message, or `None` if the channel is closed.
2759 pub async fn receive(&self) -> Option<Value> {
2760 let mut rx = self.message_rx.write().await;
2761 rx.recv().await
2762 }
2763
2764 /// Returns the current connection state (lock-free).
2765 #[inline]
2766 pub fn state(&self) -> WsConnectionState {
2767 WsConnectionState::from_u8(self.state.load(Ordering::Acquire))
2768 }
2769
2770 /// Returns a reference to the WebSocket configuration.
2771 ///
2772 /// # Example
2773 ///
2774 /// ```rust
2775 /// use ccxt_core::ws_client::{WsClient, WsConfig};
2776 ///
2777 /// let config = WsConfig {
2778 /// url: "wss://stream.example.com/ws".to_string(),
2779 /// max_reconnect_attempts: 10,
2780 /// ..Default::default()
2781 /// };
2782 /// let client = WsClient::new(config);
2783 ///
2784 /// assert_eq!(client.config().max_reconnect_attempts, 10);
2785 /// ```
2786 #[inline]
2787 pub fn config(&self) -> &WsConfig {
2788 &self.config
2789 }
2790
2791 /// Sets the connection state (lock-free).
2792 ///
2793 /// This method is used internally and by the `AutoReconnectCoordinator`
2794 /// to update the connection state.
2795 ///
2796 /// # Arguments
2797 ///
2798 /// * `state` - The new connection state
2799 #[inline]
2800 pub fn set_state(&self, state: WsConnectionState) {
2801 self.state.store(state.as_u8(), Ordering::Release);
2802 }
2803
2804 /// Checks whether the WebSocket is currently connected (lock-free).
2805 #[inline]
2806 pub fn is_connected(&self) -> bool {
2807 self.state() == WsConnectionState::Connected
2808 }
2809
2810 /// Checks if subscribed to a specific channel (lock-free).
2811 ///
2812 /// # Arguments
2813 ///
2814 /// * `channel` - Channel name to check
2815 /// * `symbol` - Optional trading pair symbol
2816 ///
2817 /// # Returns
2818 ///
2819 /// `true` if subscribed to the channel, `false` otherwise.
2820 pub fn is_subscribed(&self, channel: &str, symbol: Option<&String>) -> bool {
2821 let sub_key = Self::subscription_key(channel, symbol);
2822 self.subscription_manager.contains(&sub_key)
2823 }
2824
2825 /// Returns the number of active subscriptions (lock-free).
2826 ///
2827 /// This method delegates to the internal `SubscriptionManager` to get
2828 /// the current subscription count.
2829 ///
2830 /// # Example
2831 ///
2832 /// ```rust
2833 /// use ccxt_core::ws_client::{WsClient, WsConfig};
2834 ///
2835 /// let client = WsClient::new(WsConfig::default());
2836 /// assert_eq!(client.subscription_count(), 0);
2837 /// ```
2838 pub fn subscription_count(&self) -> usize {
2839 self.subscription_manager.count()
2840 }
2841
2842 /// Returns the remaining capacity for new subscriptions (lock-free).
2843 ///
2844 /// This is calculated as `max_subscriptions - current_count`.
2845 /// Use this method to check if there's room for more subscriptions
2846 /// before attempting to subscribe.
2847 ///
2848 /// # Example
2849 ///
2850 /// ```rust
2851 /// use ccxt_core::ws_client::{WsClient, WsConfig};
2852 ///
2853 /// let client = WsClient::new(WsConfig {
2854 /// max_subscriptions: 50,
2855 /// ..Default::default()
2856 /// });
2857 /// assert_eq!(client.remaining_capacity(), 50);
2858 /// ```
2859 pub fn remaining_capacity(&self) -> usize {
2860 self.subscription_manager.remaining_capacity()
2861 }
2862
2863 /// Sends a raw WebSocket message.
2864 ///
2865 /// # Arguments
2866 ///
2867 /// * `message` - WebSocket message to send
2868 ///
2869 /// # Errors
2870 ///
2871 /// Returns error if not connected or message transmission fails.
2872 #[instrument(name = "ws_send", skip(self, message))]
2873 pub async fn send(&self, message: Message) -> Result<()> {
2874 let tx = self.write_tx.lock().await;
2875
2876 if let Some(sender) = tx.as_ref() {
2877 sender.send(message).map_err(|e| {
2878 error!(
2879 error = %e,
2880 "Failed to send WebSocket message"
2881 );
2882 Error::network(format!("Failed to send message: {e}"))
2883 })?;
2884 debug!("WebSocket message sent successfully");
2885 Ok(())
2886 } else {
2887 warn!("WebSocket not connected, cannot send message");
2888 Err(Error::network("WebSocket not connected"))
2889 }
2890 }
2891
2892 /// Sends a text message over the WebSocket connection.
2893 ///
2894 /// # Arguments
2895 ///
2896 /// * `text` - Text content to send
2897 ///
2898 /// # Errors
2899 ///
2900 /// Returns error if not connected or transmission fails.
2901 #[instrument(name = "ws_send_text", skip(self, text), fields(text_len = text.len()))]
2902 pub async fn send_text(&self, text: String) -> Result<()> {
2903 self.send(Message::Text(text.into())).await
2904 }
2905
2906 /// Sends a JSON-encoded message over the WebSocket connection.
2907 ///
2908 /// # Arguments
2909 ///
2910 /// * `json` - JSON value to serialize and send
2911 ///
2912 /// # Errors
2913 ///
2914 /// Returns error if serialization fails, not connected, or transmission fails.
2915 #[instrument(name = "ws_send_json", skip(self, json))]
2916 pub async fn send_json(&self, json: &Value) -> Result<()> {
2917 let text = serde_json::to_string(json).map_err(|e| {
2918 error!(error = %e, "Failed to serialize JSON for WebSocket");
2919 Error::from(e)
2920 })?;
2921 self.send_text(text).await
2922 }
2923
2924 /// Generates a unique subscription key from channel and symbol.
2925 fn subscription_key(channel: &str, symbol: Option<&String>) -> String {
2926 match symbol {
2927 Some(s) => format!("{channel}:{s}"),
2928 None => channel.to_string(),
2929 }
2930 }
2931
2932 /// Starts the WebSocket message processing loop.
2933 ///
2934 /// Spawns separate tasks for reading and writing messages, handling shutdown signals.
2935 async fn start_message_loop(&self, ws_stream: WebSocketStream<MaybeTlsStream<TcpStream>>) {
2936 let (write, mut read) = ws_stream.split();
2937
2938 let (write_tx, mut write_rx) = mpsc::unbounded_channel::<Message>();
2939 *self.write_tx.lock().await = Some(write_tx.clone());
2940
2941 let (shutdown_tx, mut shutdown_rx) = mpsc::unbounded_channel::<()>();
2942 *self.shutdown_tx.lock().await = Some(shutdown_tx);
2943
2944 let state = Arc::clone(&self.state);
2945 let message_tx = self.message_tx.clone();
2946 let ping_interval_ms = self.config.ping_interval;
2947
2948 info!("Starting WebSocket message loop");
2949
2950 let write_handle = tokio::spawn(async move {
2951 let mut write = write;
2952 loop {
2953 tokio::select! {
2954 Some(msg) = write_rx.recv() => {
2955 if let Err(e) = write.send(msg).await {
2956 error!(error = %e, "Failed to write message");
2957 break;
2958 }
2959 }
2960 _ = shutdown_rx.recv() => {
2961 debug!("Write task received shutdown signal");
2962 let _ = write.send(Message::Close(None)).await;
2963 break;
2964 }
2965 }
2966 }
2967 debug!("Write task terminated");
2968 });
2969
2970 let state_clone = Arc::clone(&state);
2971 let ws_stats = Arc::clone(&self.stats);
2972 let read_handle = tokio::spawn(async move {
2973 debug!("Starting WebSocket read task");
2974 while let Some(msg_result) = read.next().await {
2975 match msg_result {
2976 Ok(Message::Text(text)) => {
2977 debug!(len = text.len(), "Received text message");
2978
2979 // Lock-free stats update
2980 ws_stats.record_received(text.len() as u64);
2981
2982 match serde_json::from_str::<Value>(&text) {
2983 Ok(json) => {
2984 let _ = message_tx.send(json);
2985 }
2986 Err(e) => {
2987 // Log parse failure with truncated raw message preview
2988 let raw_preview: String = text.chars().take(200).collect();
2989 warn!(
2990 error = %e,
2991 raw_message_preview = %raw_preview,
2992 raw_message_len = text.len(),
2993 "Failed to parse WebSocket text message as JSON"
2994 );
2995 }
2996 }
2997 }
2998 Ok(Message::Binary(data)) => {
2999 debug!(len = data.len(), "Received binary message");
3000
3001 // Lock-free stats update
3002 ws_stats.record_received(data.len() as u64);
3003
3004 match String::from_utf8(data.to_vec()) {
3005 Ok(text) => {
3006 match serde_json::from_str::<Value>(&text) {
3007 Ok(json) => {
3008 let _ = message_tx.send(json);
3009 }
3010 Err(e) => {
3011 // Log parse failure with truncated raw message preview
3012 let raw_preview: String = text.chars().take(200).collect();
3013 warn!(
3014 error = %e,
3015 raw_message_preview = %raw_preview,
3016 raw_message_len = text.len(),
3017 "Failed to parse WebSocket binary message as JSON"
3018 );
3019 }
3020 }
3021 }
3022 Err(e) => {
3023 // Log UTF-8 decode failure with hex preview
3024 let hex_preview: String = data
3025 .iter()
3026 .take(50)
3027 .map(|b| format!("{b:02x}"))
3028 .collect::<Vec<_>>()
3029 .join(" ");
3030 warn!(
3031 error = %e,
3032 hex_preview = %hex_preview,
3033 data_len = data.len(),
3034 "Failed to decode WebSocket binary message as UTF-8"
3035 );
3036 }
3037 }
3038 }
3039 Ok(Message::Ping(_)) => {
3040 debug!("Received ping, auto-responding with pong");
3041 }
3042 Ok(Message::Pong(_)) => {
3043 debug!("Received pong");
3044
3045 // Lock-free stats update
3046 ws_stats.record_pong();
3047 }
3048 Ok(Message::Close(frame)) => {
3049 info!(
3050 close_frame = ?frame,
3051 "Received WebSocket close frame"
3052 );
3053 // Lock-free state update
3054 state_clone
3055 .store(WsConnectionState::Disconnected.as_u8(), Ordering::Release);
3056 break;
3057 }
3058 Err(e) => {
3059 error!(
3060 error = %e,
3061 error_debug = ?e,
3062 "WebSocket read error"
3063 );
3064 // Lock-free state update
3065 state_clone.store(WsConnectionState::Error.as_u8(), Ordering::Release);
3066 break;
3067 }
3068 _ => {
3069 debug!("Received other WebSocket message type");
3070 }
3071 }
3072 }
3073 debug!("WebSocket read task terminated");
3074 });
3075
3076 if ping_interval_ms > 0 {
3077 let write_tx_clone = write_tx.clone();
3078 let ping_stats = Arc::clone(&self.stats);
3079 let ping_state = Arc::clone(&state);
3080 let pong_timeout_ms = self.config.pong_timeout;
3081
3082 tokio::spawn(async move {
3083 let mut interval = interval(Duration::from_millis(ping_interval_ms));
3084 debug!(
3085 interval_ms = ping_interval_ms,
3086 timeout_ms = pong_timeout_ms,
3087 "Starting ping task with timeout detection"
3088 );
3089
3090 loop {
3091 interval.tick().await;
3092
3093 let now = chrono::Utc::now().timestamp_millis();
3094 // Lock-free stats read
3095 let last_pong = ping_stats.last_pong_time();
3096
3097 if last_pong > 0 {
3098 let elapsed = now - last_pong;
3099 #[allow(clippy::cast_possible_wrap)]
3100 if elapsed > pong_timeout_ms as i64 {
3101 warn!(
3102 elapsed_ms = elapsed,
3103 timeout_ms = pong_timeout_ms,
3104 "Pong timeout detected, marking connection as error"
3105 );
3106 // Lock-free state update
3107 ping_state.store(WsConnectionState::Error.as_u8(), Ordering::Release);
3108 break;
3109 }
3110 }
3111
3112 // Lock-free stats update
3113 ping_stats.record_ping();
3114
3115 if write_tx_clone.send(Message::Ping(vec![].into())).is_err() {
3116 debug!("Ping task: write channel closed");
3117 break;
3118 }
3119 debug!("Sent ping");
3120 }
3121 debug!("Ping task terminated");
3122 });
3123 }
3124
3125 tokio::spawn(async move {
3126 let _ = tokio::join!(write_handle, read_handle);
3127 info!("All WebSocket tasks completed");
3128 });
3129 }
3130
3131 /// Sends a subscription message to the WebSocket server.
3132 #[instrument(
3133 name = "ws_send_subscribe",
3134 skip(self, params),
3135 fields(channel = %channel, symbol = ?symbol)
3136 )]
3137 async fn send_subscribe_message(
3138 &self,
3139 channel: String,
3140 symbol: Option<String>,
3141 params: Option<HashMap<String, Value>>,
3142 ) -> Result<()> {
3143 let msg = WsMessage::Subscribe {
3144 channel: channel.clone(),
3145 symbol: symbol.clone(),
3146 params,
3147 };
3148
3149 let json = serde_json::to_value(&msg).map_err(|e| {
3150 error!(error = %e, "Failed to serialize subscribe message");
3151 Error::from(e)
3152 })?;
3153
3154 debug!("Sending subscribe message to server");
3155
3156 self.send_json(&json).await?;
3157 info!("Subscribe message sent successfully");
3158 Ok(())
3159 }
3160
3161 /// Sends an unsubscribe message to the WebSocket server.
3162 #[instrument(
3163 name = "ws_send_unsubscribe",
3164 skip(self),
3165 fields(channel = %channel, symbol = ?symbol)
3166 )]
3167 async fn send_unsubscribe_message(
3168 &self,
3169 channel: String,
3170 symbol: Option<String>,
3171 ) -> Result<()> {
3172 let msg = WsMessage::Unsubscribe {
3173 channel: channel.clone(),
3174 symbol: symbol.clone(),
3175 };
3176
3177 let json = serde_json::to_value(&msg).map_err(|e| {
3178 error!(error = %e, "Failed to serialize unsubscribe message");
3179 Error::from(e)
3180 })?;
3181
3182 debug!("Sending unsubscribe message to server");
3183
3184 self.send_json(&json).await?;
3185 info!("Unsubscribe message sent successfully");
3186 Ok(())
3187 }
3188
3189 /// Resubscribes to all previously subscribed channels.
3190 async fn resubscribe_all(&self) -> Result<()> {
3191 // Collect subscriptions using SubscriptionManager to avoid holding reference during async calls
3192 let subs = self.subscription_manager.collect_subscriptions();
3193
3194 for subscription in subs {
3195 self.send_subscribe_message(
3196 subscription.channel.clone(),
3197 subscription.symbol.clone(),
3198 subscription.params.clone(),
3199 )
3200 .await?;
3201 }
3202 Ok(())
3203 }
3204}
3205/// WebSocket connection event types.
3206///
3207/// These events are emitted during the WebSocket connection lifecycle to notify
3208/// listeners about state changes, reconnection attempts, and errors.
3209///
3210/// # Example
3211///
3212/// ```rust
3213/// use ccxt_core::ws_client::WsEvent;
3214/// use std::time::Duration;
3215///
3216/// fn handle_event(event: WsEvent) {
3217/// match event {
3218/// WsEvent::Connecting => println!("Connecting..."),
3219/// WsEvent::Connected => println!("Connected!"),
3220/// WsEvent::Disconnected => println!("Disconnected"),
3221/// WsEvent::Reconnecting { attempt, delay, error } => {
3222/// println!("Reconnecting (attempt {}), delay: {:?}, error: {:?}",
3223/// attempt, delay, error);
3224/// }
3225/// WsEvent::ReconnectSuccess => println!("Reconnected successfully"),
3226/// WsEvent::ReconnectFailed { attempt, error, is_permanent } => {
3227/// println!("Reconnect failed (attempt {}): {}, permanent: {}",
3228/// attempt, error, is_permanent);
3229/// }
3230/// WsEvent::ReconnectExhausted { total_attempts, last_error } => {
3231/// println!("All {} reconnect attempts exhausted: {}",
3232/// total_attempts, last_error);
3233/// }
3234/// WsEvent::SubscriptionRestored => println!("Subscriptions restored"),
3235/// WsEvent::PermanentError { error } => {
3236/// println!("Permanent error (no retry): {}", error);
3237/// }
3238/// WsEvent::Shutdown => println!("Shutdown complete"),
3239/// }
3240/// }
3241/// ```
3242#[derive(Debug, Clone)]
3243pub enum WsEvent {
3244 /// Connection attempt started.
3245 ///
3246 /// Emitted when the client begins establishing a WebSocket connection.
3247 Connecting,
3248
3249 /// Connection established successfully.
3250 ///
3251 /// Emitted when the WebSocket handshake completes and the connection is ready.
3252 Connected,
3253
3254 /// Connection closed.
3255 ///
3256 /// Emitted when the WebSocket connection is closed, either gracefully or due to an error.
3257 Disconnected,
3258
3259 /// Reconnection in progress.
3260 ///
3261 /// Emitted before each reconnection attempt with details about the attempt.
3262 Reconnecting {
3263 /// Current reconnection attempt number (1-indexed).
3264 attempt: u32,
3265 /// Delay before this reconnection attempt (calculated by backoff strategy).
3266 delay: Duration,
3267 /// Error that triggered the reconnection (if any).
3268 error: Option<String>,
3269 },
3270
3271 /// Reconnection succeeded.
3272 ///
3273 /// Emitted when a reconnection attempt successfully establishes a new connection.
3274 ReconnectSuccess,
3275
3276 /// Single reconnection attempt failed.
3277 ///
3278 /// Emitted when a reconnection attempt fails. More attempts may follow
3279 /// unless `is_permanent` is true or max attempts is reached.
3280 ReconnectFailed {
3281 /// The attempt number that failed (1-indexed).
3282 attempt: u32,
3283 /// Error message describing the failure.
3284 error: String,
3285 /// Whether this is a permanent error that should not be retried.
3286 ///
3287 /// If `true`, no further reconnection attempts will be made.
3288 is_permanent: bool,
3289 },
3290
3291 /// All reconnection attempts exhausted.
3292 ///
3293 /// Emitted when the maximum number of reconnection attempts has been reached
3294 /// without successfully reconnecting. No further automatic reconnection will occur.
3295 ReconnectExhausted {
3296 /// Total number of reconnection attempts made.
3297 total_attempts: u32,
3298 /// The last error encountered.
3299 last_error: String,
3300 },
3301
3302 /// Subscriptions restored after reconnection.
3303 ///
3304 /// Emitted after a successful reconnection when all previous subscriptions
3305 /// have been re-established.
3306 SubscriptionRestored,
3307
3308 /// Permanent error occurred (no retry).
3309 ///
3310 /// Emitted when a permanent error is encountered that cannot be recovered
3311 /// through retries (e.g., authentication failure, invalid credentials).
3312 PermanentError {
3313 /// Error message describing the permanent failure.
3314 error: String,
3315 },
3316
3317 /// Shutdown completed.
3318 ///
3319 /// Emitted when the WebSocket client has completed its graceful shutdown
3320 /// process, including cancelling pending operations and closing connections.
3321 Shutdown,
3322}
3323
3324impl WsEvent {
3325 /// Returns `true` if this is a `Connecting` event.
3326 #[inline]
3327 #[must_use]
3328 pub fn is_connecting(&self) -> bool {
3329 matches!(self, Self::Connecting)
3330 }
3331
3332 /// Returns `true` if this is a `Connected` event.
3333 #[inline]
3334 #[must_use]
3335 pub fn is_connected(&self) -> bool {
3336 matches!(self, Self::Connected)
3337 }
3338
3339 /// Returns `true` if this is a `Disconnected` event.
3340 #[inline]
3341 #[must_use]
3342 pub fn is_disconnected(&self) -> bool {
3343 matches!(self, Self::Disconnected)
3344 }
3345
3346 /// Returns `true` if this is a `Reconnecting` event.
3347 #[inline]
3348 #[must_use]
3349 pub fn is_reconnecting(&self) -> bool {
3350 matches!(self, Self::Reconnecting { .. })
3351 }
3352
3353 /// Returns `true` if this is a `ReconnectSuccess` event.
3354 #[inline]
3355 #[must_use]
3356 pub fn is_reconnect_success(&self) -> bool {
3357 matches!(self, Self::ReconnectSuccess)
3358 }
3359
3360 /// Returns `true` if this is a `ReconnectFailed` event.
3361 #[inline]
3362 #[must_use]
3363 pub fn is_reconnect_failed(&self) -> bool {
3364 matches!(self, Self::ReconnectFailed { .. })
3365 }
3366
3367 /// Returns `true` if this is a `ReconnectExhausted` event.
3368 #[inline]
3369 #[must_use]
3370 pub fn is_reconnect_exhausted(&self) -> bool {
3371 matches!(self, Self::ReconnectExhausted { .. })
3372 }
3373
3374 /// Returns `true` if this is a `SubscriptionRestored` event.
3375 #[inline]
3376 #[must_use]
3377 pub fn is_subscription_restored(&self) -> bool {
3378 matches!(self, Self::SubscriptionRestored)
3379 }
3380
3381 /// Returns `true` if this is a `PermanentError` event.
3382 #[inline]
3383 #[must_use]
3384 pub fn is_permanent_error(&self) -> bool {
3385 matches!(self, Self::PermanentError { .. })
3386 }
3387
3388 /// Returns `true` if this is a `Shutdown` event.
3389 #[inline]
3390 #[must_use]
3391 pub fn is_shutdown(&self) -> bool {
3392 matches!(self, Self::Shutdown)
3393 }
3394
3395 /// Returns `true` if this event indicates an error condition.
3396 ///
3397 /// This includes `ReconnectFailed`, `ReconnectExhausted`, and `PermanentError`.
3398 #[inline]
3399 #[must_use]
3400 pub fn is_error(&self) -> bool {
3401 matches!(
3402 self,
3403 Self::ReconnectFailed { .. }
3404 | Self::ReconnectExhausted { .. }
3405 | Self::PermanentError { .. }
3406 )
3407 }
3408
3409 /// Returns `true` if this event indicates a terminal state.
3410 ///
3411 /// Terminal states are those where no further automatic recovery will occur:
3412 /// `ReconnectExhausted`, `PermanentError`, and `Shutdown`.
3413 #[inline]
3414 #[must_use]
3415 pub fn is_terminal(&self) -> bool {
3416 matches!(
3417 self,
3418 Self::ReconnectExhausted { .. } | Self::PermanentError { .. } | Self::Shutdown
3419 )
3420 }
3421}
3422
3423impl std::fmt::Display for WsEvent {
3424 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
3425 match self {
3426 Self::Connecting => write!(f, "Connecting"),
3427 Self::Connected => write!(f, "Connected"),
3428 Self::Disconnected => write!(f, "Disconnected"),
3429 Self::Reconnecting {
3430 attempt,
3431 delay,
3432 error,
3433 } => {
3434 write!(f, "Reconnecting (attempt {attempt}, delay: {delay:?}")?;
3435 if let Some(err) = error {
3436 write!(f, ", error: {err}")?;
3437 }
3438 write!(f, ")")
3439 }
3440 Self::ReconnectSuccess => write!(f, "ReconnectSuccess"),
3441 Self::ReconnectFailed {
3442 attempt,
3443 error,
3444 is_permanent,
3445 } => {
3446 write!(
3447 f,
3448 "ReconnectFailed (attempt {attempt}, error: {error}, permanent: {is_permanent})"
3449 )
3450 }
3451 Self::ReconnectExhausted {
3452 total_attempts,
3453 last_error,
3454 } => {
3455 write!(
3456 f,
3457 "ReconnectExhausted (attempts: {total_attempts}, last error: {last_error})"
3458 )
3459 }
3460 Self::SubscriptionRestored => write!(f, "SubscriptionRestored"),
3461 Self::PermanentError { error } => write!(f, "PermanentError: {error}"),
3462 Self::Shutdown => write!(f, "Shutdown"),
3463 }
3464 }
3465}
3466
3467/// Event callback function type.
3468pub type WsEventCallback = Arc<dyn Fn(WsEvent) + Send + Sync>;
3469
3470/// Automatic reconnection coordinator for WebSocket connections.
3471///
3472/// Monitors connection state and triggers reconnection attempts when disconnected.
3473/// Uses exponential backoff strategy for retry delays and classifies errors to
3474/// determine if reconnection should continue.
3475///
3476/// # Features
3477///
3478/// - **Exponential Backoff**: Uses configurable exponential backoff with jitter
3479/// to prevent thundering herd effects during reconnection.
3480/// - **Error Classification**: Distinguishes between transient and permanent errors,
3481/// stopping reconnection attempts for permanent errors.
3482/// - **Cancellation Support**: Supports graceful cancellation via `CancellationToken`.
3483/// - **Event Callbacks**: Notifies registered callbacks about reconnection events.
3484///
3485/// # Example
3486///
3487/// ```rust,ignore
3488/// use ccxt_core::ws_client::{WsClient, WsConfig, AutoReconnectCoordinator, WsEvent};
3489/// use std::sync::Arc;
3490///
3491/// let client = Arc::new(WsClient::new(WsConfig::default()));
3492/// let coordinator = client.clone().create_auto_reconnect_coordinator()
3493/// .with_callback(Arc::new(|event| {
3494/// match event {
3495/// WsEvent::Reconnecting { attempt, delay, .. } => {
3496/// println!("Reconnecting attempt {} with delay {:?}", attempt, delay);
3497/// }
3498/// WsEvent::ReconnectSuccess => println!("Reconnected!"),
3499/// _ => {}
3500/// }
3501/// }));
3502///
3503/// coordinator.start().await;
3504/// ```
3505pub struct AutoReconnectCoordinator {
3506 /// The WebSocket client to manage reconnection for
3507 client: Arc<WsClient>,
3508 /// Whether auto-reconnect is enabled
3509 enabled: Arc<AtomicBool>,
3510 /// Handle to the reconnection task
3511 reconnect_task: Arc<Mutex<Option<JoinHandle<()>>>>,
3512 /// Optional event callback for reconnection events
3513 event_callback: Option<WsEventCallback>,
3514 /// Cancellation token for stopping reconnection
3515 cancel_token: Arc<Mutex<Option<CancellationToken>>>,
3516}
3517
3518impl AutoReconnectCoordinator {
3519 /// Creates a new automatic reconnection coordinator.
3520 ///
3521 /// The coordinator uses the backoff configuration from the client's `WsConfig`.
3522 ///
3523 /// # Arguments
3524 ///
3525 /// * `client` - Arc reference to the WebSocket client
3526 ///
3527 /// # Example
3528 ///
3529 /// ```rust,ignore
3530 /// use ccxt_core::ws_client::{WsClient, WsConfig, AutoReconnectCoordinator};
3531 /// use std::sync::Arc;
3532 ///
3533 /// let client = Arc::new(WsClient::new(WsConfig::default()));
3534 /// let coordinator = AutoReconnectCoordinator::new(client);
3535 /// ```
3536 pub fn new(client: Arc<WsClient>) -> Self {
3537 Self {
3538 client,
3539 enabled: Arc::new(AtomicBool::new(false)),
3540 reconnect_task: Arc::new(Mutex::new(None)),
3541 event_callback: None,
3542 cancel_token: Arc::new(Mutex::new(None)),
3543 }
3544 }
3545
3546 /// Sets the event callback for reconnection events.
3547 ///
3548 /// # Arguments
3549 ///
3550 /// * `callback` - Event callback function
3551 ///
3552 /// # Returns
3553 ///
3554 /// Self for method chaining
3555 ///
3556 /// # Example
3557 ///
3558 /// ```rust,ignore
3559 /// use ccxt_core::ws_client::{AutoReconnectCoordinator, WsEvent};
3560 /// use std::sync::Arc;
3561 ///
3562 /// let coordinator = coordinator.with_callback(Arc::new(|event| {
3563 /// println!("Event: {:?}", event);
3564 /// }));
3565 /// ```
3566 pub fn with_callback(mut self, callback: WsEventCallback) -> Self {
3567 self.event_callback = Some(callback);
3568 self
3569 }
3570
3571 /// Sets the cancellation token for stopping reconnection.
3572 ///
3573 /// When the token is cancelled, the reconnection loop will stop gracefully.
3574 ///
3575 /// # Arguments
3576 ///
3577 /// * `token` - Cancellation token
3578 ///
3579 /// # Returns
3580 ///
3581 /// Self for method chaining
3582 ///
3583 /// # Example
3584 ///
3585 /// ```rust,ignore
3586 /// use ccxt_core::ws_client::AutoReconnectCoordinator;
3587 /// use tokio_util::sync::CancellationToken;
3588 ///
3589 /// let token = CancellationToken::new();
3590 /// let coordinator = coordinator.with_cancel_token(token.clone());
3591 ///
3592 /// // Later, to stop reconnection:
3593 /// token.cancel();
3594 /// ```
3595 pub fn with_cancel_token(self, token: CancellationToken) -> Self {
3596 // We need to store it synchronously during construction
3597 // The actual storage happens in start()
3598 let _ = self.cancel_token.try_lock().map(|mut guard| {
3599 *guard = Some(token);
3600 });
3601 self
3602 }
3603
3604 /// Sets the cancellation token asynchronously.
3605 ///
3606 /// # Arguments
3607 ///
3608 /// * `token` - Cancellation token
3609 pub async fn set_cancel_token(&self, token: CancellationToken) {
3610 let mut guard = self.cancel_token.lock().await;
3611 *guard = Some(token);
3612 }
3613
3614 /// Returns whether auto-reconnect is currently enabled.
3615 ///
3616 /// # Returns
3617 ///
3618 /// `true` if auto-reconnect is enabled, `false` otherwise
3619 #[inline]
3620 pub fn is_enabled(&self) -> bool {
3621 self.enabled.load(Ordering::SeqCst)
3622 }
3623
3624 /// Starts the automatic reconnection coordinator.
3625 ///
3626 /// Begins monitoring connection state and automatically reconnects on disconnect.
3627 /// Uses exponential backoff for retry delays.
3628 ///
3629 /// # Behavior
3630 ///
3631 /// 1. Monitors connection state every second
3632 /// 2. When disconnected or in error state, calculates delay using exponential backoff
3633 /// 3. Emits `Reconnecting` event before each attempt
3634 /// 4. Attempts reconnection
3635 /// 5. On success: resets reconnect counter, emits `ReconnectSuccess`, restores subscriptions
3636 /// 6. On transient error: waits for backoff delay and retries
3637 /// 7. On permanent error: emits `PermanentError` and stops
3638 /// 8. On max attempts: emits `ReconnectExhausted` and stops
3639 pub async fn start(&self) {
3640 if self.enabled.swap(true, Ordering::SeqCst) {
3641 info!("Auto-reconnect already started");
3642 return;
3643 }
3644
3645 info!("Starting auto-reconnect coordinator with exponential backoff");
3646
3647 let client = Arc::clone(&self.client);
3648 let enabled = Arc::clone(&self.enabled);
3649 let callback = self.event_callback.clone();
3650 let cancel_token = self.cancel_token.lock().await.clone();
3651
3652 // Create backoff strategy from client config
3653 let backoff_config = client.config().backoff_config.clone();
3654
3655 let handle = tokio::spawn(async move {
3656 Self::reconnect_loop(client, enabled, callback, backoff_config, cancel_token).await;
3657 });
3658
3659 *self.reconnect_task.lock().await = Some(handle);
3660 }
3661
3662 /// Stops the automatic reconnection coordinator.
3663 ///
3664 /// Halts monitoring and reconnection tasks. If a cancellation token was set,
3665 /// it will be cancelled to stop any in-progress reconnection attempts.
3666 pub async fn stop(&self) {
3667 if !self.enabled.swap(false, Ordering::SeqCst) {
3668 info!("Auto-reconnect already stopped");
3669 return;
3670 }
3671
3672 info!("Stopping auto-reconnect coordinator");
3673
3674 // Cancel any in-progress reconnection
3675 if let Some(token) = self.cancel_token.lock().await.as_ref() {
3676 token.cancel();
3677 }
3678
3679 let mut task = self.reconnect_task.lock().await;
3680 if let Some(handle) = task.take() {
3681 handle.abort();
3682 }
3683 }
3684
3685 /// Internal reconnection loop with exponential backoff.
3686 ///
3687 /// Continuously monitors connection state and triggers reconnection
3688 /// when `Error` or `Disconnected` state is detected. Uses exponential
3689 /// backoff for retry delays and classifies errors to determine if
3690 /// reconnection should continue.
3691 ///
3692 /// # Arguments
3693 ///
3694 /// * `client` - The WebSocket client
3695 /// * `enabled` - Atomic flag indicating if auto-reconnect is enabled
3696 /// * `callback` - Optional event callback
3697 /// * `backoff_config` - Configuration for exponential backoff
3698 /// * `cancel_token` - Optional cancellation token
3699 async fn reconnect_loop(
3700 client: Arc<WsClient>,
3701 enabled: Arc<AtomicBool>,
3702 callback: Option<WsEventCallback>,
3703 backoff_config: BackoffConfig,
3704 cancel_token: Option<CancellationToken>,
3705 ) {
3706 let mut check_interval = interval(Duration::from_secs(1));
3707 let backoff = BackoffStrategy::new(backoff_config);
3708 let mut last_error: Option<String> = None;
3709
3710 loop {
3711 // Check for cancellation
3712 if cancel_token
3713 .as_ref()
3714 .is_some_and(CancellationToken::is_cancelled)
3715 {
3716 info!("Auto-reconnect cancelled via token");
3717 break;
3718 }
3719
3720 check_interval.tick().await;
3721
3722 if !enabled.load(Ordering::SeqCst) {
3723 debug!("Auto-reconnect disabled, exiting loop");
3724 break;
3725 }
3726
3727 // Lock-free state check
3728 let state = client.state();
3729
3730 if matches!(
3731 state,
3732 WsConnectionState::Disconnected | WsConnectionState::Error
3733 ) {
3734 // Lock-free reconnect count access
3735 let attempt = client.reconnect_count();
3736
3737 // Check if max attempts reached
3738 let max_attempts = client.config().max_reconnect_attempts;
3739 if attempt >= max_attempts {
3740 error!(
3741 attempts = attempt,
3742 max = max_attempts,
3743 "Max reconnect attempts reached, stopping auto-reconnect"
3744 );
3745
3746 if let Some(ref cb) = callback {
3747 cb(WsEvent::ReconnectExhausted {
3748 total_attempts: attempt,
3749 last_error: last_error
3750 .clone()
3751 .unwrap_or_else(|| "Unknown error".to_string()),
3752 });
3753 }
3754 break;
3755 }
3756
3757 // Calculate delay using exponential backoff strategy
3758 let delay = backoff.calculate_delay(attempt);
3759
3760 #[allow(clippy::cast_possible_truncation)]
3761 {
3762 info!(
3763 attempt = attempt + 1,
3764 max = max_attempts,
3765 delay_ms = delay.as_millis() as u64,
3766 state = ?state,
3767 "Connection lost, attempting reconnect with exponential backoff"
3768 );
3769 }
3770
3771 if let Some(ref cb) = callback {
3772 cb(WsEvent::Reconnecting {
3773 attempt: attempt + 1,
3774 delay,
3775 error: last_error.clone(),
3776 });
3777 }
3778
3779 // Wait for backoff delay with cancellation support
3780 if let Some(ref token) = cancel_token {
3781 tokio::select! {
3782 biased;
3783
3784 () = token.cancelled() => {
3785 info!("Auto-reconnect cancelled during backoff delay");
3786 break;
3787 }
3788
3789 () = tokio::time::sleep(delay) => {
3790 // Delay completed, proceed with reconnection
3791 }
3792 }
3793 } else {
3794 tokio::time::sleep(delay).await;
3795 }
3796
3797 // Check cancellation again after delay
3798 if cancel_token
3799 .as_ref()
3800 .is_some_and(CancellationToken::is_cancelled)
3801 {
3802 info!("Auto-reconnect cancelled after backoff delay");
3803 break;
3804 }
3805
3806 // Attempt reconnection using connect() directly to avoid double-counting
3807 // The reconnect_count is managed by this loop
3808 client.set_state(WsConnectionState::Reconnecting);
3809
3810 match client.connect().await {
3811 Ok(()) => {
3812 info!(attempt = attempt + 1, "Reconnection successful");
3813
3814 // Reset reconnect count on success (Requirements 1.4)
3815 client.reset_reconnect_count();
3816 last_error = None;
3817
3818 if let Some(ref cb) = callback {
3819 cb(WsEvent::ReconnectSuccess);
3820 }
3821
3822 // Restore subscriptions
3823 match client.resubscribe_all().await {
3824 Ok(()) => {
3825 info!("Subscriptions restored");
3826 if let Some(ref cb) = callback {
3827 cb(WsEvent::SubscriptionRestored);
3828 }
3829 }
3830 Err(e) => {
3831 error!(error = %e, "Failed to restore subscriptions");
3832 }
3833 }
3834 }
3835 Err(e) => {
3836 let error_msg = e.to_string();
3837 let ws_error = WsError::from_error(&e);
3838 let is_permanent = ws_error.is_permanent();
3839
3840 error!(
3841 attempt = attempt + 1,
3842 error = %e,
3843 is_permanent = is_permanent,
3844 "Reconnection failed"
3845 );
3846
3847 last_error = Some(error_msg.clone());
3848
3849 // Increment reconnect count for failed attempt
3850 client.reconnect_count.fetch_add(1, Ordering::AcqRel);
3851
3852 if let Some(ref cb) = callback {
3853 cb(WsEvent::ReconnectFailed {
3854 attempt: attempt + 1,
3855 error: error_msg,
3856 is_permanent,
3857 });
3858 }
3859
3860 // Don't retry if it's a permanent error (Requirements 2.2, 2.3, 2.5)
3861 if is_permanent {
3862 if let Some(ref cb) = callback {
3863 cb(WsEvent::PermanentError {
3864 error: e.to_string(),
3865 });
3866 }
3867 break;
3868 }
3869
3870 // For transient errors, the loop will continue and calculate
3871 // a new backoff delay on the next iteration
3872 }
3873 }
3874 }
3875 }
3876
3877 info!("Auto-reconnect loop terminated");
3878 }
3879}
3880
3881#[cfg(test)]
3882mod tests {
3883 use super::*;
3884
3885 // ==================== BackoffConfig Tests ====================
3886
3887 #[test]
3888 fn test_backoff_config_default() {
3889 let config = BackoffConfig::default();
3890 assert_eq!(config.base_delay, Duration::from_secs(1));
3891 assert_eq!(config.max_delay, Duration::from_secs(60));
3892 assert!((config.jitter_factor - 0.25).abs() < f64::EPSILON);
3893 assert!((config.multiplier - 2.0).abs() < f64::EPSILON);
3894 }
3895
3896 #[test]
3897 fn test_backoff_config_custom() {
3898 let config = BackoffConfig {
3899 base_delay: Duration::from_millis(500),
3900 max_delay: Duration::from_secs(30),
3901 jitter_factor: 0.1,
3902 multiplier: 3.0,
3903 };
3904 assert_eq!(config.base_delay, Duration::from_millis(500));
3905 assert_eq!(config.max_delay, Duration::from_secs(30));
3906 assert!((config.jitter_factor - 0.1).abs() < f64::EPSILON);
3907 assert!((config.multiplier - 3.0).abs() < f64::EPSILON);
3908 }
3909
3910 // ==================== BackoffStrategy Tests ====================
3911
3912 #[test]
3913 fn test_backoff_strategy_with_defaults() {
3914 let strategy = BackoffStrategy::with_defaults();
3915 assert_eq!(strategy.config().base_delay, Duration::from_secs(1));
3916 assert_eq!(strategy.config().max_delay, Duration::from_secs(60));
3917 }
3918
3919 #[test]
3920 fn test_backoff_strategy_exponential_growth_no_jitter() {
3921 let config = BackoffConfig {
3922 base_delay: Duration::from_secs(1),
3923 max_delay: Duration::from_secs(60),
3924 jitter_factor: 0.0, // No jitter for predictable results
3925 multiplier: 2.0,
3926 };
3927 let strategy = BackoffStrategy::new(config);
3928
3929 // Verify exponential growth: 1s, 2s, 4s, 8s, 16s, 32s, 60s (capped)
3930 assert_eq!(strategy.calculate_delay(0), Duration::from_secs(1));
3931 assert_eq!(strategy.calculate_delay(1), Duration::from_secs(2));
3932 assert_eq!(strategy.calculate_delay(2), Duration::from_secs(4));
3933 assert_eq!(strategy.calculate_delay(3), Duration::from_secs(8));
3934 assert_eq!(strategy.calculate_delay(4), Duration::from_secs(16));
3935 assert_eq!(strategy.calculate_delay(5), Duration::from_secs(32));
3936 // At attempt 6, 1 * 2^6 = 64s, but capped at 60s
3937 assert_eq!(strategy.calculate_delay(6), Duration::from_secs(60));
3938 // Further attempts stay at max
3939 assert_eq!(strategy.calculate_delay(10), Duration::from_secs(60));
3940 }
3941
3942 #[test]
3943 fn test_backoff_strategy_max_delay_cap() {
3944 let config = BackoffConfig {
3945 base_delay: Duration::from_secs(10),
3946 max_delay: Duration::from_secs(30),
3947 jitter_factor: 0.0,
3948 multiplier: 2.0,
3949 };
3950 let strategy = BackoffStrategy::new(config);
3951
3952 // 10s, 20s, 30s (capped), 30s (capped)
3953 assert_eq!(strategy.calculate_delay(0), Duration::from_secs(10));
3954 assert_eq!(strategy.calculate_delay(1), Duration::from_secs(20));
3955 assert_eq!(strategy.calculate_delay(2), Duration::from_secs(30)); // 40s capped to 30s
3956 assert_eq!(strategy.calculate_delay(3), Duration::from_secs(30)); // 80s capped to 30s
3957 }
3958
3959 #[test]
3960 fn test_backoff_strategy_jitter_bounds() {
3961 let config = BackoffConfig {
3962 base_delay: Duration::from_secs(1),
3963 max_delay: Duration::from_secs(60),
3964 jitter_factor: 0.25,
3965 multiplier: 2.0,
3966 };
3967 let strategy = BackoffStrategy::new(config);
3968
3969 // Run multiple times to test jitter randomness
3970 for _ in 0..100 {
3971 let delay = strategy.calculate_delay(0);
3972 // Base delay is 1s, jitter is [0, 0.25s]
3973 // So delay should be in [1s, 1.25s]
3974 assert!(delay >= Duration::from_secs(1));
3975 assert!(delay <= Duration::from_millis(1250));
3976 }
3977
3978 for _ in 0..100 {
3979 let delay = strategy.calculate_delay(2);
3980 // Base delay is 4s, jitter is [0, 1s]
3981 // So delay should be in [4s, 5s]
3982 assert!(delay >= Duration::from_secs(4));
3983 assert!(delay <= Duration::from_secs(5));
3984 }
3985 }
3986
3987 #[test]
3988 fn test_backoff_strategy_calculate_delay_without_jitter() {
3989 let config = BackoffConfig {
3990 base_delay: Duration::from_secs(1),
3991 max_delay: Duration::from_secs(60),
3992 jitter_factor: 0.25, // Has jitter configured
3993 multiplier: 2.0,
3994 };
3995 let strategy = BackoffStrategy::new(config);
3996
3997 // calculate_delay_without_jitter should always return the same value
3998 assert_eq!(
3999 strategy.calculate_delay_without_jitter(0),
4000 Duration::from_secs(1)
4001 );
4002 assert_eq!(
4003 strategy.calculate_delay_without_jitter(1),
4004 Duration::from_secs(2)
4005 );
4006 assert_eq!(
4007 strategy.calculate_delay_without_jitter(2),
4008 Duration::from_secs(4)
4009 );
4010 }
4011
4012 #[test]
4013 fn test_backoff_strategy_custom_multiplier() {
4014 let config = BackoffConfig {
4015 base_delay: Duration::from_secs(1),
4016 max_delay: Duration::from_secs(100),
4017 jitter_factor: 0.0,
4018 multiplier: 3.0, // Triple each time
4019 };
4020 let strategy = BackoffStrategy::new(config);
4021
4022 // 1s, 3s, 9s, 27s, 81s, 100s (capped)
4023 assert_eq!(strategy.calculate_delay(0), Duration::from_secs(1));
4024 assert_eq!(strategy.calculate_delay(1), Duration::from_secs(3));
4025 assert_eq!(strategy.calculate_delay(2), Duration::from_secs(9));
4026 assert_eq!(strategy.calculate_delay(3), Duration::from_secs(27));
4027 assert_eq!(strategy.calculate_delay(4), Duration::from_secs(81));
4028 assert_eq!(strategy.calculate_delay(5), Duration::from_secs(100)); // 243s capped
4029 }
4030
4031 #[test]
4032 fn test_backoff_strategy_millisecond_precision() {
4033 let config = BackoffConfig {
4034 base_delay: Duration::from_millis(100),
4035 max_delay: Duration::from_secs(10),
4036 jitter_factor: 0.0,
4037 multiplier: 2.0,
4038 };
4039 let strategy = BackoffStrategy::new(config);
4040
4041 // 100ms, 200ms, 400ms, 800ms, 1600ms, ...
4042 assert_eq!(strategy.calculate_delay(0), Duration::from_millis(100));
4043 assert_eq!(strategy.calculate_delay(1), Duration::from_millis(200));
4044 assert_eq!(strategy.calculate_delay(2), Duration::from_millis(400));
4045 assert_eq!(strategy.calculate_delay(3), Duration::from_millis(800));
4046 assert_eq!(strategy.calculate_delay(4), Duration::from_millis(1600));
4047 }
4048
4049 // ==================== WsConfig Tests ====================
4050
4051 #[test]
4052 fn test_ws_config_default() {
4053 let config = WsConfig::default();
4054 assert_eq!(config.connect_timeout, 10000);
4055 assert_eq!(config.ping_interval, 30000);
4056 assert_eq!(config.reconnect_interval, 5000);
4057 assert_eq!(config.max_reconnect_attempts, 5);
4058 assert!(config.auto_reconnect);
4059 assert!(!config.enable_compression);
4060 assert_eq!(config.pong_timeout, 90000);
4061 // New fields added for WebSocket resilience improvements
4062 assert_eq!(config.max_subscriptions, DEFAULT_MAX_SUBSCRIPTIONS);
4063 assert_eq!(config.shutdown_timeout, DEFAULT_SHUTDOWN_TIMEOUT);
4064 // Verify backoff_config defaults
4065 assert_eq!(config.backoff_config.base_delay, Duration::from_secs(1));
4066 assert_eq!(config.backoff_config.max_delay, Duration::from_secs(60));
4067 assert!((config.backoff_config.jitter_factor - 0.25).abs() < f64::EPSILON);
4068 assert!((config.backoff_config.multiplier - 2.0).abs() < f64::EPSILON);
4069 }
4070
4071 #[test]
4072 fn test_subscription_key() {
4073 let key1 = WsClient::subscription_key("ticker", Some(&"BTC/USDT".to_string()));
4074 assert_eq!(key1, "ticker:BTC/USDT");
4075
4076 let key2 = WsClient::subscription_key("trades", None);
4077 assert_eq!(key2, "trades");
4078 }
4079
4080 #[tokio::test]
4081 async fn test_ws_client_creation() {
4082 let config = WsConfig {
4083 url: "wss://example.com/ws".to_string(),
4084 ..Default::default()
4085 };
4086
4087 let client = WsClient::new(config);
4088 // state() is now lock-free (no await needed)
4089 assert_eq!(client.state(), WsConnectionState::Disconnected);
4090 // is_connected() is now lock-free (no await needed)
4091 assert!(!client.is_connected());
4092 }
4093
4094 #[tokio::test]
4095 async fn test_subscribe_adds_subscription() {
4096 let config = WsConfig {
4097 url: "wss://example.com/ws".to_string(),
4098 ..Default::default()
4099 };
4100
4101 let client = WsClient::new(config);
4102
4103 let result = client
4104 .subscribe("ticker".to_string(), Some("BTC/USDT".to_string()), None)
4105 .await;
4106 assert!(result.is_ok());
4107
4108 // Use DashMap API (lock-free)
4109 assert_eq!(client.subscription_count(), 1);
4110 assert!(client.is_subscribed("ticker", Some(&"BTC/USDT".to_string())));
4111 }
4112
4113 #[tokio::test]
4114 async fn test_unsubscribe_removes_subscription() {
4115 let config = WsConfig {
4116 url: "wss://example.com/ws".to_string(),
4117 ..Default::default()
4118 };
4119
4120 let client = WsClient::new(config);
4121
4122 client
4123 .subscribe("ticker".to_string(), Some("BTC/USDT".to_string()), None)
4124 .await
4125 .unwrap();
4126
4127 let result = client
4128 .unsubscribe("ticker".to_string(), Some("BTC/USDT".to_string()))
4129 .await;
4130 assert!(result.is_ok());
4131
4132 // Use DashMap API (lock-free)
4133 assert_eq!(client.subscription_count(), 0);
4134 assert!(!client.is_subscribed("ticker", Some(&"BTC/USDT".to_string())));
4135 }
4136
4137 #[test]
4138 fn test_ws_message_serialization() {
4139 let msg = WsMessage::Subscribe {
4140 channel: "ticker".to_string(),
4141 symbol: Some("BTC/USDT".to_string()),
4142 params: None,
4143 };
4144
4145 let json = serde_json::to_string(&msg).unwrap();
4146 assert!(json.contains("\"type\":\"subscribe\""));
4147 assert!(json.contains("\"channel\":\"ticker\""));
4148 }
4149
4150 #[test]
4151 fn test_ws_connection_state_from_u8() {
4152 assert_eq!(
4153 WsConnectionState::from_u8(0),
4154 WsConnectionState::Disconnected
4155 );
4156 assert_eq!(WsConnectionState::from_u8(1), WsConnectionState::Connecting);
4157 assert_eq!(WsConnectionState::from_u8(2), WsConnectionState::Connected);
4158 assert_eq!(
4159 WsConnectionState::from_u8(3),
4160 WsConnectionState::Reconnecting
4161 );
4162 assert_eq!(WsConnectionState::from_u8(4), WsConnectionState::Error);
4163 // Unknown values default to Error
4164 assert_eq!(WsConnectionState::from_u8(5), WsConnectionState::Error);
4165 assert_eq!(WsConnectionState::from_u8(255), WsConnectionState::Error);
4166 }
4167
4168 #[test]
4169 fn test_ws_connection_state_as_u8() {
4170 assert_eq!(WsConnectionState::Disconnected.as_u8(), 0);
4171 assert_eq!(WsConnectionState::Connecting.as_u8(), 1);
4172 assert_eq!(WsConnectionState::Connected.as_u8(), 2);
4173 assert_eq!(WsConnectionState::Reconnecting.as_u8(), 3);
4174 assert_eq!(WsConnectionState::Error.as_u8(), 4);
4175 }
4176
4177 #[test]
4178 fn test_reconnect_count_lock_free() {
4179 let config = WsConfig {
4180 url: "wss://example.com/ws".to_string(),
4181 ..Default::default()
4182 };
4183
4184 let client = WsClient::new(config);
4185
4186 // Initial count should be 0
4187 assert_eq!(client.reconnect_count(), 0);
4188
4189 // Reset should work
4190 client.reset_reconnect_count();
4191 assert_eq!(client.reconnect_count(), 0);
4192 }
4193
4194 #[tokio::test]
4195 async fn test_concurrent_subscription_operations() {
4196 use std::sync::Arc;
4197
4198 let config = WsConfig {
4199 url: "wss://example.com/ws".to_string(),
4200 ..Default::default()
4201 };
4202
4203 let client = Arc::new(WsClient::new(config));
4204
4205 // Spawn multiple tasks that add subscriptions concurrently
4206 let mut handles = vec![];
4207
4208 for i in 0..10 {
4209 let client_clone = Arc::clone(&client);
4210 let handle = tokio::spawn(async move {
4211 let channel = format!("channel{}", i);
4212 let symbol = Some(format!("SYMBOL{}/USDT", i));
4213 client_clone.subscribe(channel, symbol, None).await.unwrap();
4214 });
4215 handles.push(handle);
4216 }
4217
4218 // Wait for all tasks to complete
4219 for handle in handles {
4220 handle.await.unwrap();
4221 }
4222
4223 // All subscriptions should be present
4224 assert_eq!(client.subscription_count(), 10);
4225
4226 // Verify each subscription exists
4227 for i in 0..10 {
4228 assert!(
4229 client.is_subscribed(&format!("channel{}", i), Some(&format!("SYMBOL{}/USDT", i)))
4230 );
4231 }
4232 }
4233
4234 #[tokio::test]
4235 async fn test_concurrent_state_access() {
4236 use std::sync::Arc;
4237
4238 let config = WsConfig {
4239 url: "wss://example.com/ws".to_string(),
4240 ..Default::default()
4241 };
4242
4243 let client = Arc::new(WsClient::new(config));
4244
4245 // Spawn multiple tasks that read state concurrently
4246 let mut handles = vec![];
4247
4248 for _ in 0..100 {
4249 let client_clone = Arc::clone(&client);
4250 let handle = tokio::spawn(async move {
4251 // Lock-free state access should not panic or deadlock
4252 let _ = client_clone.state();
4253 let _ = client_clone.is_connected();
4254 });
4255 handles.push(handle);
4256 }
4257
4258 // Wait for all tasks to complete
4259 for handle in handles {
4260 handle.await.unwrap();
4261 }
4262
4263 // State should still be Disconnected
4264 assert_eq!(client.state(), WsConnectionState::Disconnected);
4265 }
4266
4267 #[tokio::test]
4268 async fn test_concurrent_subscription_add_remove() {
4269 use std::sync::Arc;
4270
4271 let config = WsConfig {
4272 url: "wss://example.com/ws".to_string(),
4273 ..Default::default()
4274 };
4275
4276 let client = Arc::new(WsClient::new(config));
4277
4278 // Add some initial subscriptions
4279 for i in 0..5 {
4280 client
4281 .subscribe(
4282 format!("channel{}", i),
4283 Some(format!("SYM{}/USDT", i)),
4284 None,
4285 )
4286 .await
4287 .unwrap();
4288 }
4289
4290 // Spawn tasks that add and remove subscriptions concurrently
4291 let mut handles = vec![];
4292
4293 // Add new subscriptions
4294 for i in 5..10 {
4295 let client_clone = Arc::clone(&client);
4296 let handle = tokio::spawn(async move {
4297 client_clone
4298 .subscribe(
4299 format!("channel{}", i),
4300 Some(format!("SYM{}/USDT", i)),
4301 None,
4302 )
4303 .await
4304 .unwrap();
4305 });
4306 handles.push(handle);
4307 }
4308
4309 // Remove some existing subscriptions
4310 for i in 0..3 {
4311 let client_clone = Arc::clone(&client);
4312 let handle = tokio::spawn(async move {
4313 client_clone
4314 .unsubscribe(format!("channel{}", i), Some(format!("SYM{}/USDT", i)))
4315 .await
4316 .unwrap();
4317 });
4318 handles.push(handle);
4319 }
4320
4321 // Wait for all tasks to complete
4322 for handle in handles {
4323 handle.await.unwrap();
4324 }
4325
4326 // Should have 7 subscriptions (5 initial + 5 added - 3 removed)
4327 assert_eq!(client.subscription_count(), 7);
4328 }
4329
4330 // ==================== WsErrorKind Tests ====================
4331
4332 #[test]
4333 fn test_ws_error_kind_transient() {
4334 let kind = WsErrorKind::Transient;
4335 assert!(kind.is_transient());
4336 assert!(!kind.is_permanent());
4337 assert_eq!(kind.to_string(), "Transient");
4338 }
4339
4340 #[test]
4341 fn test_ws_error_kind_permanent() {
4342 let kind = WsErrorKind::Permanent;
4343 assert!(!kind.is_transient());
4344 assert!(kind.is_permanent());
4345 assert_eq!(kind.to_string(), "Permanent");
4346 }
4347
4348 #[test]
4349 fn test_ws_error_kind_equality() {
4350 assert_eq!(WsErrorKind::Transient, WsErrorKind::Transient);
4351 assert_eq!(WsErrorKind::Permanent, WsErrorKind::Permanent);
4352 assert_ne!(WsErrorKind::Transient, WsErrorKind::Permanent);
4353 }
4354
4355 #[test]
4356 fn test_ws_error_kind_clone() {
4357 let kind = WsErrorKind::Transient;
4358 let cloned = kind;
4359 assert_eq!(kind, cloned);
4360 }
4361
4362 // ==================== WsError Tests ====================
4363
4364 #[test]
4365 fn test_ws_error_transient_creation() {
4366 let err = WsError::transient("Connection timeout");
4367 assert!(err.is_transient());
4368 assert!(!err.is_permanent());
4369 assert_eq!(err.kind(), WsErrorKind::Transient);
4370 assert_eq!(err.message(), "Connection timeout");
4371 assert!(err.source().is_none());
4372 }
4373
4374 #[test]
4375 fn test_ws_error_permanent_creation() {
4376 let err = WsError::permanent("Invalid API key");
4377 assert!(!err.is_transient());
4378 assert!(err.is_permanent());
4379 assert_eq!(err.kind(), WsErrorKind::Permanent);
4380 assert_eq!(err.message(), "Invalid API key");
4381 assert!(err.source().is_none());
4382 }
4383
4384 #[test]
4385 fn test_ws_error_new() {
4386 let err = WsError::new(WsErrorKind::Transient, "Custom error");
4387 assert_eq!(err.kind(), WsErrorKind::Transient);
4388 assert_eq!(err.message(), "Custom error");
4389 }
4390
4391 #[test]
4392 fn test_ws_error_with_source() {
4393 let io_err = std::io::Error::new(std::io::ErrorKind::ConnectionReset, "reset");
4394 let err = WsError::with_source(WsErrorKind::Transient, "Connection lost", io_err);
4395
4396 assert!(err.is_transient());
4397 assert_eq!(err.message(), "Connection lost");
4398 assert!(err.source().is_some());
4399 }
4400
4401 #[test]
4402 fn test_ws_error_transient_with_source() {
4403 let io_err = std::io::Error::new(std::io::ErrorKind::TimedOut, "timeout");
4404 let err = WsError::transient_with_source("Network timeout", io_err);
4405
4406 assert!(err.is_transient());
4407 assert!(err.source().is_some());
4408 }
4409
4410 #[test]
4411 fn test_ws_error_permanent_with_source() {
4412 let io_err = std::io::Error::new(std::io::ErrorKind::PermissionDenied, "denied");
4413 let err = WsError::permanent_with_source("Access denied", io_err);
4414
4415 assert!(err.is_permanent());
4416 assert!(err.source().is_some());
4417 }
4418
4419 #[test]
4420 fn test_ws_error_display() {
4421 let err = WsError::transient("Connection timeout");
4422 let display = err.to_string();
4423 assert!(display.contains("Transient"));
4424 assert!(display.contains("Connection timeout"));
4425 }
4426
4427 #[test]
4428 fn test_ws_error_display_permanent() {
4429 let err = WsError::permanent("Auth failed");
4430 let display = err.to_string();
4431 assert!(display.contains("Permanent"));
4432 assert!(display.contains("Auth failed"));
4433 }
4434
4435 #[test]
4436 fn test_ws_error_std_error_trait() {
4437 let io_err = std::io::Error::new(std::io::ErrorKind::ConnectionReset, "reset");
4438 let err = WsError::transient_with_source("Connection lost", io_err);
4439
4440 // Test std::error::Error trait
4441 let std_err: &dyn std::error::Error = &err;
4442 assert!(std_err.source().is_some());
4443 }
4444
4445 // ==================== WsError::from_tungstenite Tests ====================
4446
4447 #[test]
4448 fn test_ws_error_from_tungstenite_connection_closed() {
4449 use tokio_tungstenite::tungstenite::Error as TungError;
4450
4451 let tung_err = TungError::ConnectionClosed;
4452 let ws_err = WsError::from_tungstenite(&tung_err);
4453
4454 assert!(ws_err.is_transient());
4455 assert!(ws_err.message().contains("Connection closed"));
4456 }
4457
4458 #[test]
4459 fn test_ws_error_from_tungstenite_already_closed() {
4460 use tokio_tungstenite::tungstenite::Error as TungError;
4461
4462 let tung_err = TungError::AlreadyClosed;
4463 let ws_err = WsError::from_tungstenite(&tung_err);
4464
4465 assert!(ws_err.is_transient());
4466 assert!(ws_err.message().contains("already closed"));
4467 }
4468
4469 #[test]
4470 fn test_ws_error_from_tungstenite_io_error() {
4471 use tokio_tungstenite::tungstenite::Error as TungError;
4472
4473 let io_err = std::io::Error::new(std::io::ErrorKind::ConnectionReset, "connection reset");
4474 let tung_err = TungError::Io(io_err);
4475 let ws_err = WsError::from_tungstenite(&tung_err);
4476
4477 assert!(ws_err.is_transient());
4478 assert!(ws_err.message().contains("IO error"));
4479 }
4480
4481 #[test]
4482 fn test_ws_error_from_tungstenite_utf8_error() {
4483 use tokio_tungstenite::tungstenite::Error as TungError;
4484
4485 let tung_err = TungError::Utf8("invalid utf8".to_string());
4486 let ws_err = WsError::from_tungstenite(&tung_err);
4487
4488 assert!(ws_err.is_permanent());
4489 assert!(ws_err.message().contains("UTF-8"));
4490 }
4491
4492 #[test]
4493 fn test_ws_error_from_tungstenite_attack_attempt() {
4494 use tokio_tungstenite::tungstenite::Error as TungError;
4495
4496 let tung_err = TungError::AttackAttempt;
4497 let ws_err = WsError::from_tungstenite(&tung_err);
4498
4499 assert!(ws_err.is_permanent());
4500 assert!(ws_err.message().contains("attack"));
4501 }
4502
4503 // ==================== WsError::from_error Tests ====================
4504
4505 #[test]
4506 fn test_ws_error_from_error_authentication() {
4507 let err = Error::authentication("Invalid API key");
4508 let ws_err = WsError::from_error(&err);
4509
4510 assert!(ws_err.is_permanent());
4511 assert!(ws_err.message().contains("Authentication"));
4512 }
4513
4514 #[test]
4515 fn test_ws_error_from_error_cancelled() {
4516 let err = Error::cancelled("Operation cancelled");
4517 let ws_err = WsError::from_error(&err);
4518
4519 assert!(ws_err.is_permanent());
4520 assert!(ws_err.message().contains("cancelled"));
4521 }
4522
4523 #[test]
4524 fn test_ws_error_from_error_resource_exhausted() {
4525 let err = Error::resource_exhausted("Max subscriptions reached");
4526 let ws_err = WsError::from_error(&err);
4527
4528 assert!(ws_err.is_permanent());
4529 assert!(ws_err.message().contains("exhausted"));
4530 }
4531
4532 #[test]
4533 fn test_ws_error_from_error_network() {
4534 let err = Error::network("Connection failed");
4535 let ws_err = WsError::from_error(&err);
4536
4537 // Network errors are transient by default
4538 assert!(ws_err.is_transient());
4539 }
4540
4541 // ==================== SubscriptionManager Tests ====================
4542
4543 #[test]
4544 fn test_subscription_manager_new() {
4545 let manager = SubscriptionManager::new(50);
4546 assert_eq!(manager.max_subscriptions(), 50);
4547 assert_eq!(manager.count(), 0);
4548 assert_eq!(manager.remaining_capacity(), 50);
4549 assert!(manager.is_empty());
4550 assert!(!manager.is_full());
4551 }
4552
4553 #[test]
4554 fn test_subscription_manager_with_default_capacity() {
4555 let manager = SubscriptionManager::with_default_capacity();
4556 assert_eq!(manager.max_subscriptions(), DEFAULT_MAX_SUBSCRIPTIONS);
4557 assert_eq!(manager.count(), 0);
4558 assert_eq!(manager.remaining_capacity(), DEFAULT_MAX_SUBSCRIPTIONS);
4559 }
4560
4561 #[test]
4562 fn test_subscription_manager_default_trait() {
4563 let manager = SubscriptionManager::default();
4564 assert_eq!(manager.max_subscriptions(), DEFAULT_MAX_SUBSCRIPTIONS);
4565 }
4566
4567 #[test]
4568 fn test_subscription_manager_try_add_success() {
4569 let manager = SubscriptionManager::new(10);
4570
4571 let subscription = Subscription {
4572 channel: "ticker".to_string(),
4573 symbol: Some("BTC/USDT".to_string()),
4574 params: None,
4575 };
4576
4577 let result = manager.try_add("ticker:BTC/USDT".to_string(), subscription);
4578 assert!(result.is_ok());
4579 assert_eq!(manager.count(), 1);
4580 assert_eq!(manager.remaining_capacity(), 9);
4581 assert!(manager.contains("ticker:BTC/USDT"));
4582 }
4583
4584 #[test]
4585 fn test_subscription_manager_try_add_at_capacity() {
4586 let manager = SubscriptionManager::new(2);
4587
4588 // Add first subscription
4589 let sub1 = Subscription {
4590 channel: "ticker".to_string(),
4591 symbol: Some("BTC/USDT".to_string()),
4592 params: None,
4593 };
4594 assert!(manager.try_add("ticker:BTC/USDT".to_string(), sub1).is_ok());
4595
4596 // Add second subscription
4597 let sub2 = Subscription {
4598 channel: "ticker".to_string(),
4599 symbol: Some("ETH/USDT".to_string()),
4600 params: None,
4601 };
4602 assert!(manager.try_add("ticker:ETH/USDT".to_string(), sub2).is_ok());
4603
4604 // Third subscription should fail
4605 let sub3 = Subscription {
4606 channel: "ticker".to_string(),
4607 symbol: Some("SOL/USDT".to_string()),
4608 params: None,
4609 };
4610 let result = manager.try_add("ticker:SOL/USDT".to_string(), sub3);
4611 assert!(result.is_err());
4612
4613 // Verify error is ResourceExhausted
4614 let err = result.unwrap_err();
4615 assert!(err.as_resource_exhausted().is_some());
4616 assert!(err.to_string().contains("Maximum subscriptions"));
4617
4618 // Count should still be 2
4619 assert_eq!(manager.count(), 2);
4620 assert_eq!(manager.remaining_capacity(), 0);
4621 assert!(manager.is_full());
4622 }
4623
4624 #[test]
4625 fn test_subscription_manager_try_add_replace_existing() {
4626 let manager = SubscriptionManager::new(2);
4627
4628 // Fill to capacity
4629 let sub1 = Subscription {
4630 channel: "ticker".to_string(),
4631 symbol: Some("BTC/USDT".to_string()),
4632 params: None,
4633 };
4634 manager
4635 .try_add("ticker:BTC/USDT".to_string(), sub1)
4636 .unwrap();
4637
4638 let sub2 = Subscription {
4639 channel: "ticker".to_string(),
4640 symbol: Some("ETH/USDT".to_string()),
4641 params: None,
4642 };
4643 manager
4644 .try_add("ticker:ETH/USDT".to_string(), sub2)
4645 .unwrap();
4646
4647 assert!(manager.is_full());
4648
4649 // Replacing existing key should succeed even at capacity
4650 let sub1_updated = Subscription {
4651 channel: "ticker".to_string(),
4652 symbol: Some("BTC/USDT".to_string()),
4653 params: Some(HashMap::new()), // Different params
4654 };
4655 let result = manager.try_add("ticker:BTC/USDT".to_string(), sub1_updated);
4656 assert!(result.is_ok());
4657
4658 // Count should still be 2
4659 assert_eq!(manager.count(), 2);
4660 }
4661
4662 #[test]
4663 fn test_subscription_manager_remove() {
4664 let manager = SubscriptionManager::new(10);
4665
4666 let subscription = Subscription {
4667 channel: "ticker".to_string(),
4668 symbol: Some("BTC/USDT".to_string()),
4669 params: None,
4670 };
4671 manager
4672 .try_add("ticker:BTC/USDT".to_string(), subscription)
4673 .unwrap();
4674
4675 assert_eq!(manager.count(), 1);
4676 assert_eq!(manager.remaining_capacity(), 9);
4677
4678 // Remove the subscription
4679 let removed = manager.remove("ticker:BTC/USDT");
4680 assert!(removed.is_some());
4681
4682 // Verify removal
4683 assert_eq!(manager.count(), 0);
4684 assert_eq!(manager.remaining_capacity(), 10);
4685 assert!(!manager.contains("ticker:BTC/USDT"));
4686 assert!(manager.is_empty());
4687 }
4688
4689 #[test]
4690 fn test_subscription_manager_remove_nonexistent() {
4691 let manager = SubscriptionManager::new(10);
4692
4693 let removed = manager.remove("nonexistent");
4694 assert!(removed.is_none());
4695 assert_eq!(manager.count(), 0);
4696 }
4697
4698 #[test]
4699 fn test_subscription_manager_remove_frees_slot() {
4700 let manager = SubscriptionManager::new(2);
4701
4702 // Fill to capacity
4703 let sub1 = Subscription {
4704 channel: "ticker".to_string(),
4705 symbol: Some("BTC/USDT".to_string()),
4706 params: None,
4707 };
4708 manager
4709 .try_add("ticker:BTC/USDT".to_string(), sub1)
4710 .unwrap();
4711
4712 let sub2 = Subscription {
4713 channel: "ticker".to_string(),
4714 symbol: Some("ETH/USDT".to_string()),
4715 params: None,
4716 };
4717 manager
4718 .try_add("ticker:ETH/USDT".to_string(), sub2)
4719 .unwrap();
4720
4721 assert!(manager.is_full());
4722
4723 // Remove one subscription
4724 manager.remove("ticker:BTC/USDT");
4725
4726 // Should now be able to add a new subscription
4727 let sub3 = Subscription {
4728 channel: "ticker".to_string(),
4729 symbol: Some("SOL/USDT".to_string()),
4730 params: None,
4731 };
4732 let result = manager.try_add("ticker:SOL/USDT".to_string(), sub3);
4733 assert!(result.is_ok());
4734 assert_eq!(manager.count(), 2);
4735 }
4736
4737 #[test]
4738 fn test_subscription_manager_clear() {
4739 let manager = SubscriptionManager::new(10);
4740
4741 // Add some subscriptions
4742 for i in 0..5 {
4743 let sub = Subscription {
4744 channel: format!("channel{}", i),
4745 symbol: Some(format!("SYM{}/USDT", i)),
4746 params: None,
4747 };
4748 manager
4749 .try_add(format!("channel{}:SYM{}/USDT", i, i), sub)
4750 .unwrap();
4751 }
4752
4753 assert_eq!(manager.count(), 5);
4754
4755 // Clear all
4756 manager.clear();
4757
4758 assert_eq!(manager.count(), 0);
4759 assert_eq!(manager.remaining_capacity(), 10);
4760 assert!(manager.is_empty());
4761 }
4762
4763 #[test]
4764 fn test_subscription_manager_get() {
4765 let manager = SubscriptionManager::new(10);
4766
4767 let subscription = Subscription {
4768 channel: "ticker".to_string(),
4769 symbol: Some("BTC/USDT".to_string()),
4770 params: None,
4771 };
4772 manager
4773 .try_add("ticker:BTC/USDT".to_string(), subscription)
4774 .unwrap();
4775
4776 // Get existing
4777 let got = manager.get("ticker:BTC/USDT");
4778 assert!(got.is_some());
4779 assert_eq!(got.unwrap().channel, "ticker");
4780
4781 // Get nonexistent
4782 let got = manager.get("nonexistent");
4783 assert!(got.is_none());
4784 }
4785
4786 #[test]
4787 fn test_subscription_manager_collect_subscriptions() {
4788 let manager = SubscriptionManager::new(10);
4789
4790 for i in 0..3 {
4791 let sub = Subscription {
4792 channel: format!("channel{}", i),
4793 symbol: Some(format!("SYM{}/USDT", i)),
4794 params: None,
4795 };
4796 manager
4797 .try_add(format!("channel{}:SYM{}/USDT", i, i), sub)
4798 .unwrap();
4799 }
4800
4801 let collected = manager.collect_subscriptions();
4802 assert_eq!(collected.len(), 3);
4803 }
4804
4805 #[test]
4806 fn test_subscription_manager_concurrent_operations() {
4807 use std::sync::Arc;
4808 use std::thread;
4809
4810 let manager = Arc::new(SubscriptionManager::new(100));
4811 let mut handles = vec![];
4812
4813 // Spawn threads that add subscriptions concurrently
4814 for i in 0..10 {
4815 let manager_clone = Arc::clone(&manager);
4816 let handle = thread::spawn(move || {
4817 for j in 0..5 {
4818 let sub = Subscription {
4819 channel: format!("channel{}_{}", i, j),
4820 symbol: Some(format!("SYM{}_{}/USDT", i, j)),
4821 params: None,
4822 };
4823 let _ = manager_clone
4824 .try_add(format!("channel{}_{}:SYM{}_{}/USDT", i, j, i, j), sub);
4825 }
4826 });
4827 handles.push(handle);
4828 }
4829
4830 // Wait for all threads
4831 for handle in handles {
4832 handle.join().unwrap();
4833 }
4834
4835 // Should have 50 subscriptions (10 threads * 5 each)
4836 assert_eq!(manager.count(), 50);
4837 assert_eq!(manager.remaining_capacity(), 50);
4838 }
4839
4840 #[test]
4841 fn test_subscription_manager_concurrent_add_remove() {
4842 use std::sync::Arc;
4843 use std::thread;
4844
4845 let manager = Arc::new(SubscriptionManager::new(100));
4846
4847 // Pre-populate with some subscriptions
4848 for i in 0..20 {
4849 let sub = Subscription {
4850 channel: format!("channel{}", i),
4851 symbol: Some(format!("SYM{}/USDT", i)),
4852 params: None,
4853 };
4854 manager
4855 .try_add(format!("channel{}:SYM{}/USDT", i, i), sub)
4856 .unwrap();
4857 }
4858
4859 let mut handles = vec![];
4860
4861 // Spawn threads that add new subscriptions
4862 for i in 20..30 {
4863 let manager_clone = Arc::clone(&manager);
4864 let handle = thread::spawn(move || {
4865 let sub = Subscription {
4866 channel: format!("channel{}", i),
4867 symbol: Some(format!("SYM{}/USDT", i)),
4868 params: None,
4869 };
4870 let _ = manager_clone.try_add(format!("channel{}:SYM{}/USDT", i, i), sub);
4871 });
4872 handles.push(handle);
4873 }
4874
4875 // Spawn threads that remove existing subscriptions
4876 for i in 0..10 {
4877 let manager_clone = Arc::clone(&manager);
4878 let handle = thread::spawn(move || {
4879 manager_clone.remove(&format!("channel{}:SYM{}/USDT", i, i));
4880 });
4881 handles.push(handle);
4882 }
4883
4884 // Wait for all threads
4885 for handle in handles {
4886 handle.join().unwrap();
4887 }
4888
4889 // Should have 20 subscriptions (20 initial + 10 added - 10 removed)
4890 assert_eq!(manager.count(), 20);
4891 }
4892
4893 // ==================== CancellationToken Tests ====================
4894
4895 #[tokio::test]
4896 async fn test_ws_client_set_cancel_token() {
4897 let config = WsConfig {
4898 url: "wss://example.com/ws".to_string(),
4899 ..Default::default()
4900 };
4901
4902 let client = WsClient::new(config);
4903
4904 // Initially no token
4905 assert!(client.get_cancel_token().await.is_none());
4906
4907 // Set a token
4908 let token = CancellationToken::new();
4909 client.set_cancel_token(token.clone()).await;
4910
4911 // Token should be set
4912 let retrieved = client.get_cancel_token().await;
4913 assert!(retrieved.is_some());
4914 }
4915
4916 #[tokio::test]
4917 async fn test_ws_client_clear_cancel_token() {
4918 let config = WsConfig {
4919 url: "wss://example.com/ws".to_string(),
4920 ..Default::default()
4921 };
4922
4923 let client = WsClient::new(config);
4924
4925 // Set a token
4926 let token = CancellationToken::new();
4927 client.set_cancel_token(token).await;
4928 assert!(client.get_cancel_token().await.is_some());
4929
4930 // Clear the token
4931 client.clear_cancel_token().await;
4932 assert!(client.get_cancel_token().await.is_none());
4933 }
4934
4935 #[tokio::test]
4936 async fn test_cancellation_token_sharing() {
4937 // Test that CancellationToken is properly shared (cloning shares state)
4938 let token = CancellationToken::new();
4939 let token_clone = token.clone();
4940
4941 // Neither should be cancelled initially
4942 assert!(!token.is_cancelled());
4943 assert!(!token_clone.is_cancelled());
4944
4945 // Cancel the clone
4946 token_clone.cancel();
4947
4948 // Both should be cancelled (shared state)
4949 assert!(token.is_cancelled());
4950 assert!(token_clone.is_cancelled());
4951 }
4952
4953 #[tokio::test]
4954 async fn test_cancellation_token_child() {
4955 // Test child token behavior
4956 let parent = CancellationToken::new();
4957 let child = parent.child_token();
4958
4959 // Neither should be cancelled initially
4960 assert!(!parent.is_cancelled());
4961 assert!(!child.is_cancelled());
4962
4963 // Cancel the parent
4964 parent.cancel();
4965
4966 // Both should be cancelled
4967 assert!(parent.is_cancelled());
4968 assert!(child.is_cancelled());
4969 }
4970
4971 #[tokio::test]
4972 async fn test_cancellation_token_child_independent() {
4973 // Test that cancelling child doesn't cancel parent
4974 let parent = CancellationToken::new();
4975 let child = parent.child_token();
4976
4977 // Cancel the child
4978 child.cancel();
4979
4980 // Child should be cancelled, parent should not
4981 assert!(!parent.is_cancelled());
4982 assert!(child.is_cancelled());
4983 }
4984}