Skip to main content

nautilus_network/websocket/
client.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2026 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16//! WebSocket client implementation with automatic reconnection.
17//!
18//! This module contains the core WebSocket client implementation including:
19//! - Connection management with automatic reconnection.
20//! - Split read/write architecture with separate tasks.
21//! - Unbounded channels on latency-sensitive paths.
22//! - Event-driven state notification via `Notify` for immediate wakeup on transitions.
23//! - Heartbeat support.
24//! - Rate limiting integration.
25
26use std::{
27    collections::VecDeque,
28    fmt::Debug,
29    sync::{
30        Arc, OnceLock,
31        atomic::{AtomicBool, AtomicU8, Ordering},
32    },
33    time::Duration,
34};
35
36use futures_util::{SinkExt, StreamExt};
37use http::HeaderName;
38use nautilus_core::CleanDrop;
39use nautilus_cryptography::providers::install_cryptographic_provider;
40#[cfg(any(feature = "turmoil", feature = "transport-sockudo"))]
41use rustls::ClientConfig;
42#[cfg(feature = "transport-sockudo")]
43use sockudo_ws::{
44    Config as SockudoConfig, Http1, Role, Stream as SockudoStream,
45    WebSocketStream as SockudoWebSocketStream,
46};
47#[cfg(feature = "transport-sockudo")]
48use tokio::io::{AsyncRead, AsyncWrite};
49#[cfg(any(feature = "turmoil", feature = "transport-sockudo"))]
50use tokio_rustls::TlsConnector;
51#[cfg(feature = "turmoil")]
52use tokio_tungstenite::MaybeTlsStream;
53#[cfg(feature = "turmoil")]
54use tokio_tungstenite::client_async;
55#[cfg(not(feature = "turmoil"))]
56use tokio_tungstenite::connect_async_with_config;
57use tokio_tungstenite::tungstenite::{client::IntoClientRequest, http::HeaderValue};
58use ustr::Ustr;
59
60#[cfg(not(feature = "turmoil"))]
61use super::proxy::{ProxiedStream, ProxyKind, WsTarget, tunnel_via_proxy};
62use super::{
63    auth::{AuthState, AuthTracker},
64    config::{TransportBackend, WebSocketConfig},
65    consts::{
66        CONNECTION_STATE_CHECK_INTERVAL_MS, GRACEFUL_SHUTDOWN_DELAY_MS,
67        GRACEFUL_SHUTDOWN_TIMEOUT_SECS,
68    },
69    types::{MessageHandler, MessageReader, MessageWriter, PingHandler, WriterCommand},
70};
71#[cfg(feature = "turmoil")]
72use crate::net::TcpConnector;
73#[cfg(feature = "transport-sockudo")]
74use crate::net::TcpStream;
75#[cfg(feature = "transport-sockudo")]
76use crate::transport::sockudo::{
77    PrefixedIo, SockudoTransport, client_handshake_with_headers, validate_extra_headers,
78};
79use crate::{
80    RECONNECTED,
81    backoff::ExponentialBackoff,
82    dst,
83    error::SendError,
84    logging::{log_task_aborted, log_task_started, log_task_stopped},
85    mode::ConnectionMode,
86    ratelimiter::{RateLimiter, clock::MonotonicClock, quota::Quota},
87    transport::{BoxedWsTransport, Message, TransportError, tungstenite::TungsteniteTransport},
88};
89
90/// `WebSocketClient` connects to a websocket server to read and send messages.
91///
92/// The client is opinionated about how messages are read and written. It
93/// assumes that data can only have one reader but multiple writers.
94///
95/// The client splits the connection into read and write halves. It moves
96/// the read half into a tokio task which keeps receiving messages from the
97/// server and calls a handler - a Python function that takes the data
98/// as its parameter. It stores the write half in the struct wrapped
99/// with an Arc Mutex. This way the client struct can be used to write
100/// data to the server from multiple scopes/tasks.
101///
102/// The client also maintains a heartbeat if given a duration in seconds.
103/// It's preferable to set the duration slightly lower - heartbeat more
104/// frequently - than the required amount.
105pub struct WebSocketClientInner {
106    config: WebSocketConfig,
107    /// The function to handle incoming messages (stored separately from config).
108    message_handler: Option<MessageHandler>,
109    /// The handler for incoming pings (stored separately from config).
110    ping_handler: Option<PingHandler>,
111    read_task: Option<tokio::task::JoinHandle<()>>,
112    write_task: tokio::task::JoinHandle<()>,
113    writer_tx: tokio::sync::mpsc::UnboundedSender<WriterCommand>,
114    heartbeat_task: Option<tokio::task::JoinHandle<()>>,
115    connection_mode: Arc<AtomicU8>,
116    state_notify: Arc<tokio::sync::Notify>,
117    reconnect_timeout: Duration,
118    backoff: ExponentialBackoff,
119    /// True if this is a stream-based client (created via `connect_stream`).
120    /// Stream-based clients disable auto-reconnect because the reader is
121    /// owned by the caller and cannot be replaced during reconnection.
122    is_stream_mode: bool,
123    /// Maximum number of reconnection attempts before giving up (None = unlimited).
124    reconnect_max_attempts: Option<u32>,
125    /// Current count of consecutive reconnection attempts.
126    reconnection_attempt_count: u32,
127    /// Shared auth tracker invalidated on connection drops.
128    auth_tracker: Arc<OnceLock<AuthTracker>>,
129    /// Controls whether buffered replay waits for the next authenticated session.
130    reconnect_buffer_waits_for_auth: Arc<AtomicBool>,
131}
132
133enum ReconnectBufferAction {
134    Drain,
135    Wait,
136    Discard,
137}
138
139impl WebSocketClientInner {
140    /// Create an inner websocket client with an existing writer.
141    ///
142    /// This is used for stream mode where the reader is owned by the caller.
143    ///
144    /// # Errors
145    ///
146    /// Returns an error if the exponential backoff configuration is invalid.
147    #[expect(
148        clippy::unused_async,
149        reason = "async signature for consistency with connect-based constructors"
150    )]
151    pub async fn new_with_writer(
152        config: WebSocketConfig,
153        writer: MessageWriter,
154    ) -> Result<Self, TransportError> {
155        install_cryptographic_provider();
156
157        let connection_mode = Arc::new(AtomicU8::new(ConnectionMode::Active.as_u8()));
158        let state_notify = Arc::new(tokio::sync::Notify::new());
159
160        // Note: We don't spawn a read task here since the reader is handled externally
161        let read_task = None;
162
163        // Stream mode ignores reconnect settings, use harmless defaults
164        let backoff = ExponentialBackoff::new(
165            Duration::from_secs(2),
166            Duration::from_secs(30),
167            1.5,
168            100,
169            true,
170        )
171        .map_err(|e| {
172            TransportError::Io(std::io::Error::new(std::io::ErrorKind::InvalidInput, e))
173        })?;
174
175        let auth_tracker = Arc::new(OnceLock::new());
176        let reconnect_buffer_waits_for_auth = Arc::new(AtomicBool::new(false));
177
178        let (writer_tx, writer_rx) = tokio::sync::mpsc::unbounded_channel::<WriterCommand>();
179        let write_task = Self::spawn_write_task(
180            connection_mode.clone(),
181            state_notify.clone(),
182            writer,
183            writer_rx,
184            Arc::clone(&auth_tracker),
185            Arc::clone(&reconnect_buffer_waits_for_auth),
186        );
187
188        let heartbeat_task = if let Some(heartbeat_interval) = config.heartbeat {
189            Some(Self::spawn_heartbeat_task(
190                connection_mode.clone(),
191                heartbeat_interval,
192                config.heartbeat_msg.clone(),
193                writer_tx.clone(),
194            ))
195        } else {
196            None
197        };
198
199        let reconnect_max_attempts = None; // Stream mode does not reconnect
200        let reconnect_timeout = Duration::from_secs(10);
201
202        Ok(Self {
203            config,
204            message_handler: None, // Stream mode has no handler
205            ping_handler: None,
206            writer_tx,
207            connection_mode,
208            state_notify,
209            reconnect_timeout,
210            heartbeat_task,
211            read_task,
212            write_task,
213            backoff,
214            is_stream_mode: true,
215            reconnect_max_attempts,
216            reconnection_attempt_count: 0,
217            auth_tracker,
218            reconnect_buffer_waits_for_auth,
219        })
220    }
221
222    /// Create an inner websocket client.
223    ///
224    /// # Errors
225    ///
226    /// Returns an error if:
227    /// - The connection to the server fails.
228    /// - The exponential backoff configuration is invalid.
229    pub async fn connect_url(
230        config: WebSocketConfig,
231        message_handler: Option<MessageHandler>,
232        ping_handler: Option<PingHandler>,
233    ) -> Result<Self, TransportError> {
234        install_cryptographic_provider();
235
236        if config.heartbeat == Some(0) {
237            return Err(TransportError::Io(std::io::Error::new(
238                std::io::ErrorKind::InvalidInput,
239                "Heartbeat interval cannot be zero",
240            )));
241        }
242
243        if config.idle_timeout_ms == Some(0) {
244            return Err(TransportError::Io(std::io::Error::new(
245                std::io::ErrorKind::InvalidInput,
246                "Idle timeout cannot be zero",
247            )));
248        }
249
250        // Capture whether we're in stream mode before moving config
251        let is_stream_mode = message_handler.is_none();
252        let reconnect_max_attempts = config.reconnect_max_attempts;
253
254        let (writer, reader) = Box::pin(Self::connect_with_server(
255            &config.url,
256            config.headers.clone(),
257            config.backend,
258            config.proxy_url.as_deref(),
259        ))
260        .await?;
261
262        let connection_mode = Arc::new(AtomicU8::new(ConnectionMode::Active.as_u8()));
263        let state_notify = Arc::new(tokio::sync::Notify::new());
264
265        let read_task = if message_handler.is_some() {
266            Some(Self::spawn_message_handler_task(
267                connection_mode.clone(),
268                state_notify.clone(),
269                reader,
270                message_handler.as_ref(),
271                ping_handler.as_ref(),
272                config.idle_timeout_ms,
273            ))
274        } else {
275            None
276        };
277
278        let auth_tracker = Arc::new(OnceLock::new());
279        let reconnect_buffer_waits_for_auth = Arc::new(AtomicBool::new(false));
280
281        let (writer_tx, writer_rx) = tokio::sync::mpsc::unbounded_channel::<WriterCommand>();
282        let write_task = Self::spawn_write_task(
283            connection_mode.clone(),
284            state_notify.clone(),
285            writer,
286            writer_rx,
287            Arc::clone(&auth_tracker),
288            Arc::clone(&reconnect_buffer_waits_for_auth),
289        );
290
291        // Optionally spawn a heartbeat task to periodically ping server
292        let heartbeat_task = config.heartbeat.map(|heartbeat_secs| {
293            Self::spawn_heartbeat_task(
294                connection_mode.clone(),
295                heartbeat_secs,
296                config.heartbeat_msg.clone(),
297                writer_tx.clone(),
298            )
299        });
300
301        let reconnect_timeout =
302            Duration::from_millis(config.reconnect_timeout_ms.unwrap_or(10_000));
303        let backoff = ExponentialBackoff::new(
304            Duration::from_millis(config.reconnect_delay_initial_ms.unwrap_or(2_000)),
305            Duration::from_millis(config.reconnect_delay_max_ms.unwrap_or(30_000)),
306            config.reconnect_backoff_factor.unwrap_or(1.5),
307            config.reconnect_jitter_ms.unwrap_or(100),
308            true, // immediate-first
309        )
310        .map_err(|e| {
311            TransportError::Io(std::io::Error::new(std::io::ErrorKind::InvalidInput, e))
312        })?;
313
314        Ok(Self {
315            config,
316            message_handler,
317            ping_handler,
318            read_task,
319            write_task,
320            writer_tx,
321            heartbeat_task,
322            connection_mode,
323            state_notify,
324            reconnect_timeout,
325            backoff,
326            // Set stream mode when no message handler (reader not managed by client)
327            is_stream_mode,
328            reconnect_max_attempts,
329            reconnection_attempt_count: 0,
330            auth_tracker,
331            reconnect_buffer_waits_for_auth,
332        })
333    }
334
335    /// Connect to the server and return the split halves of the active transport.
336    ///
337    /// Dispatches on `backend` to the matching backend helper. The
338    /// [`TransportBackend::Tungstenite`] backend is always available; the
339    /// [`TransportBackend::Sockudo`] backend requires the `transport-sockudo`
340    /// Cargo feature (enabled by default) and uses a custom HTTP/1.1 handshake
341    /// path for upgrade headers.
342    ///
343    /// When `proxy_url` is `Some`, the Tungstenite backend establishes an HTTP
344    /// `CONNECT` tunnel through the proxy before performing the WebSocket
345    /// handshake. The Sockudo backend does not yet support proxying; when it
346    /// is selected together with a proxy URL, this method logs a warning and
347    /// transparently falls back to Tungstenite so omitted-backend Python
348    /// configurations keep working.
349    ///
350    /// # Errors
351    ///
352    /// Returns a [`TransportError`] if the URL is invalid, headers fail to
353    /// parse, the TCP / TLS layer cannot be established, the proxy refuses
354    /// the tunnel, or the WebSocket handshake is rejected by the peer. When
355    /// the Sockudo backend is selected without the `transport-sockudo`
356    /// feature, returns [`TransportError::Other`].
357    #[inline]
358    pub async fn connect_with_server(
359        url: &str,
360        headers: Vec<(String, String)>,
361        backend: TransportBackend,
362        proxy_url: Option<&str>,
363    ) -> Result<(MessageWriter, MessageReader), TransportError> {
364        // Sockudo does not yet support proxy tunnels. When a proxy URL is supplied,
365        // route through Tungstenite so configurations that rely on the runtime
366        // default keep working (notably the Python `WebSocketConfig` binding,
367        // which exposes `proxy_url` but no `backend` selector).
368        if matches!(backend, TransportBackend::Sockudo)
369            && let Some(proxy) = proxy_url
370        {
371            log::warn!("Sockudo backend does not support proxy_url; falling back to Tungstenite");
372            return Box::pin(Self::connect_tungstenite_via_proxy(url, headers, proxy)).await;
373        }
374
375        match backend {
376            TransportBackend::Tungstenite => match proxy_url {
377                Some(proxy) => {
378                    Box::pin(Self::connect_tungstenite_via_proxy(url, headers, proxy)).await
379                }
380                None => Self::connect_tungstenite(url, headers).await,
381            },
382            TransportBackend::Sockudo => {
383                #[cfg(feature = "transport-sockudo")]
384                {
385                    Self::connect_sockudo(url, headers).await
386                }
387                #[cfg(not(feature = "transport-sockudo"))]
388                {
389                    Err(TransportError::Other(
390                        "sockudo backend selected but the transport-sockudo \
391                         Cargo feature is not enabled"
392                            .to_string(),
393                    ))
394                }
395            }
396        }
397    }
398
399    /// Connects with the server creating a tokio-tungstenite websocket stream.
400    /// Production version that uses `connect_async_with_config` convenience helper.
401    #[inline]
402    #[cfg(not(feature = "turmoil"))]
403    async fn connect_tungstenite(
404        url: &str,
405        headers: Vec<(String, String)>,
406    ) -> Result<(MessageWriter, MessageReader), TransportError> {
407        let mut request = url.into_client_request().map_err(TransportError::from)?;
408        let req_headers = request.headers_mut();
409
410        for (key, val) in headers {
411            let header_value = HeaderValue::from_str(&val)
412                .map_err(|e| TransportError::Handshake(format!("invalid header value: {e}")))?;
413            let header_name: HeaderName = key
414                .parse()
415                .map_err(|e| TransportError::Handshake(format!("invalid header name: {e}")))?;
416            req_headers.insert(header_name, header_value);
417        }
418
419        let (stream, _resp) = connect_async_with_config(request, None, true)
420            .await
421            .map_err(TransportError::from)?;
422        let transport: BoxedWsTransport = Box::pin(TungsteniteTransport::new(stream));
423        Ok(transport.split())
424    }
425
426    /// Connects via an HTTP `CONNECT` proxy and performs the WebSocket
427    /// handshake over the resulting tunnel.
428    ///
429    /// Recognised but unsupported proxy schemes (currently SOCKS) log a
430    /// warning and fall back to a direct connection so existing REST proxy
431    /// configs remain usable. Only available in production builds; the
432    /// turmoil simulator does not model arbitrary outbound TCP via a proxy.
433    #[inline]
434    #[cfg(not(feature = "turmoil"))]
435    async fn connect_tungstenite_via_proxy(
436        url: &str,
437        headers: Vec<(String, String)>,
438        proxy_url: &str,
439    ) -> Result<(MessageWriter, MessageReader), TransportError> {
440        let proxy = match ProxyKind::parse(proxy_url)? {
441            ProxyKind::Http(target) => target,
442            ProxyKind::Unsupported { scheme } => {
443                log::warn!(
444                    "WebSocket proxy_url scheme '{scheme}' is not yet supported; \
445                     connecting without a WebSocket proxy"
446                );
447                return Self::connect_tungstenite(url, headers).await;
448            }
449        };
450
451        let mut request = url.into_client_request().map_err(TransportError::from)?;
452        let req_headers = request.headers_mut();
453
454        for (key, val) in headers {
455            let header_value = HeaderValue::from_str(&val)
456                .map_err(|e| TransportError::Handshake(format!("invalid header value: {e}")))?;
457            let header_name: HeaderName = key
458                .parse()
459                .map_err(|e| TransportError::Handshake(format!("invalid header name: {e}")))?;
460            req_headers.insert(header_name, header_value);
461        }
462
463        let target = WsTarget::parse(url)?;
464        let stream = tunnel_via_proxy(&target, &proxy).await?;
465
466        // Each ProxiedStream variant carries a distinct concrete stream type,
467        // so we monomorphize the handshake through `proxied_ws_handshake`
468        // rather than duplicating the body four times. The arms are
469        // syntactically identical post-deref, but each call instantiates a
470        // different generic; the `match_same_arms` lint is a false positive
471        // here. The futures are boxed because `client_async` produces a
472        // large state machine.
473        #[allow(clippy::match_same_arms)]
474        let transport: BoxedWsTransport = match stream {
475            ProxiedStream::Plain(tcp) => Box::pin(proxied_ws_handshake(request, tcp)).await?,
476            ProxiedStream::PlainOverTlsProxy(s) => {
477                Box::pin(proxied_ws_handshake(request, *s)).await?
478            }
479            ProxiedStream::Tls(s) => Box::pin(proxied_ws_handshake(request, *s)).await?,
480            ProxiedStream::TlsOverTlsProxy(s) => {
481                Box::pin(proxied_ws_handshake(request, *s)).await?
482            }
483        };
484
485        Ok(transport.split())
486    }
487
488    /// Turmoil simulator variant: HTTP `CONNECT` tunneling is not supported
489    /// under the simulator so any proxy URL is rejected up front.
490    #[inline]
491    #[cfg(feature = "turmoil")]
492    #[expect(
493        clippy::unused_async,
494        reason = "signature mirrors the production variant; both are awaited in the dispatcher"
495    )]
496    async fn connect_tungstenite_via_proxy(
497        _url: &str,
498        _headers: Vec<(String, String)>,
499        _proxy_url: &str,
500    ) -> Result<(MessageWriter, MessageReader), TransportError> {
501        Err(TransportError::Other(
502            "proxy_url is not supported under the turmoil simulator".to_string(),
503        ))
504    }
505
506    /// Connects with the server creating a tokio-tungstenite websocket stream.
507    /// Turmoil version that uses the lower-level `client_async` API with injected stream.
508    #[inline]
509    #[cfg(feature = "turmoil")]
510    async fn connect_tungstenite(
511        url: &str,
512        headers: Vec<(String, String)>,
513    ) -> Result<(MessageWriter, MessageReader), TransportError> {
514        let mut request = url.into_client_request().map_err(TransportError::from)?;
515        let req_headers = request.headers_mut();
516
517        for (key, val) in headers {
518            let header_value = HeaderValue::from_str(&val)
519                .map_err(|e| TransportError::Handshake(format!("invalid header value: {e}")))?;
520            let header_name: HeaderName = key
521                .parse()
522                .map_err(|e| TransportError::Handshake(format!("invalid header name: {e}")))?;
523            req_headers.insert(header_name, header_value);
524        }
525
526        let uri = request.uri();
527        let scheme = uri.scheme_str().unwrap_or("ws");
528        let host = uri
529            .host()
530            .ok_or_else(|| TransportError::InvalidUrl("missing hostname".to_string()))?;
531
532        // Determine port: use explicit port if specified, otherwise default based on scheme
533        let port = uri
534            .port_u16()
535            .unwrap_or_else(|| if scheme == "wss" { 443 } else { 80 });
536
537        let addr = format!("{host}:{port}");
538
539        // Use the connector to get a turmoil-compatible stream
540        let connector = crate::net::RealTcpConnector;
541        let tcp_stream = connector.connect(&addr).await?;
542        if let Err(e) = tcp_stream.set_nodelay(true) {
543            log::warn!("Failed to enable TCP_NODELAY for socket client: {e:?}");
544        }
545
546        // Wrap stream appropriately based on scheme
547        let maybe_tls_stream = if scheme == "wss" {
548            // Build TLS config with webpki roots
549            let mut root_store = rustls::RootCertStore::empty();
550            root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
551
552            let config = ClientConfig::builder()
553                .with_root_certificates(root_store)
554                .with_no_client_auth();
555
556            let tls_connector = TlsConnector::from(std::sync::Arc::new(config));
557            let domain = rustls::pki_types::ServerName::try_from(host.to_string())
558                .map_err(|e| TransportError::Tls(format!("Invalid DNS name: {e}")))?;
559
560            let tls_stream = tls_connector
561                .connect(domain, tcp_stream)
562                .await
563                .map_err(TransportError::Io)?;
564            MaybeTlsStream::Rustls(tls_stream)
565        } else {
566            MaybeTlsStream::Plain(tcp_stream)
567        };
568
569        // Use client_async with the stream (plain or TLS)
570        let (stream, _resp) = client_async(request, maybe_tls_stream)
571            .await
572            .map_err(TransportError::from)?;
573        let transport: BoxedWsTransport = Box::pin(TungsteniteTransport::new(stream));
574        Ok(transport.split())
575    }
576
577    /// Connects with the server using the sockudo-ws backend.
578    ///
579    /// Uses a local HTTP/1.1 handshake helper so error logging and stream
580    /// construction stay in our hands regardless of header count.
581    ///
582    /// Under the turmoil simulator, only plaintext `ws://` is supported (the
583    /// simulator does not model TLS), so a `wss://` URL returns
584    /// [`TransportError::Tls`] up front.
585    #[inline]
586    #[cfg(feature = "transport-sockudo")]
587    async fn connect_sockudo(
588        url: &str,
589        headers: Vec<(String, String)>,
590    ) -> Result<(MessageWriter, MessageReader), TransportError> {
591        let target = SockudoTarget::parse(url)?;
592        validate_extra_headers(&headers).map_err(TransportError::from)?;
593
594        #[cfg(feature = "turmoil")]
595        if target.is_tls {
596            return Err(TransportError::Tls(
597                "wss:// is not supported under the turmoil simulator; use ws://".to_string(),
598            ));
599        }
600
601        let tcp_stream = TcpStream::connect((target.host.as_str(), target.port))
602            .await
603            .map_err(TransportError::Io)?;
604
605        if let Err(e) = tcp_stream.set_nodelay(true) {
606            log::warn!("Failed to enable TCP_NODELAY for sockudo client: {e:?}");
607        }
608
609        #[cfg(not(feature = "turmoil"))]
610        if target.is_tls {
611            let mut root_store = rustls::RootCertStore::empty();
612            root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
613            let config = ClientConfig::builder()
614                .with_root_certificates(root_store)
615                .with_no_client_auth();
616            let connector = TlsConnector::from(std::sync::Arc::new(config));
617            let domain = rustls::pki_types::ServerName::try_from(target.host.clone())
618                .map_err(|e| TransportError::Tls(format!("Invalid DNS name: {e}")))?;
619            let tls_stream = connector
620                .connect(domain, tcp_stream)
621                .await
622                .map_err(TransportError::Io)?;
623            return Self::finish_sockudo_handshake(tls_stream, &target, &headers).await;
624        }
625
626        Self::finish_sockudo_handshake(tcp_stream, &target, &headers).await
627    }
628
629    #[cfg(feature = "transport-sockudo")]
630    async fn finish_sockudo_handshake<S>(
631        mut stream: S,
632        target: &SockudoTarget,
633        headers: &[(String, String)],
634    ) -> Result<(MessageWriter, MessageReader), TransportError>
635    where
636        S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
637    {
638        // Use our helper for both paths: uniform error logging, and we own
639        // stream construction since sockudo's high-level client drops the
640        // handshake leftover.
641        let handshake = client_handshake_with_headers(
642            &mut stream,
643            &target.host_header,
644            &target.path,
645            None,
646            headers,
647        )
648        .await
649        .map_err(TransportError::from)?;
650
651        // Reading the HTTP 101 may also read the first WebSocket frame prefix;
652        // replay it only when present so the ordinary path stays unwrapped.
653        let stream = match handshake.leftover {
654            Some(prefix) => SockudoStream::<Http1>::new(PrefixedIo::new(stream, prefix)),
655            None => SockudoStream::<Http1>::new(stream),
656        };
657        let ws = SockudoWebSocketStream::from_raw(stream, Role::Client, SockudoConfig::default());
658        let transport: BoxedWsTransport = Box::pin(SockudoTransport::new(ws));
659        Ok(transport.split())
660    }
661}
662
663/// Complete the WebSocket handshake over a stream that has already been
664/// tunneled through an HTTP `CONNECT` proxy. Generic over the concrete
665/// stream type so the four [`super::proxy::ProxiedStream`] variants share
666/// a single body.
667#[cfg(not(feature = "turmoil"))]
668async fn proxied_ws_handshake<S>(
669    request: tokio_tungstenite::tungstenite::handshake::client::Request,
670    stream: S,
671) -> Result<BoxedWsTransport, TransportError>
672where
673    S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static,
674{
675    let (ws, _resp) = tokio_tungstenite::client_async(request, stream)
676        .await
677        .map_err(TransportError::from)?;
678    Ok(Box::pin(TungsteniteTransport::new(ws)))
679}
680
681/// Parsed components of a `ws://` / `wss://` URL needed by the sockudo backend.
682///
683/// Sockudo's HTTP/1.1 client passes the `host` argument verbatim as the
684/// HTTP `Host:` header, so it must include the explicit port when one is
685/// present in the URL (RFC 7230 section 5.4). The DNS / SNI lookup uses the bare
686/// host without the port.
687#[cfg(feature = "transport-sockudo")]
688#[derive(Debug, PartialEq, Eq)]
689struct SockudoTarget {
690    host: String,
691    /// Value to send as the HTTP `Host:` header. Includes `:port` only when
692    /// the URL specifies a non-default port explicitly.
693    host_header: String,
694    port: u16,
695    path: String,
696    is_tls: bool,
697}
698
699#[cfg(feature = "transport-sockudo")]
700impl SockudoTarget {
701    fn parse(url: &str) -> Result<Self, TransportError> {
702        let parsed =
703            url::Url::parse(url).map_err(|e| TransportError::InvalidUrl(format!("{url}: {e}")))?;
704
705        let scheme = parsed.scheme();
706        let is_tls = match scheme {
707            "ws" => false,
708            "wss" => true,
709            other => {
710                return Err(TransportError::InvalidUrl(format!(
711                    "expected ws:// or wss:// scheme, was {other}"
712                )));
713            }
714        };
715
716        let raw_host = parsed
717            .host_str()
718            .ok_or_else(|| TransportError::InvalidUrl("missing hostname".to_string()))?;
719
720        // url::Url stores IPv6 hosts in their bracketed form (e.g. `[::1]`).
721        // Brackets are correct for the HTTP `Host:` header but invalid for
722        // DNS/TCP and TLS SNI, so we keep two representations: a bracketed
723        // `host_header` for the upgrade, and a bare `host` for socket dialing.
724        let is_bracketed = raw_host.starts_with('[') && raw_host.ends_with(']');
725        let host = if is_bracketed {
726            raw_host[1..raw_host.len() - 1].to_string()
727        } else {
728            raw_host.to_string()
729        };
730
731        let explicit_port = parsed.port();
732        let port = explicit_port.unwrap_or(if is_tls { 443 } else { 80 });
733        let host_header = match explicit_port {
734            Some(p) => format!("{raw_host}:{p}"),
735            None => raw_host.to_string(),
736        };
737
738        let path = if parsed.path().is_empty() {
739            "/".to_string()
740        } else {
741            let mut p = parsed.path().to_string();
742            if let Some(query) = parsed.query() {
743                p.push('?');
744                p.push_str(query);
745            }
746            p
747        };
748
749        Ok(Self {
750            host,
751            host_header,
752            port,
753            path,
754            is_tls,
755        })
756    }
757}
758
759impl WebSocketClientInner {
760    /// Reconnect with server.
761    ///
762    /// Make a new connection with server. Use the new read and write halves
763    /// to update self writer and read and heartbeat tasks.
764    ///
765    /// For stream-based clients (created via `connect_stream`), reconnection is disabled
766    /// because the reader is owned by the caller and cannot be replaced. Stream users
767    /// should handle disconnections by creating a new connection.
768    ///
769    /// # Errors
770    ///
771    /// Returns an error if:
772    /// - The reconnection attempt times out.
773    /// - The connection to the server fails.
774    pub async fn reconnect(&mut self) -> Result<(), TransportError> {
775        log::debug!("Reconnecting");
776
777        if self.is_stream_mode {
778            log::warn!(
779                "Auto-reconnect disabled for stream-based WebSocket client; \
780                stream users must manually reconnect by creating a new connection"
781            );
782            // Transition to CLOSED state to stop reconnection attempts
783            self.connection_mode
784                .store(ConnectionMode::Closed.as_u8(), Ordering::SeqCst);
785            return Ok(());
786        }
787
788        if ConnectionMode::from_atomic(&self.connection_mode).is_disconnect() {
789            log::debug!("Reconnect aborted due to disconnect state");
790            return Ok(());
791        }
792
793        dst::time::timeout(self.reconnect_timeout, async {
794            // Attempt to connect; abort early if a disconnect was requested
795            let (new_writer, reader) = Self::connect_with_server(
796                &self.config.url,
797                self.config.headers.clone(),
798                self.config.backend,
799                self.config.proxy_url.as_deref(),
800            )
801            .await?;
802
803            if ConnectionMode::from_atomic(&self.connection_mode).is_disconnect() {
804                log::debug!("Reconnect aborted mid-flight (after connect)");
805                return Ok(());
806            }
807
808            // Use a oneshot channel to synchronize the writer swap before transitioning
809            // back to ACTIVE. Buffered messages stay in the writer task and replay later.
810            let (tx, rx) = tokio::sync::oneshot::channel();
811            if let Err(e) = self.writer_tx.send(WriterCommand::Update(new_writer, tx)) {
812                log::error!("{e}");
813                return Err(TransportError::Io(std::io::Error::new(
814                    std::io::ErrorKind::BrokenPipe,
815                    format!("Failed to send update command: {e}"),
816                )));
817            }
818
819            // Wait for writer to confirm it accepted the new socket
820            match rx.await {
821                Ok(true) => log::debug!("Writer confirmed socket update"),
822                Ok(false) => {
823                    log::warn!("Writer rejected socket update, aborting reconnect");
824                    return Err(TransportError::Io(std::io::Error::other(
825                        "Failed to update reconnection writer",
826                    )));
827                }
828                Err(e) => {
829                    log::error!("Writer dropped update channel: {e}");
830                    return Err(TransportError::Io(std::io::Error::new(
831                        std::io::ErrorKind::BrokenPipe,
832                        "Writer task dropped response channel",
833                    )));
834                }
835            }
836
837            // Delay before closing connection
838            dst::time::sleep(Duration::from_millis(GRACEFUL_SHUTDOWN_DELAY_MS)).await;
839
840            if ConnectionMode::from_atomic(&self.connection_mode).is_disconnect() {
841                log::debug!("Reconnect aborted mid-flight (after delay)");
842                return Ok(());
843            }
844
845            if let Some(ref read_task) = self.read_task.take()
846                && !read_task.is_finished()
847            {
848                read_task.abort();
849                log_task_aborted("read");
850            }
851
852            // Atomically transition from Reconnect to Active
853            // This prevents race condition where disconnect could be requested between check and store
854            if self
855                .connection_mode
856                .compare_exchange(
857                    ConnectionMode::Reconnect.as_u8(),
858                    ConnectionMode::Active.as_u8(),
859                    Ordering::SeqCst,
860                    Ordering::SeqCst,
861                )
862                .is_err()
863            {
864                log::debug!("Reconnect aborted (state changed during reconnect)");
865                return Ok(());
866            }
867
868            self.read_task = if self.message_handler.is_some() {
869                Some(Self::spawn_message_handler_task(
870                    self.connection_mode.clone(),
871                    self.state_notify.clone(),
872                    reader,
873                    self.message_handler.as_ref(),
874                    self.ping_handler.as_ref(),
875                    self.config.idle_timeout_ms,
876                ))
877            } else {
878                None
879            };
880
881            log::debug!("Reconnect succeeded");
882            Ok(())
883        })
884        .await
885        .map_err(|_| {
886            TransportError::Io(std::io::Error::new(
887                std::io::ErrorKind::TimedOut,
888                format!(
889                    "reconnection timed out after {}s",
890                    self.reconnect_timeout.as_secs_f64()
891                ),
892            ))
893        })?
894    }
895
896    /// Check if the client is still alive.
897    ///
898    /// Returns `true` if both the read and write tasks are still running.
899    /// There may be some delay between the connection closing and the
900    /// client detecting it.
901    #[inline]
902    #[must_use]
903    pub fn is_alive(&self) -> bool {
904        match &self.read_task {
905            Some(read_task) => !read_task.is_finished() && !self.write_task.is_finished(),
906            None => !self.write_task.is_finished(),
907        }
908    }
909
910    fn spawn_message_handler_task(
911        connection_state: Arc<AtomicU8>,
912        state_notify: Arc<tokio::sync::Notify>,
913        mut reader: MessageReader,
914        message_handler: Option<&MessageHandler>,
915        ping_handler: Option<&PingHandler>,
916        idle_timeout_ms: Option<u64>,
917    ) -> tokio::task::JoinHandle<()> {
918        log::debug!("Started message handler task 'read'");
919
920        let check_interval = Duration::from_millis(CONNECTION_STATE_CHECK_INTERVAL_MS);
921        let idle_timeout = idle_timeout_ms.map(Duration::from_millis);
922
923        // Clone Arc handlers for the async task
924        let message_handler = message_handler.cloned();
925        let ping_handler = ping_handler.cloned();
926
927        tokio::task::spawn(async move {
928            let mut last_data_time = dst::time::Instant::now();
929
930            loop {
931                if !ConnectionMode::from_atomic(&connection_state).is_active() {
932                    break;
933                }
934
935                match dst::time::timeout(check_interval, reader.next()).await {
936                    Ok(Some(Ok(Message::Binary(data)))) => {
937                        log::trace!("Received message <binary> {} bytes", data.len());
938                        last_data_time = dst::time::Instant::now();
939
940                        if let Some(ref handler) = message_handler {
941                            handler(Message::Binary(data));
942                        }
943                    }
944                    Ok(Some(Ok(Message::Text(data)))) => {
945                        log::trace!("Received message: {data:?}");
946                        last_data_time = dst::time::Instant::now();
947
948                        if let Some(ref handler) = message_handler {
949                            handler(Message::Text(data));
950                        }
951                    }
952                    Ok(Some(Ok(Message::Ping(ping_data)))) => {
953                        log::trace!("Received ping: {ping_data:?}");
954                        // Do not reset last_data_time: pings are keep-alive frames, not application
955                        // data, so a peer that emits only pings must still trip the idle timeout.
956
957                        if let Some(ref handler) = ping_handler {
958                            handler(ping_data.to_vec());
959                        }
960                    }
961                    Ok(Some(Ok(Message::Pong(_)))) => {
962                        log::trace!("Received pong");
963                        // Do not reset last_data_time: pongs are keep-alive replies (not data)
964                    }
965                    Ok(Some(Ok(Message::Close(_)))) => {
966                        log::debug!("Received close message - terminating");
967                        break;
968                    }
969                    Ok(Some(Err(e))) => {
970                        log::error!("Received error message - terminating: {e}");
971                        break;
972                    }
973                    Ok(None) => {
974                        log::debug!("No message received - terminating");
975                        break;
976                    }
977                    Err(_) => {
978                        if let Some(timeout) = idle_timeout {
979                            let idle_duration = last_data_time.elapsed();
980                            if idle_duration >= timeout {
981                                log::warn!(
982                                    "Read idle timeout: no data received for {:.1}s",
983                                    idle_duration.as_secs_f64()
984                                );
985                                break;
986                            }
987                        }
988                    }
989                }
990            }
991
992            // Wake the controller immediately so it detects the dead read task
993            state_notify.notify_one();
994        })
995    }
996
997    /// Attempts to send all buffered messages after reconnection.
998    ///
999    /// Returns `true` if a send error occurred (caller should trigger reconnection).
1000    /// Messages remain in buffer if send fails, preserving them for the next reconnection attempt.
1001    async fn drain_reconnect_buffer(
1002        buffer: &mut VecDeque<Message>,
1003        writer: &mut MessageWriter,
1004    ) -> bool {
1005        if buffer.is_empty() {
1006            return false;
1007        }
1008
1009        let initial_buffer_len = buffer.len();
1010        log::info!("Sending {initial_buffer_len} buffered messages after reconnection");
1011
1012        let mut send_error_occurred = false;
1013
1014        while let Some(buffered_msg) = buffer.front() {
1015            // Clone message before attempting send (to keep in buffer if send fails)
1016            let msg_to_send = buffered_msg.clone();
1017
1018            if let Err(e) = writer.send(msg_to_send).await {
1019                log::error!(
1020                    "Failed to send buffered message after reconnection: {e}, {} messages remain in buffer",
1021                    buffer.len()
1022                );
1023                send_error_occurred = true;
1024                break; // Stop processing buffer, remaining messages preserved for next reconnection
1025            }
1026
1027            // Only remove from buffer after successful send
1028            buffer.pop_front();
1029        }
1030
1031        if buffer.is_empty() {
1032            log::info!("Successfully sent all {initial_buffer_len} buffered messages");
1033        }
1034
1035        send_error_occurred
1036    }
1037
1038    fn can_drain_reconnect_buffer(
1039        reconnect_buffer_waits_for_auth: &AtomicBool,
1040        auth_tracker: &Arc<OnceLock<AuthTracker>>,
1041    ) -> ReconnectBufferAction {
1042        if !reconnect_buffer_waits_for_auth.load(Ordering::Acquire) {
1043            return ReconnectBufferAction::Drain;
1044        }
1045
1046        match auth_tracker.get().map(AuthTracker::auth_state) {
1047            Some(AuthState::Authenticated) => ReconnectBufferAction::Drain,
1048            Some(AuthState::Failed) => ReconnectBufferAction::Discard,
1049            Some(AuthState::Unauthenticated) | None => ReconnectBufferAction::Wait,
1050        }
1051    }
1052
1053    fn spawn_write_task(
1054        connection_state: Arc<AtomicU8>,
1055        state_notify: Arc<tokio::sync::Notify>,
1056        writer: MessageWriter,
1057        mut writer_rx: tokio::sync::mpsc::UnboundedReceiver<WriterCommand>,
1058        auth_tracker: Arc<OnceLock<AuthTracker>>,
1059        reconnect_buffer_waits_for_auth: Arc<AtomicBool>,
1060    ) -> tokio::task::JoinHandle<()> {
1061        log_task_started("write");
1062
1063        // Interval between checking the connection mode
1064        let check_interval = Duration::from_millis(CONNECTION_STATE_CHECK_INTERVAL_MS);
1065
1066        tokio::task::spawn(async move {
1067            let mut active_writer = writer;
1068            // Buffer for messages received during reconnection
1069            // VecDeque for efficient pop_front() operations
1070            let mut reconnect_buffer: VecDeque<Message> = VecDeque::new();
1071
1072            loop {
1073                let mode = ConnectionMode::from_atomic(&connection_state);
1074
1075                match mode {
1076                    ConnectionMode::Disconnect => {
1077                        // Log any buffered messages that will be lost
1078                        if !reconnect_buffer.is_empty() {
1079                            log::warn!(
1080                                "Discarding {} buffered messages due to disconnect",
1081                                reconnect_buffer.len()
1082                            );
1083                            reconnect_buffer.clear();
1084                        }
1085
1086                        // Attempt to close the writer gracefully before exiting,
1087                        // we ignore any error as the writer may already be closed.
1088                        _ = dst::time::timeout(
1089                            Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS),
1090                            active_writer.close(),
1091                        )
1092                        .await;
1093                        break;
1094                    }
1095                    ConnectionMode::Closed => {
1096                        // Log any buffered messages that will be lost
1097                        if !reconnect_buffer.is_empty() {
1098                            log::warn!(
1099                                "Discarding {} buffered messages due to closed connection",
1100                                reconnect_buffer.len()
1101                            );
1102                            reconnect_buffer.clear();
1103                        }
1104                        break;
1105                    }
1106                    _ => {}
1107                }
1108
1109                if mode.is_active() && !reconnect_buffer.is_empty() {
1110                    match Self::can_drain_reconnect_buffer(
1111                        reconnect_buffer_waits_for_auth.as_ref(),
1112                        &auth_tracker,
1113                    ) {
1114                        ReconnectBufferAction::Drain => {
1115                            let send_error = Self::drain_reconnect_buffer(
1116                                &mut reconnect_buffer,
1117                                &mut active_writer,
1118                            )
1119                            .await;
1120
1121                            if send_error {
1122                                if let Some(tracker) = auth_tracker.get() {
1123                                    tracker.invalidate();
1124                                }
1125                                connection_state
1126                                    .store(ConnectionMode::Reconnect.as_u8(), Ordering::SeqCst);
1127                                state_notify.notify_one();
1128                            }
1129
1130                            continue;
1131                        }
1132                        ReconnectBufferAction::Discard => {
1133                            log::warn!(
1134                                "Discarding {} buffered messages after authentication failed",
1135                                reconnect_buffer.len()
1136                            );
1137                            reconnect_buffer.clear();
1138                            continue;
1139                        }
1140                        ReconnectBufferAction::Wait => {}
1141                    }
1142                }
1143
1144                match dst::time::timeout(check_interval, writer_rx.recv()).await {
1145                    Ok(Some(msg)) => {
1146                        // Re-check connection mode after receiving a message
1147                        let mode = ConnectionMode::from_atomic(&connection_state);
1148                        if matches!(mode, ConnectionMode::Disconnect | ConnectionMode::Closed) {
1149                            break;
1150                        }
1151
1152                        match msg {
1153                            WriterCommand::Update(new_writer, tx) => {
1154                                log::debug!("Received new writer");
1155
1156                                // Delay before closing connection
1157                                dst::time::sleep(Duration::from_millis(100)).await;
1158
1159                                // Attempt to close the writer gracefully on update,
1160                                // we ignore any error as the writer may already be closed.
1161                                _ = dst::time::timeout(
1162                                    Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS),
1163                                    active_writer.close(),
1164                                )
1165                                .await;
1166
1167                                active_writer = new_writer;
1168                                log::debug!("Updated writer");
1169
1170                                if let Err(e) = tx.send(true) {
1171                                    log::error!(
1172                                        "Failed to report writer update to controller: {e:?}"
1173                                    );
1174                                }
1175                            }
1176                            WriterCommand::Send(msg) if mode.is_reconnect() => {
1177                                // Buffer messages during reconnection instead of dropping them
1178                                log::debug!(
1179                                    "Buffering message during reconnection (buffer size: {})",
1180                                    reconnect_buffer.len() + 1
1181                                );
1182                                reconnect_buffer.push_back(msg);
1183                            }
1184                            WriterCommand::Send(msg) => {
1185                                if let Err(e) = active_writer.send(msg.clone()).await {
1186                                    log::error!("Failed to send message: {e}");
1187                                    log::warn!("Writer triggering reconnect");
1188                                    reconnect_buffer.push_back(msg);
1189
1190                                    if let Some(tracker) = auth_tracker.get() {
1191                                        tracker.invalidate();
1192                                    }
1193                                    connection_state
1194                                        .store(ConnectionMode::Reconnect.as_u8(), Ordering::SeqCst);
1195                                    state_notify.notify_one();
1196                                }
1197                            }
1198                        }
1199                    }
1200                    Ok(None) => {
1201                        // Channel closed - writer task should terminate
1202                        log::debug!("Writer channel closed, terminating writer task");
1203                        break;
1204                    }
1205                    Err(_) => {
1206                        // Timeout - just continue the loop
1207                    }
1208                }
1209            }
1210
1211            // Attempt to close the writer gracefully before exiting,
1212            // we ignore any error as the writer may already be closed.
1213            _ = dst::time::timeout(
1214                Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS),
1215                active_writer.close(),
1216            )
1217            .await;
1218
1219            log_task_stopped("write");
1220        })
1221    }
1222
1223    fn spawn_heartbeat_task(
1224        connection_state: Arc<AtomicU8>,
1225        heartbeat_secs: u64,
1226        message: Option<String>,
1227        writer_tx: tokio::sync::mpsc::UnboundedSender<WriterCommand>,
1228    ) -> tokio::task::JoinHandle<()> {
1229        log_task_started("heartbeat");
1230
1231        tokio::task::spawn(async move {
1232            let interval = Duration::from_secs(heartbeat_secs);
1233
1234            loop {
1235                dst::time::sleep(interval).await;
1236
1237                match ConnectionMode::from_u8(connection_state.load(Ordering::SeqCst)) {
1238                    ConnectionMode::Active => {
1239                        let msg = match &message {
1240                            Some(text) => WriterCommand::Send(Message::Text(text.clone().into())),
1241                            None => WriterCommand::Send(Message::Ping(vec![].into())),
1242                        };
1243
1244                        match writer_tx.send(msg) {
1245                            Ok(()) => log::trace!("Sent heartbeat to writer task"),
1246                            Err(e) => {
1247                                log::error!("Failed to send heartbeat to writer task: {e}");
1248                            }
1249                        }
1250                    }
1251                    ConnectionMode::Reconnect => {}
1252                    ConnectionMode::Disconnect | ConnectionMode::Closed => break,
1253                }
1254            }
1255
1256            log_task_stopped("heartbeat");
1257        })
1258    }
1259}
1260
1261impl Drop for WebSocketClientInner {
1262    fn drop(&mut self) {
1263        // Delegate to explicit cleanup handler
1264        self.clean_drop();
1265    }
1266}
1267
1268/// Cleanup on drop: aborts background tasks and clears handlers to break reference cycles.
1269impl CleanDrop for WebSocketClientInner {
1270    fn clean_drop(&mut self) {
1271        if let Some(ref read_task) = self.read_task.take()
1272            && !read_task.is_finished()
1273        {
1274            read_task.abort();
1275            log_task_aborted("read");
1276        }
1277
1278        if !self.write_task.is_finished() {
1279            self.write_task.abort();
1280            log_task_aborted("write");
1281        }
1282
1283        if let Some(ref handle) = self.heartbeat_task.take()
1284            && !handle.is_finished()
1285        {
1286            handle.abort();
1287            log_task_aborted("heartbeat");
1288        }
1289
1290        // Clear handlers to break potential reference cycles
1291        self.message_handler = None;
1292        self.ping_handler = None;
1293    }
1294}
1295
1296#[expect(
1297    clippy::missing_fields_in_debug,
1298    reason = "handler closures and internal task handles are intentionally omitted"
1299)]
1300impl Debug for WebSocketClientInner {
1301    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1302        f.debug_struct(stringify!(WebSocketClientInner))
1303            .field("config", &self.config)
1304            .field(
1305                "connection_mode",
1306                &ConnectionMode::from_atomic(&self.connection_mode),
1307            )
1308            .field("reconnect_timeout", &self.reconnect_timeout)
1309            .field("is_stream_mode", &self.is_stream_mode)
1310            .finish()
1311    }
1312}
1313
1314/// WebSocket client with automatic reconnection.
1315///
1316/// Handles connection state, callbacks, and rate limiting.
1317/// See module docs for architecture details.
1318#[cfg_attr(
1319    feature = "python",
1320    pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.network")
1321)]
1322#[cfg_attr(
1323    feature = "python",
1324    pyo3_stub_gen::derive::gen_stub_pyclass(module = "nautilus_trader.network")
1325)]
1326pub struct WebSocketClient {
1327    pub(crate) controller_task: tokio::task::JoinHandle<()>,
1328    pub(crate) connection_mode: Arc<AtomicU8>,
1329    pub(crate) state_notify: Arc<tokio::sync::Notify>,
1330    pub(crate) reconnect_timeout: Duration,
1331    pub(crate) rate_limiter: Arc<RateLimiter<Ustr, MonotonicClock>>,
1332    pub(crate) writer_tx: tokio::sync::mpsc::UnboundedSender<WriterCommand>,
1333    auth_tracker: Arc<OnceLock<AuthTracker>>,
1334    reconnect_buffer_waits_for_auth: Arc<AtomicBool>,
1335}
1336
1337impl Debug for WebSocketClient {
1338    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1339        f.debug_struct(stringify!(WebSocketClient)).finish()
1340    }
1341}
1342
1343impl WebSocketClient {
1344    /// Creates a websocket client in **stream mode** that returns a [`MessageReader`].
1345    ///
1346    /// Returns a stream that the caller owns and reads from directly. Automatic reconnection
1347    /// is **disabled** because the reader cannot be replaced internally. On disconnection, the
1348    /// client transitions to CLOSED state and the caller must manually reconnect by calling
1349    /// `connect_stream` again.
1350    ///
1351    /// Use stream mode when you need custom reconnection logic, direct control over message
1352    /// reading, or fine-grained backpressure handling.
1353    ///
1354    /// See [`WebSocketConfig`] documentation for comparison with handler mode.
1355    ///
1356    /// # Errors
1357    ///
1358    /// Returns an error if the connection cannot be established.
1359    pub async fn connect_stream(
1360        config: WebSocketConfig,
1361        keyed_quotas: Vec<(String, Quota)>,
1362        default_quota: Option<Quota>,
1363        post_reconnect: Option<Arc<dyn Fn() + Send + Sync>>,
1364    ) -> Result<(MessageReader, Self), TransportError> {
1365        install_cryptographic_provider();
1366
1367        // Create a single connection and split it, respecting configured headers
1368        let (writer, reader) = WebSocketClientInner::connect_with_server(
1369            &config.url,
1370            config.headers.clone(),
1371            config.backend,
1372            config.proxy_url.as_deref(),
1373        )
1374        .await?;
1375
1376        // Create inner without connecting (we'll provide the writer)
1377        let inner = WebSocketClientInner::new_with_writer(config, writer).await?;
1378
1379        let connection_mode = inner.connection_mode.clone();
1380        let state_notify = inner.state_notify.clone();
1381        let reconnect_timeout = inner.reconnect_timeout;
1382        let auth_tracker = Arc::clone(&inner.auth_tracker);
1383        let reconnect_buffer_waits_for_auth = Arc::clone(&inner.reconnect_buffer_waits_for_auth);
1384        let keyed_quotas = keyed_quotas
1385            .into_iter()
1386            .map(|(key, quota)| (Ustr::from(&key), quota))
1387            .collect();
1388        let rate_limiter = Arc::new(RateLimiter::new_with_quota(default_quota, keyed_quotas));
1389        let writer_tx = inner.writer_tx.clone();
1390
1391        let controller_task = Self::spawn_controller_task(
1392            inner,
1393            connection_mode.clone(),
1394            state_notify.clone(),
1395            post_reconnect,
1396            Arc::clone(&auth_tracker),
1397        );
1398
1399        Ok((
1400            reader,
1401            Self {
1402                controller_task,
1403                connection_mode,
1404                state_notify,
1405                reconnect_timeout,
1406                rate_limiter,
1407                writer_tx,
1408                auth_tracker,
1409                reconnect_buffer_waits_for_auth,
1410            },
1411        ))
1412    }
1413
1414    /// Creates a websocket client in **handler mode** with automatic reconnection.
1415    ///
1416    /// The handler is called for each incoming message on an internal task.
1417    /// Automatic reconnection is **enabled** with exponential backoff. On disconnection,
1418    /// the client automatically attempts to reconnect and replaces the internal reader
1419    /// (the handler continues working seamlessly).
1420    ///
1421    /// Use handler mode for simplified connection management, automatic reconnection, Python
1422    /// bindings, or callback-based message handling.
1423    ///
1424    /// See [`WebSocketConfig`] documentation for comparison with stream mode.
1425    ///
1426    /// # Errors
1427    ///
1428    /// Returns an error if:
1429    /// - The connection cannot be established.
1430    /// - `message_handler` is `None` (use `connect_stream` instead).
1431    pub async fn connect(
1432        config: WebSocketConfig,
1433        message_handler: Option<MessageHandler>,
1434        ping_handler: Option<PingHandler>,
1435        post_reconnection: Option<Arc<dyn Fn() + Send + Sync>>,
1436        keyed_quotas: Vec<(String, Quota)>,
1437        default_quota: Option<Quota>,
1438    ) -> Result<Self, TransportError> {
1439        // Validate that handler mode has a message handler
1440        if message_handler.is_none() {
1441            return Err(TransportError::Io(std::io::Error::new(
1442                std::io::ErrorKind::InvalidInput,
1443                "Handler mode requires message_handler to be set. Use connect_stream() for stream mode without a handler.",
1444            )));
1445        }
1446
1447        log::debug!("Connecting");
1448        let inner =
1449            WebSocketClientInner::connect_url(config, message_handler, ping_handler).await?;
1450        let connection_mode = inner.connection_mode.clone();
1451        let state_notify = inner.state_notify.clone();
1452        let writer_tx = inner.writer_tx.clone();
1453        let reconnect_timeout = inner.reconnect_timeout;
1454        let auth_tracker = Arc::clone(&inner.auth_tracker);
1455        let reconnect_buffer_waits_for_auth = Arc::clone(&inner.reconnect_buffer_waits_for_auth);
1456
1457        let controller_task = Self::spawn_controller_task(
1458            inner,
1459            connection_mode.clone(),
1460            state_notify.clone(),
1461            post_reconnection,
1462            Arc::clone(&auth_tracker),
1463        );
1464
1465        let keyed_quotas = keyed_quotas
1466            .into_iter()
1467            .map(|(key, quota)| (Ustr::from(&key), quota))
1468            .collect();
1469        let rate_limiter = Arc::new(RateLimiter::new_with_quota(default_quota, keyed_quotas));
1470
1471        Ok(Self {
1472            controller_task,
1473            connection_mode,
1474            state_notify,
1475            reconnect_timeout,
1476            rate_limiter,
1477            writer_tx,
1478            auth_tracker,
1479            reconnect_buffer_waits_for_auth,
1480        })
1481    }
1482
1483    /// Returns the current connection mode.
1484    #[must_use]
1485    pub fn connection_mode(&self) -> ConnectionMode {
1486        ConnectionMode::from_atomic(&self.connection_mode)
1487    }
1488
1489    /// Returns a clone of the connection mode atomic for external state tracking.
1490    ///
1491    /// This allows adapter clients to track connection state across reconnections
1492    /// without message-passing delays.
1493    #[must_use]
1494    pub fn connection_mode_atomic(&self) -> Arc<AtomicU8> {
1495        Arc::clone(&self.connection_mode)
1496    }
1497
1498    /// Check if the client connection is active.
1499    ///
1500    /// Returns `true` if the client is connected and has not been signalled to disconnect.
1501    /// The client will automatically retry connection based on its configuration.
1502    #[inline]
1503    #[must_use]
1504    pub fn is_active(&self) -> bool {
1505        self.connection_mode().is_active()
1506    }
1507
1508    /// Check if the client is disconnected.
1509    #[must_use]
1510    pub fn is_disconnected(&self) -> bool {
1511        self.controller_task.is_finished()
1512    }
1513
1514    /// Check if the client is reconnecting.
1515    ///
1516    /// Returns `true` if the client lost connection and is attempting to reestablish it.
1517    /// The client will automatically retry connection based on its configuration.
1518    #[inline]
1519    #[must_use]
1520    pub fn is_reconnecting(&self) -> bool {
1521        self.connection_mode().is_reconnect()
1522    }
1523
1524    /// Registers an [`AuthTracker`] with the client.
1525    ///
1526    /// When the controller detects a dead connection and transitions to
1527    /// `Reconnect`, it calls `invalidate()` on the tracker so that any
1528    /// pending authenticated sends see the state change immediately.
1529    /// Set `reconnect_buffer_waits_for_auth` for clients that must not replay
1530    /// buffered messages until the next session authenticates.
1531    ///
1532    /// Call this once after construction, before any authenticated sends.
1533    pub fn set_auth_tracker(&self, tracker: AuthTracker, reconnect_buffer_waits_for_auth: bool) {
1534        let _ = self.auth_tracker.set(tracker);
1535        self.reconnect_buffer_waits_for_auth
1536            .store(reconnect_buffer_waits_for_auth, Ordering::Release);
1537    }
1538
1539    /// Check if the client is disconnecting.
1540    ///
1541    /// Returns `true` if the client is in disconnect mode.
1542    #[inline]
1543    #[must_use]
1544    pub fn is_disconnecting(&self) -> bool {
1545        self.connection_mode().is_disconnect()
1546    }
1547
1548    /// Check if the client is closed.
1549    ///
1550    /// Returns `true` if the client has been explicitly disconnected or reached
1551    /// maximum reconnection attempts. In this state, the client cannot be reused
1552    /// and a new client must be created for further connections.
1553    #[inline]
1554    #[must_use]
1555    pub fn is_closed(&self) -> bool {
1556        self.connection_mode().is_closed()
1557    }
1558
1559    /// Checks whether the connection is in a terminal state (disconnecting or closed).
1560    ///
1561    /// Single atomic load to fail fast before rate limiting or waiting.
1562    #[inline]
1563    fn check_not_terminal(&self) -> Result<(), SendError> {
1564        match self.connection_mode() {
1565            ConnectionMode::Disconnect | ConnectionMode::Closed => Err(SendError::Closed),
1566            _ => Ok(()),
1567        }
1568    }
1569
1570    /// Waits for rate limiter quota, aborting early if connection enters a terminal state.
1571    async fn await_rate_limit_or_closed(&self, keys: Option<&[Ustr]>) -> Result<(), SendError> {
1572        const CHECK_INTERVAL_MS: u64 = 100;
1573
1574        tokio::select! {
1575            biased;
1576            () = self.rate_limiter.await_keys_ready(keys) => Ok(()),
1577            () = async {
1578                loop {
1579                    let notified = self.state_notify.notified();
1580
1581                    if matches!(self.connection_mode(), ConnectionMode::Disconnect | ConnectionMode::Closed) {
1582                        break;
1583                    }
1584                    tokio::select! {
1585                        biased;
1586                        () = notified => {}
1587                        () = dst::time::sleep(Duration::from_millis(CHECK_INTERVAL_MS)) => {}
1588                    }
1589                }
1590            } => Err(SendError::Closed),
1591        }
1592    }
1593
1594    /// Waits for the client to become active before sending.
1595    ///
1596    /// Uses `state_notify` for event-driven wakeup so sends resume immediately
1597    /// after reconnection completes. A fallback interval guards against missed
1598    /// notifications.
1599    async fn wait_for_active(&self) -> Result<(), SendError> {
1600        const FALLBACK_INTERVAL_MS: u64 = 100;
1601
1602        let mode = self.connection_mode();
1603        if mode.is_active() {
1604            return Ok(());
1605        }
1606
1607        if matches!(mode, ConnectionMode::Disconnect | ConnectionMode::Closed) {
1608            return Err(SendError::Closed);
1609        }
1610
1611        log::debug!("Waiting for client to become ACTIVE before sending...");
1612
1613        let fallback_interval = Duration::from_millis(FALLBACK_INTERVAL_MS);
1614
1615        dst::time::timeout(self.reconnect_timeout, async {
1616            loop {
1617                // Register notification interest BEFORE checking state to prevent
1618                // a race where the state changes between our check and the await
1619                let notified = self.state_notify.notified();
1620
1621                let mode = self.connection_mode();
1622                if mode.is_active() {
1623                    return Ok(());
1624                }
1625
1626                if matches!(mode, ConnectionMode::Disconnect | ConnectionMode::Closed) {
1627                    return Err(());
1628                }
1629
1630                tokio::select! {
1631                    biased;
1632                    () = notified => {}
1633                    () = dst::time::sleep(fallback_interval) => {}
1634                }
1635            }
1636        })
1637        .await
1638        .map_err(|_| SendError::Timeout)?
1639        .map_err(|()| SendError::Closed)
1640    }
1641
1642    /// Signals that the caller's reader has observed EOF or a fatal error.
1643    ///
1644    /// In stream mode the controller has no visibility into the caller-owned reader.
1645    /// Call this method when `reader.next().await` returns `None` or an unrecoverable
1646    /// error so the controller transitions to `Closed` and dependent tasks shut down.
1647    ///
1648    /// For peer-initiated close frames (`Message::Close`), use [`disconnect`](Self::disconnect)
1649    /// instead so the writer can send the close reply before shutting down.
1650    ///
1651    /// This is a no-op if the connection is already closed or disconnecting.
1652    pub fn notify_closed(&self) {
1653        let mode = self.connection_mode();
1654        if mode.is_disconnect() || mode.is_closed() {
1655            return;
1656        }
1657
1658        log::debug!("Stream reader signalled EOF, transitioning to CLOSED");
1659
1660        self.connection_mode
1661            .store(ConnectionMode::Closed.as_u8(), Ordering::SeqCst);
1662        self.state_notify.notify_waiters();
1663    }
1664
1665    /// Set disconnect mode to true.
1666    ///
1667    /// Controller task will periodically check the disconnect mode
1668    /// and shutdown the client if it is alive
1669    pub async fn disconnect(&self) {
1670        log::debug!("Disconnecting");
1671        self.connection_mode
1672            .store(ConnectionMode::Disconnect.as_u8(), Ordering::SeqCst);
1673        self.state_notify.notify_waiters();
1674
1675        if dst::time::timeout(Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS), async {
1676            while !self.is_disconnected() {
1677                dst::time::sleep(Duration::from_millis(CONNECTION_STATE_CHECK_INTERVAL_MS)).await;
1678            }
1679
1680            if !self.controller_task.is_finished() {
1681                self.controller_task.abort();
1682                log_task_aborted("controller");
1683            }
1684        })
1685        .await
1686            == Ok(())
1687        {
1688            log::debug!("Controller task finished");
1689        } else {
1690            log::error!("Timeout waiting for controller task to finish");
1691
1692            if !self.controller_task.is_finished() {
1693                self.controller_task.abort();
1694                log_task_aborted("controller");
1695            }
1696            self.connection_mode
1697                .store(ConnectionMode::Closed.as_u8(), Ordering::SeqCst);
1698        }
1699    }
1700
1701    /// Sends the given text `data` to the server.
1702    ///
1703    /// Returns `Ok(())` when the message is enqueued to the writer channel. This does NOT
1704    /// guarantee delivery: if a disconnect occurs concurrently, the writer task may drop the
1705    /// message. During reconnection, messages are buffered and replayed on the new connection.
1706    ///
1707    /// # Errors
1708    ///
1709    /// Returns a websocket error if unable to send.
1710    #[allow(unused_variables)]
1711    pub async fn send_text(&self, data: String, keys: Option<&[Ustr]>) -> Result<(), SendError> {
1712        self.check_not_terminal()?;
1713
1714        self.await_rate_limit_or_closed(keys).await?;
1715        self.wait_for_active().await?;
1716
1717        log::trace!("Sending text: {data:?}");
1718
1719        let msg = Message::Text(data.into());
1720        self.writer_tx
1721            .send(WriterCommand::Send(msg))
1722            .map_err(|e| SendError::BrokenPipe(e.to_string()))
1723    }
1724
1725    /// Sends a pong frame back to the server.
1726    ///
1727    /// # Errors
1728    ///
1729    /// Returns a websocket error if unable to send.
1730    pub async fn send_pong(&self, data: Vec<u8>) -> Result<(), SendError> {
1731        self.wait_for_active().await?;
1732
1733        log::trace!("Sending pong frame ({} bytes)", data.len());
1734
1735        let msg = Message::Pong(data.into());
1736        self.writer_tx
1737            .send(WriterCommand::Send(msg))
1738            .map_err(|e| SendError::BrokenPipe(e.to_string()))
1739    }
1740
1741    /// Sends the given bytes `data` to the server.
1742    ///
1743    /// Returns `Ok(())` when the message is enqueued to the writer channel. This does NOT
1744    /// guarantee delivery: if a disconnect occurs concurrently, the writer task may drop the
1745    /// message. During reconnection, messages are buffered and replayed on the new connection.
1746    ///
1747    /// # Errors
1748    ///
1749    /// Returns a websocket error if unable to send.
1750    #[allow(unused_variables)]
1751    pub async fn send_bytes(&self, data: Vec<u8>, keys: Option<&[Ustr]>) -> Result<(), SendError> {
1752        self.check_not_terminal()?;
1753
1754        self.await_rate_limit_or_closed(keys).await?;
1755        self.wait_for_active().await?;
1756
1757        log::trace!("Sending bytes: {data:?}");
1758
1759        let msg = Message::Binary(data.into());
1760        self.writer_tx
1761            .send(WriterCommand::Send(msg))
1762            .map_err(|e| SendError::BrokenPipe(e.to_string()))
1763    }
1764
1765    /// Sends a close message to the server.
1766    ///
1767    /// # Errors
1768    ///
1769    /// Returns a websocket error if unable to send.
1770    pub async fn send_close_message(&self) -> Result<(), SendError> {
1771        self.wait_for_active().await?;
1772
1773        let msg = Message::Close(None);
1774        self.writer_tx
1775            .send(WriterCommand::Send(msg))
1776            .map_err(|e| SendError::BrokenPipe(e.to_string()))
1777    }
1778
1779    fn spawn_controller_task(
1780        mut inner: WebSocketClientInner,
1781        connection_mode: Arc<AtomicU8>,
1782        state_notify: Arc<tokio::sync::Notify>,
1783        post_reconnection: Option<Arc<dyn Fn() + Send + Sync>>,
1784        auth_tracker: Arc<OnceLock<AuthTracker>>,
1785    ) -> tokio::task::JoinHandle<()> {
1786        const CONTROLLER_FALLBACK_INTERVAL_MS: u64 = 100;
1787
1788        tokio::task::spawn(async move {
1789            log_task_started("controller");
1790
1791            let fallback_interval = Duration::from_millis(CONTROLLER_FALLBACK_INTERVAL_MS);
1792
1793            loop {
1794                tokio::select! {
1795                    biased;
1796                    () = state_notify.notified() => {}
1797                    () = dst::time::sleep(fallback_interval) => {}
1798                }
1799
1800                let mut mode = ConnectionMode::from_atomic(&connection_mode);
1801
1802                if mode.is_disconnect() {
1803                    log::debug!("Disconnecting");
1804
1805                    let timeout = Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS);
1806                    if dst::time::timeout(timeout, async {
1807                        // Delay awaiting graceful shutdown
1808                        dst::time::sleep(Duration::from_millis(GRACEFUL_SHUTDOWN_DELAY_MS)).await;
1809
1810                        if let Some(task) = &inner.read_task
1811                            && !task.is_finished()
1812                        {
1813                            task.abort();
1814                            log_task_aborted("read");
1815                        }
1816
1817                        if let Some(task) = &inner.heartbeat_task
1818                            && !task.is_finished()
1819                        {
1820                            task.abort();
1821                            log_task_aborted("heartbeat");
1822                        }
1823                    })
1824                    .await
1825                    .is_err()
1826                    {
1827                        log::error!("Shutdown timed out after {}s", timeout.as_secs());
1828                    }
1829
1830                    log::debug!("Closed");
1831                    break; // Controller finished
1832                }
1833
1834                if mode.is_closed() {
1835                    log::debug!("Connection closed");
1836                    break;
1837                }
1838
1839                if mode.is_active() && !inner.is_alive() {
1840                    let target = if inner.is_stream_mode {
1841                        ConnectionMode::Closed
1842                    } else {
1843                        ConnectionMode::Reconnect
1844                    };
1845
1846                    if connection_mode
1847                        .compare_exchange(
1848                            ConnectionMode::Active.as_u8(),
1849                            target.as_u8(),
1850                            Ordering::SeqCst,
1851                            Ordering::SeqCst,
1852                        )
1853                        .is_ok()
1854                    {
1855                        if let Some(tracker) = auth_tracker.get() {
1856                            tracker.invalidate();
1857                        }
1858                        log::debug!("Detected dead connection, transitioning to {target:?}");
1859                    }
1860                    mode = ConnectionMode::from_atomic(&connection_mode);
1861                }
1862
1863                if mode.is_reconnect() {
1864                    // Check if max reconnection attempts exceeded
1865                    if let Some(max_attempts) = inner.reconnect_max_attempts
1866                        && inner.reconnection_attempt_count >= max_attempts
1867                    {
1868                        log::error!(
1869                            "Max reconnection attempts ({max_attempts}) exceeded, transitioning to CLOSED"
1870                        );
1871                        connection_mode.store(ConnectionMode::Closed.as_u8(), Ordering::SeqCst);
1872                        state_notify.notify_waiters();
1873                        break;
1874                    }
1875
1876                    inner.reconnection_attempt_count += 1;
1877                    log::debug!(
1878                        "Reconnection attempt {} of {}",
1879                        inner.reconnection_attempt_count,
1880                        inner
1881                            .reconnect_max_attempts
1882                            .map_or_else(|| "unlimited".to_string(), |m| m.to_string())
1883                    );
1884
1885                    // Race reconnect against disconnect notification
1886                    let reconnect_result = tokio::select! {
1887                        biased;
1888                        result = inner.reconnect() => Some(result),
1889                        () = async {
1890                            loop {
1891                                state_notify.notified().await;
1892
1893                                if ConnectionMode::from_atomic(&connection_mode).is_disconnect() {
1894                                    break;
1895                                }
1896                            }
1897                        } => None,
1898                    };
1899
1900                    match reconnect_result {
1901                        None => {
1902                            log::debug!("Reconnect interrupted by disconnect");
1903                        }
1904                        Some(Ok(())) => {
1905                            inner.backoff.reset();
1906                            inner.reconnection_attempt_count = 0;
1907
1908                            state_notify.notify_waiters();
1909
1910                            if ConnectionMode::from_atomic(&connection_mode).is_active() {
1911                                if let Some(ref handler) = inner.message_handler {
1912                                    let reconnected_msg =
1913                                        Message::Text(RECONNECTED.to_string().into());
1914                                    handler(reconnected_msg);
1915                                    log::debug!("Sent reconnected message to handler");
1916                                }
1917
1918                                // TODO: Retain this legacy callback for use from Python
1919                                if let Some(ref callback) = post_reconnection {
1920                                    callback();
1921                                    log::debug!("Called `post_reconnection` handler");
1922                                }
1923
1924                                log::debug!("Reconnected successfully");
1925                            } else {
1926                                log::debug!(
1927                                    "Skipping post_reconnection handlers due to disconnect state"
1928                                );
1929                            }
1930                        }
1931                        Some(Err(e)) => {
1932                            let duration = inner.backoff.next_duration();
1933                            log::warn!(
1934                                "Reconnect attempt {} failed: {e}",
1935                                inner.reconnection_attempt_count
1936                            );
1937
1938                            if !duration.is_zero() {
1939                                log::warn!("Backing off for {}s...", duration.as_secs_f64());
1940                                // Race backoff sleep against disconnect
1941                                tokio::select! {
1942                                    biased;
1943                                    () = dst::time::sleep(duration) => {}
1944                                    () = async {
1945                                        loop {
1946                                            state_notify.notified().await;
1947
1948                                            if ConnectionMode::from_atomic(&connection_mode).is_disconnect() {
1949                                                break;
1950                                            }
1951                                        }
1952                                    } => {
1953                                        log::debug!("Backoff interrupted by disconnect");
1954                                    }
1955                                }
1956                            }
1957                        }
1958                    }
1959                }
1960            }
1961            inner
1962                .connection_mode
1963                .store(ConnectionMode::Closed.as_u8(), Ordering::SeqCst);
1964
1965            log_task_stopped("controller");
1966        })
1967    }
1968}
1969
1970// Abort controller task on drop to clean up background tasks
1971impl Drop for WebSocketClient {
1972    fn drop(&mut self) {
1973        if !self.controller_task.is_finished() {
1974            self.controller_task.abort();
1975            log_task_aborted("controller");
1976        }
1977    }
1978}
1979
1980#[cfg(test)]
1981#[cfg(not(feature = "turmoil"))]
1982#[cfg(not(all(feature = "simulation", madsim)))] // transport-layer I/O not simulated
1983#[cfg(target_os = "linux")] // Only run network tests on Linux (CI stability)
1984mod tests {
1985    use std::{num::NonZeroU32, sync::Arc};
1986
1987    use futures_util::{SinkExt, StreamExt};
1988    use tokio::{
1989        net::TcpListener,
1990        task::{self, JoinHandle},
1991    };
1992    use tokio_tungstenite::{
1993        accept_hdr_async,
1994        tungstenite::{
1995            Message as WsMessage,
1996            handshake::server::{self, Callback},
1997            http::HeaderValue,
1998        },
1999    };
2000
2001    use crate::{
2002        ratelimiter::quota::Quota,
2003        websocket::{TransportBackend, WebSocketClient, WebSocketConfig},
2004    };
2005
2006    struct TestServer {
2007        task: JoinHandle<()>,
2008        port: u16,
2009    }
2010
2011    #[derive(Debug, Clone)]
2012    struct TestCallback {
2013        key: String,
2014        value: HeaderValue,
2015    }
2016
2017    impl Callback for TestCallback {
2018        #[expect(clippy::panic_in_result_fn)]
2019        fn on_request(
2020            self,
2021            request: &server::Request,
2022            response: server::Response,
2023        ) -> Result<server::Response, server::ErrorResponse> {
2024            let _ = response;
2025            let value = request.headers().get(&self.key);
2026            assert!(value.is_some());
2027
2028            if let Some(value) = request.headers().get(&self.key) {
2029                assert_eq!(value, self.value);
2030            }
2031
2032            Ok(response)
2033        }
2034    }
2035
2036    impl TestServer {
2037        async fn setup() -> Self {
2038            let server = TcpListener::bind("127.0.0.1:0").await.unwrap();
2039            let port = TcpListener::local_addr(&server).unwrap().port();
2040
2041            let header_key = "test".to_string();
2042            let header_value = "test".to_string();
2043
2044            let test_call_back = TestCallback {
2045                key: header_key,
2046                value: HeaderValue::from_str(&header_value).unwrap(),
2047            };
2048
2049            let task = task::spawn(async move {
2050                // Keep accepting connections
2051                loop {
2052                    let (conn, _) = server.accept().await.unwrap();
2053                    let mut websocket = accept_hdr_async(conn, test_call_back.clone())
2054                        .await
2055                        .unwrap();
2056
2057                    task::spawn(async move {
2058                        // Inner if consumes `msg`, cannot hoist into a match guard
2059                        #[expect(clippy::collapsible_match)]
2060                        while let Some(Ok(msg)) = websocket.next().await {
2061                            match msg {
2062                                WsMessage::Text(txt) if txt == "close-now" => {
2063                                    log::debug!("Forcibly closing from server side");
2064                                    // This sends a close frame, then stops reading
2065                                    let _ = websocket.close(None).await;
2066                                    break;
2067                                }
2068                                // Echo text/binary frames
2069                                WsMessage::Text(_) | WsMessage::Binary(_) => {
2070                                    if websocket.send(msg).await.is_err() {
2071                                        break;
2072                                    }
2073                                }
2074                                // If the client closes, we also break
2075                                WsMessage::Close(_frame) => {
2076                                    let _ = websocket.close(None).await;
2077                                    break;
2078                                }
2079                                // Ignore pings/pongs
2080                                _ => {}
2081                            }
2082                        }
2083                    });
2084                }
2085            });
2086
2087            Self { task, port }
2088        }
2089    }
2090
2091    impl Drop for TestServer {
2092        fn drop(&mut self) {
2093            self.task.abort();
2094        }
2095    }
2096
2097    async fn setup_test_client(port: u16) -> WebSocketClient {
2098        let config = WebSocketConfig {
2099            url: format!("ws://127.0.0.1:{port}"),
2100            headers: vec![("test".into(), "test".into())],
2101            heartbeat: None,
2102            heartbeat_msg: None,
2103            reconnect_timeout_ms: None,
2104            reconnect_delay_initial_ms: None,
2105            reconnect_backoff_factor: None,
2106            reconnect_delay_max_ms: None,
2107            reconnect_jitter_ms: None,
2108            reconnect_max_attempts: None,
2109            idle_timeout_ms: None,
2110            backend: TransportBackend::Tungstenite,
2111            proxy_url: None,
2112        };
2113        WebSocketClient::connect(config, Some(Arc::new(|_| {})), None, None, vec![], None)
2114            .await
2115            .expect("Failed to connect")
2116    }
2117
2118    #[tokio::test]
2119    async fn test_websocket_basic() {
2120        let server = TestServer::setup().await;
2121        let client = setup_test_client(server.port).await;
2122
2123        assert!(!client.is_disconnected());
2124
2125        client.disconnect().await;
2126        assert!(client.is_disconnected());
2127    }
2128
2129    #[tokio::test]
2130    async fn test_websocket_heartbeat() {
2131        let server = TestServer::setup().await;
2132        let client = setup_test_client(server.port).await;
2133
2134        // Wait ~3s => server should see multiple "ping"
2135        tokio::time::sleep(std::time::Duration::from_secs(3)).await;
2136
2137        // Cleanup
2138        client.disconnect().await;
2139        assert!(client.is_disconnected());
2140    }
2141
2142    #[tokio::test]
2143    async fn test_websocket_reconnect_exhausted() {
2144        let config = WebSocketConfig {
2145            url: "ws://127.0.0.1:9997".into(), // <-- No server
2146            headers: vec![],
2147            heartbeat: None,
2148            heartbeat_msg: None,
2149            reconnect_timeout_ms: None,
2150            reconnect_delay_initial_ms: None,
2151            reconnect_backoff_factor: None,
2152            reconnect_delay_max_ms: None,
2153            reconnect_jitter_ms: None,
2154            reconnect_max_attempts: None,
2155            idle_timeout_ms: None,
2156            backend: TransportBackend::Tungstenite,
2157            proxy_url: None,
2158        };
2159        let res =
2160            WebSocketClient::connect(config, Some(Arc::new(|_| {})), None, None, vec![], None)
2161                .await;
2162        assert!(res.is_err(), "Should fail quickly with no server");
2163    }
2164
2165    #[tokio::test]
2166    async fn test_websocket_forced_close_reconnect() {
2167        let server = TestServer::setup().await;
2168        let client = setup_test_client(server.port).await;
2169
2170        // 1) Send normal message
2171        client.send_text("Hello".into(), None).await.unwrap();
2172
2173        // 2) Trigger forced close from server
2174        client.send_text("close-now".into(), None).await.unwrap();
2175
2176        // 3) Wait a bit => read loop sees close => reconnect
2177        tokio::time::sleep(std::time::Duration::from_secs(1)).await;
2178
2179        // Confirm not disconnected
2180        assert!(!client.is_disconnected());
2181
2182        // Cleanup
2183        client.disconnect().await;
2184        assert!(client.is_disconnected());
2185    }
2186
2187    #[tokio::test]
2188    async fn test_rate_limiter() {
2189        let server = TestServer::setup().await;
2190        let quota = Quota::per_second(NonZeroU32::new(2).unwrap()).unwrap();
2191
2192        let config = WebSocketConfig {
2193            url: format!("ws://127.0.0.1:{}", server.port),
2194            headers: vec![("test".into(), "test".into())],
2195            heartbeat: None,
2196            heartbeat_msg: None,
2197            reconnect_timeout_ms: None,
2198            reconnect_delay_initial_ms: None,
2199            reconnect_backoff_factor: None,
2200            reconnect_delay_max_ms: None,
2201            reconnect_jitter_ms: None,
2202            reconnect_max_attempts: None,
2203            idle_timeout_ms: None,
2204            backend: TransportBackend::Tungstenite,
2205            proxy_url: None,
2206        };
2207
2208        let client = WebSocketClient::connect(
2209            config,
2210            Some(Arc::new(|_| {})),
2211            None,
2212            None,
2213            vec![("default".into(), quota)],
2214            None,
2215        )
2216        .await
2217        .unwrap();
2218
2219        // First 2 should succeed
2220        client.send_text("test1".into(), None).await.unwrap();
2221        client.send_text("test2".into(), None).await.unwrap();
2222
2223        // Third should error
2224        client.send_text("test3".into(), None).await.unwrap();
2225
2226        // Cleanup
2227        client.disconnect().await;
2228        assert!(client.is_disconnected());
2229    }
2230
2231    #[tokio::test]
2232    async fn test_concurrent_writers() {
2233        let server = TestServer::setup().await;
2234        let client = Arc::new(setup_test_client(server.port).await);
2235
2236        let mut handles = vec![];
2237
2238        for i in 0..10 {
2239            let client = client.clone();
2240            handles.push(task::spawn(async move {
2241                client.send_text(format!("test{i}"), None).await.unwrap();
2242            }));
2243        }
2244
2245        for handle in handles {
2246            handle.await.unwrap();
2247        }
2248
2249        // Cleanup
2250        client.disconnect().await;
2251        assert!(client.is_disconnected());
2252    }
2253}
2254
2255#[cfg(test)]
2256#[cfg(not(feature = "turmoil"))]
2257#[cfg(not(all(feature = "simulation", madsim)))] // transport-layer I/O not simulated
2258mod rust_tests {
2259    use std::sync::{
2260        Arc, OnceLock,
2261        atomic::{AtomicBool, AtomicU8, Ordering},
2262    };
2263
2264    use futures_util::{SinkExt, StreamExt};
2265    use nautilus_common::testing::wait_until_async;
2266    use rstest::rstest;
2267    #[cfg(feature = "transport-sockudo")]
2268    use sockudo_ws::handshake as sockudo_handshake;
2269    #[cfg(feature = "transport-sockudo")]
2270    use tokio::io::{AsyncRead, AsyncReadExt, AsyncWriteExt};
2271    use tokio::{
2272        net::TcpListener,
2273        task::{self, JoinHandle},
2274        time::{Duration, sleep},
2275    };
2276    use tokio_tungstenite::{accept_async, tungstenite::Message as WsMessage};
2277    #[cfg(feature = "transport-sockudo")]
2278    use tokio_tungstenite::{
2279        accept_hdr_async,
2280        tungstenite::{
2281            handshake::server::{self, Callback},
2282            http::HeaderValue,
2283        },
2284    };
2285
2286    use super::*;
2287    use crate::websocket::types::channel_message_handler;
2288
2289    struct RecordingServer {
2290        task: JoinHandle<()>,
2291        port: u16,
2292        messages: Arc<tokio::sync::Mutex<Vec<String>>>,
2293    }
2294
2295    #[cfg(feature = "transport-sockudo")]
2296    async fn read_http_request<S>(stream: &mut S) -> Vec<u8>
2297    where
2298        S: AsyncRead + Unpin,
2299    {
2300        let mut buf = Vec::new();
2301        let mut chunk = [0u8; 256];
2302
2303        loop {
2304            let n = stream.read(&mut chunk).await.unwrap();
2305            assert!(n > 0, "HTTP request closed before headers completed");
2306            buf.extend_from_slice(&chunk[..n]);
2307            if buf.windows(4).any(|window| window == b"\r\n\r\n") {
2308                return buf;
2309            }
2310        }
2311    }
2312
2313    #[cfg(feature = "transport-sockudo")]
2314    fn extract_header<'a>(request: &'a str, name: &str) -> Option<&'a str> {
2315        request.lines().find_map(|line| {
2316            let (header_name, header_value) = line.split_once(':')?;
2317            if header_name.eq_ignore_ascii_case(name) {
2318                Some(header_value.trim())
2319            } else {
2320                None
2321            }
2322        })
2323    }
2324
2325    #[cfg(feature = "transport-sockudo")]
2326    #[derive(Debug, Clone)]
2327    struct HeaderAssertCallback {
2328        key: String,
2329        value: HeaderValue,
2330    }
2331
2332    #[cfg(feature = "transport-sockudo")]
2333    impl Callback for HeaderAssertCallback {
2334        #[expect(
2335            clippy::panic_in_result_fn,
2336            reason = "assertion failures should fail the test"
2337        )]
2338        fn on_request(
2339            self,
2340            request: &server::Request,
2341            response: server::Response,
2342        ) -> Result<server::Response, server::ErrorResponse> {
2343            assert_eq!(request.headers().get(&self.key), Some(&self.value));
2344            Ok(response)
2345        }
2346    }
2347
2348    impl RecordingServer {
2349        async fn setup() -> Self {
2350            let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
2351            let port = listener.local_addr().unwrap().port();
2352            let messages = Arc::new(tokio::sync::Mutex::new(Vec::new()));
2353            let messages_clone = Arc::clone(&messages);
2354
2355            let task = task::spawn(async move {
2356                loop {
2357                    let (stream, _) = listener.accept().await.unwrap();
2358                    let mut websocket = accept_async(stream).await.unwrap();
2359                    let messages = Arc::clone(&messages_clone);
2360
2361                    task::spawn(async move {
2362                        while let Some(Ok(msg)) = websocket.next().await {
2363                            match msg {
2364                                WsMessage::Text(text) => {
2365                                    messages.lock().await.push(text.to_string());
2366                                }
2367                                WsMessage::Close(_) => {
2368                                    let _ = websocket.close(None).await;
2369                                    break;
2370                                }
2371                                _ => {}
2372                            }
2373                        }
2374                    });
2375                }
2376            });
2377
2378            Self {
2379                task,
2380                port,
2381                messages,
2382            }
2383        }
2384
2385        async fn messages(&self) -> Vec<String> {
2386            self.messages.lock().await.clone()
2387        }
2388    }
2389
2390    impl Drop for RecordingServer {
2391        fn drop(&mut self) {
2392            self.task.abort();
2393        }
2394    }
2395
2396    #[rstest]
2397    #[tokio::test]
2398    async fn test_reconnect_then_disconnect() {
2399        // Bind an ephemeral port
2400        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
2401        let port = listener.local_addr().unwrap().port();
2402
2403        // Server task: accept one ws connection then close it
2404        let server = task::spawn(async move {
2405            let (stream, _) = listener.accept().await.unwrap();
2406            let ws = accept_async(stream).await.unwrap();
2407            drop(ws);
2408            // Keep alive briefly
2409            sleep(Duration::from_secs(1)).await;
2410        });
2411
2412        // Build a channel-based message handler for incoming messages (unused here)
2413        let (handler, _rx) = channel_message_handler();
2414
2415        // Configure client with short reconnect backoff
2416        let config = WebSocketConfig {
2417            url: format!("ws://127.0.0.1:{port}"),
2418            headers: vec![],
2419            heartbeat: None,
2420            heartbeat_msg: None,
2421            reconnect_timeout_ms: Some(1_000),
2422            reconnect_delay_initial_ms: Some(50),
2423            reconnect_delay_max_ms: Some(100),
2424            reconnect_backoff_factor: Some(1.0),
2425            reconnect_jitter_ms: Some(0),
2426            reconnect_max_attempts: None,
2427            idle_timeout_ms: None,
2428            backend: TransportBackend::Tungstenite,
2429            proxy_url: None,
2430        };
2431
2432        // Connect the client
2433        let client = WebSocketClient::connect(config, Some(handler), None, None, vec![], None)
2434            .await
2435            .unwrap();
2436
2437        // Allow server to drop connection and client to detect
2438        sleep(Duration::from_millis(100)).await;
2439        // Now immediately disconnect the client
2440        client.disconnect().await;
2441        assert!(client.is_disconnected());
2442        server.abort();
2443    }
2444
2445    #[rstest]
2446    #[tokio::test]
2447    async fn test_reconnect_state_flips_when_reader_stops() {
2448        // Bind an ephemeral port and accept a single websocket connection which we drop.
2449        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
2450        let port = listener.local_addr().unwrap().port();
2451
2452        let server = task::spawn(async move {
2453            if let Ok((stream, _)) = listener.accept().await
2454                && let Ok(ws) = accept_async(stream).await
2455            {
2456                drop(ws);
2457            }
2458            sleep(Duration::from_millis(50)).await;
2459        });
2460
2461        let (handler, _rx) = channel_message_handler();
2462
2463        let config = WebSocketConfig {
2464            url: format!("ws://127.0.0.1:{port}"),
2465            headers: vec![],
2466            heartbeat: None,
2467            heartbeat_msg: None,
2468            reconnect_timeout_ms: Some(1_000),
2469            reconnect_delay_initial_ms: Some(50),
2470            reconnect_delay_max_ms: Some(100),
2471            reconnect_backoff_factor: Some(1.0),
2472            reconnect_jitter_ms: Some(0),
2473            reconnect_max_attempts: None,
2474            idle_timeout_ms: None,
2475            backend: TransportBackend::Tungstenite,
2476            proxy_url: None,
2477        };
2478
2479        let client = WebSocketClient::connect(config, Some(handler), None, None, vec![], None)
2480            .await
2481            .unwrap();
2482
2483        tokio::time::timeout(Duration::from_secs(2), async {
2484            loop {
2485                if client.is_reconnecting() {
2486                    break;
2487                }
2488                tokio::time::sleep(Duration::from_millis(10)).await;
2489            }
2490        })
2491        .await
2492        .expect("client did not enter RECONNECT state");
2493
2494        client.disconnect().await;
2495        server.abort();
2496    }
2497
2498    #[rstest]
2499    #[tokio::test]
2500    async fn test_stream_mode_disables_auto_reconnect() {
2501        // Test that stream-based clients (created via connect_stream) set is_stream_mode flag
2502        // and that reconnect() transitions to CLOSED state for stream mode
2503        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
2504        let port = listener.local_addr().unwrap().port();
2505
2506        let server = task::spawn(async move {
2507            if let Ok((stream, _)) = listener.accept().await
2508                && let Ok(_ws) = accept_async(stream).await
2509            {
2510                // Keep connection alive briefly
2511                sleep(Duration::from_millis(100)).await;
2512            }
2513        });
2514
2515        let config = WebSocketConfig {
2516            url: format!("ws://127.0.0.1:{port}"),
2517            headers: vec![],
2518            heartbeat: None,
2519            heartbeat_msg: None,
2520            reconnect_timeout_ms: Some(1_000),
2521            reconnect_delay_initial_ms: Some(50),
2522            reconnect_delay_max_ms: Some(100),
2523            reconnect_backoff_factor: Some(1.0),
2524            reconnect_jitter_ms: Some(0),
2525            reconnect_max_attempts: None,
2526            idle_timeout_ms: None,
2527            backend: TransportBackend::Tungstenite,
2528            proxy_url: None,
2529        };
2530
2531        let (_reader, _client) = WebSocketClient::connect_stream(config, vec![], None, None)
2532            .await
2533            .unwrap();
2534
2535        // Note: We can't easily test the reconnect behavior from the outside since
2536        // the inner client is private. The key fix is that WebSocketClientInner
2537        // now has is_stream_mode=true for connect_stream, and reconnect() will
2538        // transition to CLOSED state instead of creating a new reader that gets dropped.
2539        // This is tested implicitly by the fact that stream users won't get stuck
2540        // in an infinite reconnect loop.
2541
2542        server.abort();
2543    }
2544
2545    #[rstest]
2546    #[tokio::test]
2547    async fn test_message_handler_mode_allows_auto_reconnect() {
2548        // Test that regular clients (with message handler) can auto-reconnect
2549        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
2550        let port = listener.local_addr().unwrap().port();
2551
2552        let server = task::spawn(async move {
2553            // Accept first connection and close it
2554            if let Ok((stream, _)) = listener.accept().await
2555                && let Ok(ws) = accept_async(stream).await
2556            {
2557                drop(ws);
2558            }
2559            sleep(Duration::from_millis(50)).await;
2560        });
2561
2562        let (handler, _rx) = channel_message_handler();
2563
2564        let config = WebSocketConfig {
2565            url: format!("ws://127.0.0.1:{port}"),
2566            headers: vec![],
2567            heartbeat: None,
2568            heartbeat_msg: None,
2569            reconnect_timeout_ms: Some(1_000),
2570            reconnect_delay_initial_ms: Some(50),
2571            reconnect_delay_max_ms: Some(100),
2572            reconnect_backoff_factor: Some(1.0),
2573            reconnect_jitter_ms: Some(0),
2574            reconnect_max_attempts: None,
2575            idle_timeout_ms: None,
2576            backend: TransportBackend::Tungstenite,
2577            proxy_url: None,
2578        };
2579
2580        let client = WebSocketClient::connect(config, Some(handler), None, None, vec![], None)
2581            .await
2582            .unwrap();
2583
2584        // Wait for the connection to be dropped and reconnection to be attempted
2585        tokio::time::timeout(Duration::from_secs(2), async {
2586            loop {
2587                if client.is_reconnecting() || client.is_closed() {
2588                    break;
2589                }
2590                tokio::time::sleep(Duration::from_millis(10)).await;
2591            }
2592        })
2593        .await
2594        .expect("client should attempt reconnection or close");
2595
2596        // Should either be reconnecting or closed (depending on timing)
2597        // The important thing is it's not staying active forever
2598        assert!(
2599            client.is_reconnecting() || client.is_closed(),
2600            "Client with message handler should attempt reconnection"
2601        );
2602
2603        client.disconnect().await;
2604        server.abort();
2605    }
2606
2607    #[rstest]
2608    #[tokio::test]
2609    async fn test_handler_mode_reconnect_with_new_connection() {
2610        // Test that handler mode successfully reconnects and messages continue flowing
2611        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
2612        let port = listener.local_addr().unwrap().port();
2613
2614        let server = task::spawn(async move {
2615            // First connection - accept and immediately close
2616            if let Ok((stream, _)) = listener.accept().await
2617                && let Ok(ws) = accept_async(stream).await
2618            {
2619                drop(ws);
2620            }
2621
2622            // Small delay to let client detect disconnection
2623            sleep(Duration::from_millis(100)).await;
2624
2625            // Second connection - accept, send a message, then keep alive
2626            if let Ok((stream, _)) = listener.accept().await
2627                && let Ok(mut ws) = accept_async(stream).await
2628            {
2629                use futures_util::SinkExt;
2630                let _ = ws
2631                    .send(WsMessage::Text("reconnected".to_string().into()))
2632                    .await;
2633                sleep(Duration::from_secs(1)).await;
2634            }
2635        });
2636
2637        let (handler, mut rx) = channel_message_handler();
2638
2639        let config = WebSocketConfig {
2640            url: format!("ws://127.0.0.1:{port}"),
2641            headers: vec![],
2642            heartbeat: None,
2643            heartbeat_msg: None,
2644            reconnect_timeout_ms: Some(2_000),
2645            reconnect_delay_initial_ms: Some(50),
2646            reconnect_delay_max_ms: Some(200),
2647            reconnect_backoff_factor: Some(1.5),
2648            reconnect_jitter_ms: Some(10),
2649            reconnect_max_attempts: None,
2650            idle_timeout_ms: None,
2651            backend: TransportBackend::Tungstenite,
2652            proxy_url: None,
2653        };
2654
2655        let client = WebSocketClient::connect(config, Some(handler), None, None, vec![], None)
2656            .await
2657            .unwrap();
2658
2659        // Wait for reconnection to happen and message to arrive
2660        let result = tokio::time::timeout(Duration::from_secs(5), async {
2661            loop {
2662                if let Ok(msg) = rx.try_recv()
2663                    && matches!(msg, WsMessage::Text(ref text) if AsRef::<str>::as_ref(text) == "reconnected")
2664                {
2665                    return true;
2666                }
2667                tokio::time::sleep(Duration::from_millis(10)).await;
2668            }
2669        })
2670        .await;
2671
2672        assert!(
2673            result.is_ok(),
2674            "Should receive message after reconnection within timeout"
2675        );
2676
2677        client.disconnect().await;
2678        server.abort();
2679    }
2680
2681    #[rstest]
2682    #[tokio::test]
2683    async fn test_stream_mode_no_auto_reconnect() {
2684        // Test that stream mode does not automatically reconnect when connection is lost
2685        // The caller owns the reader and is responsible for detecting disconnection
2686        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
2687        let port = listener.local_addr().unwrap().port();
2688
2689        let server = task::spawn(async move {
2690            // Accept connection and send one message, then close
2691            if let Ok((stream, _)) = listener.accept().await
2692                && let Ok(mut ws) = accept_async(stream).await
2693            {
2694                use futures_util::SinkExt;
2695                let _ = ws.send(WsMessage::Text("hello".to_string().into())).await;
2696                sleep(Duration::from_millis(50)).await;
2697                // Connection closes when ws is dropped
2698            }
2699        });
2700
2701        let config = WebSocketConfig {
2702            url: format!("ws://127.0.0.1:{port}"),
2703            headers: vec![],
2704            heartbeat: None,
2705            heartbeat_msg: None,
2706            reconnect_timeout_ms: Some(1_000),
2707            reconnect_delay_initial_ms: Some(50),
2708            reconnect_delay_max_ms: Some(100),
2709            reconnect_backoff_factor: Some(1.0),
2710            reconnect_jitter_ms: Some(0),
2711            reconnect_max_attempts: None,
2712            idle_timeout_ms: None,
2713            backend: TransportBackend::Tungstenite,
2714            proxy_url: None,
2715        };
2716
2717        let (mut reader, client) = WebSocketClient::connect_stream(config, vec![], None, None)
2718            .await
2719            .unwrap();
2720
2721        // Initially active
2722        assert!(client.is_active(), "Client should start as active");
2723
2724        // Read the hello message
2725        let msg = reader.next().await;
2726        assert!(
2727            matches!(&msg, Some(Ok(Message::Text(bytes))) if bytes.as_ref() == b"hello"),
2728            "Should receive initial message"
2729        );
2730
2731        // Read until connection closes (reader will return None or error)
2732        while let Some(msg) = reader.next().await {
2733            if msg.is_err() || matches!(msg, Ok(Message::Close(_))) {
2734                break;
2735            }
2736        }
2737
2738        // Controller cannot detect reader EOF (reader is owned by caller),
2739        // so the client stays ACTIVE until the caller signals.
2740        sleep(Duration::from_millis(200)).await;
2741        assert!(
2742            client.is_active(),
2743            "Stream mode client stays ACTIVE before notify_closed()"
2744        );
2745
2746        // Caller signals EOF via notify_closed()
2747        client.notify_closed();
2748
2749        assert!(
2750            client.is_closed(),
2751            "Stream mode client should be CLOSED after notify_closed()"
2752        );
2753        assert!(
2754            !client.is_reconnecting(),
2755            "Stream mode client should never attempt reconnection"
2756        );
2757
2758        client.disconnect().await;
2759        server.abort();
2760    }
2761
2762    #[rstest]
2763    #[tokio::test]
2764    async fn test_send_timeout_uses_configured_reconnect_timeout() {
2765        // Test that send operations respect the configured reconnect_timeout.
2766        // When a client is stuck in RECONNECT longer than the timeout, sends should fail with Timeout.
2767        use nautilus_common::testing::wait_until_async;
2768
2769        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
2770        let port = listener.local_addr().unwrap().port();
2771
2772        let server = task::spawn(async move {
2773            // Accept first connection and immediately close it
2774            if let Ok((stream, _)) = listener.accept().await
2775                && let Ok(ws) = accept_async(stream).await
2776            {
2777                drop(ws);
2778            }
2779            // Don't accept second connection - client will be stuck in RECONNECT
2780            sleep(Duration::from_mins(1)).await;
2781        });
2782
2783        let (handler, _rx) = channel_message_handler();
2784
2785        // Configure with SHORT 2s reconnect timeout
2786        let config = WebSocketConfig {
2787            url: format!("ws://127.0.0.1:{port}"),
2788            headers: vec![],
2789            heartbeat: None,
2790            heartbeat_msg: None,
2791            reconnect_timeout_ms: Some(2_000), // 2s timeout
2792            reconnect_delay_initial_ms: Some(50),
2793            reconnect_delay_max_ms: Some(100),
2794            reconnect_backoff_factor: Some(1.0),
2795            reconnect_jitter_ms: Some(0),
2796            reconnect_max_attempts: None,
2797            idle_timeout_ms: None,
2798            backend: TransportBackend::Tungstenite,
2799            proxy_url: None,
2800        };
2801
2802        let client = WebSocketClient::connect(config, Some(handler), None, None, vec![], None)
2803            .await
2804            .unwrap();
2805
2806        // Wait for client to enter RECONNECT state
2807        wait_until_async(
2808            || async { client.is_reconnecting() },
2809            Duration::from_secs(3),
2810        )
2811        .await;
2812
2813        // Attempt send while stuck in RECONNECT - should timeout after 2s (configured timeout)
2814        let start = std::time::Instant::now();
2815        let send_result = client.send_text("test".to_string(), None).await;
2816        let elapsed = start.elapsed();
2817
2818        assert!(
2819            send_result.is_err(),
2820            "Send should fail when client stuck in RECONNECT"
2821        );
2822        assert!(
2823            matches!(send_result, Err(crate::error::SendError::Timeout)),
2824            "Send should return Timeout error, was: {send_result:?}"
2825        );
2826        // Verify timeout respects configured value (2s), but don't check upper bound
2827        // as CI scheduler jitter can cause legitimate delays beyond the timeout
2828        assert!(
2829            elapsed >= Duration::from_millis(1800),
2830            "Send should timeout after at least 2s (configured timeout), took {elapsed:?}"
2831        );
2832
2833        client.disconnect().await;
2834        server.abort();
2835    }
2836
2837    #[rstest]
2838    #[tokio::test]
2839    async fn test_send_waits_during_reconnection() {
2840        // Test that send operations wait for reconnection to complete (up to timeout)
2841        use nautilus_common::testing::wait_until_async;
2842
2843        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
2844        let port = listener.local_addr().unwrap().port();
2845
2846        let server = task::spawn(async move {
2847            // First connection - accept and immediately close
2848            if let Ok((stream, _)) = listener.accept().await
2849                && let Ok(ws) = accept_async(stream).await
2850            {
2851                drop(ws);
2852            }
2853
2854            // Wait a bit before accepting second connection
2855            sleep(Duration::from_millis(500)).await;
2856
2857            // Second connection - accept and keep alive
2858            if let Ok((stream, _)) = listener.accept().await
2859                && let Ok(mut ws) = accept_async(stream).await
2860            {
2861                // Echo messages
2862                while let Some(Ok(msg)) = ws.next().await {
2863                    if ws.send(msg).await.is_err() {
2864                        break;
2865                    }
2866                }
2867            }
2868        });
2869
2870        let (handler, _rx) = channel_message_handler();
2871
2872        let config = WebSocketConfig {
2873            url: format!("ws://127.0.0.1:{port}"),
2874            headers: vec![],
2875            heartbeat: None,
2876            heartbeat_msg: None,
2877            reconnect_timeout_ms: Some(5_000), // 5s timeout - enough for reconnect
2878            reconnect_delay_initial_ms: Some(100),
2879            reconnect_delay_max_ms: Some(200),
2880            reconnect_backoff_factor: Some(1.0),
2881            reconnect_jitter_ms: Some(0),
2882            reconnect_max_attempts: None,
2883            idle_timeout_ms: None,
2884            backend: TransportBackend::Tungstenite,
2885            proxy_url: None,
2886        };
2887
2888        let client = WebSocketClient::connect(config, Some(handler), None, None, vec![], None)
2889            .await
2890            .unwrap();
2891
2892        // Wait for reconnection to trigger
2893        wait_until_async(
2894            || async { client.is_reconnecting() },
2895            Duration::from_secs(2),
2896        )
2897        .await;
2898
2899        // Try to send while reconnecting - should wait and succeed after reconnect
2900        let send_result = tokio::time::timeout(
2901            Duration::from_secs(3),
2902            client.send_text("test_message".to_string(), None),
2903        )
2904        .await;
2905
2906        assert!(
2907            send_result.is_ok() && send_result.unwrap().is_ok(),
2908            "Send should succeed after waiting for reconnection"
2909        );
2910
2911        client.disconnect().await;
2912        server.abort();
2913    }
2914
2915    #[rstest]
2916    #[tokio::test]
2917    async fn test_rate_limiter_before_active_wait() {
2918        // Test that rate limiting happens BEFORE active state check.
2919        // This prevents race conditions where connection state changes during rate limit wait.
2920        // We verify this by: (1) exhausting rate limit, (2) ensuring client is RECONNECTING,
2921        // (3) sending again and confirming it waits for rate limit THEN reconnection.
2922        use std::{num::NonZeroU32, sync::Arc};
2923
2924        use nautilus_common::testing::wait_until_async;
2925
2926        use crate::ratelimiter::quota::Quota;
2927
2928        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
2929        let port = listener.local_addr().unwrap().port();
2930
2931        let server = task::spawn(async move {
2932            // First connection - accept and close after receiving one message
2933            if let Ok((stream, _)) = listener.accept().await
2934                && let Ok(mut ws) = accept_async(stream).await
2935            {
2936                // Receive first message then close
2937                if let Some(Ok(_)) = ws.next().await {
2938                    drop(ws);
2939                }
2940            }
2941
2942            // Wait before accepting reconnection
2943            sleep(Duration::from_millis(500)).await;
2944
2945            // Second connection - accept and keep alive
2946            if let Ok((stream, _)) = listener.accept().await
2947                && let Ok(mut ws) = accept_async(stream).await
2948            {
2949                while let Some(Ok(msg)) = ws.next().await {
2950                    if ws.send(msg).await.is_err() {
2951                        break;
2952                    }
2953                }
2954            }
2955        });
2956
2957        let (handler, _rx) = channel_message_handler();
2958
2959        let config = WebSocketConfig {
2960            url: format!("ws://127.0.0.1:{port}"),
2961            headers: vec![],
2962            heartbeat: None,
2963            heartbeat_msg: None,
2964            reconnect_timeout_ms: Some(5_000),
2965            reconnect_delay_initial_ms: Some(50),
2966            reconnect_delay_max_ms: Some(100),
2967            reconnect_backoff_factor: Some(1.0),
2968            reconnect_jitter_ms: Some(0),
2969            reconnect_max_attempts: None,
2970            idle_timeout_ms: None,
2971            backend: TransportBackend::Tungstenite,
2972            proxy_url: None,
2973        };
2974
2975        // Very restrictive rate limit: 1 request per second, burst of 1
2976        let quota = Quota::per_second(NonZeroU32::new(1).unwrap())
2977            .unwrap()
2978            .allow_burst(NonZeroU32::new(1).unwrap());
2979
2980        let client = Arc::new(
2981            WebSocketClient::connect(
2982                config,
2983                Some(handler),
2984                None,
2985                None,
2986                vec![("test_key".to_string(), quota)],
2987                None,
2988            )
2989            .await
2990            .unwrap(),
2991        );
2992
2993        // First send exhausts burst capacity and triggers connection close
2994        let test_key: [Ustr; 1] = [Ustr::from("test_key")];
2995        client
2996            .send_text("msg1".to_string(), Some(test_key.as_slice()))
2997            .await
2998            .unwrap();
2999
3000        // Wait for client to enter RECONNECT state
3001        wait_until_async(
3002            || async { client.is_reconnecting() },
3003            Duration::from_secs(2),
3004        )
3005        .await;
3006
3007        // Second send: will hit rate limit (~1s) THEN wait for reconnection (~0.5s)
3008        let start = std::time::Instant::now();
3009        let send_result = client
3010            .send_text("msg2".to_string(), Some(test_key.as_slice()))
3011            .await;
3012        let elapsed = start.elapsed();
3013
3014        // Should succeed after both rate limit AND reconnection
3015        assert!(
3016            send_result.is_ok(),
3017            "Send should succeed after rate limit + reconnection, was: {send_result:?}"
3018        );
3019        // Total wait should be at least rate limit time (~1s)
3020        // The reconnection completes while rate limiting or after
3021        // Use 850ms threshold to account for timing jitter in CI
3022        assert!(
3023            elapsed >= Duration::from_millis(850),
3024            "Should wait for rate limit (~1s), waited {elapsed:?}"
3025        );
3026
3027        client.disconnect().await;
3028        server.abort();
3029    }
3030
3031    #[rstest]
3032    #[tokio::test]
3033    async fn test_disconnect_during_reconnect_exits_cleanly() {
3034        // Test CAS race condition: disconnect called during reconnection
3035        // Should exit cleanly without spawning new tasks
3036        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
3037        let port = listener.local_addr().unwrap().port();
3038
3039        let server = task::spawn(async move {
3040            // Accept first connection and immediately close
3041            if let Ok((stream, _)) = listener.accept().await
3042                && let Ok(ws) = accept_async(stream).await
3043            {
3044                drop(ws);
3045            }
3046            // Don't accept second connection - let reconnect hang
3047            sleep(Duration::from_mins(1)).await;
3048        });
3049
3050        let (handler, _rx) = channel_message_handler();
3051
3052        let config = WebSocketConfig {
3053            url: format!("ws://127.0.0.1:{port}"),
3054            headers: vec![],
3055            heartbeat: None,
3056            heartbeat_msg: None,
3057            reconnect_timeout_ms: Some(2_000), // 2s timeout - shorter than disconnect timeout
3058            reconnect_delay_initial_ms: Some(100),
3059            reconnect_delay_max_ms: Some(200),
3060            reconnect_backoff_factor: Some(1.0),
3061            reconnect_jitter_ms: Some(0),
3062            reconnect_max_attempts: None,
3063            idle_timeout_ms: None,
3064            backend: TransportBackend::Tungstenite,
3065            proxy_url: None,
3066        };
3067
3068        let client = WebSocketClient::connect(config, Some(handler), None, None, vec![], None)
3069            .await
3070            .unwrap();
3071
3072        // Wait for reconnection to start
3073        tokio::time::timeout(Duration::from_secs(2), async {
3074            while !client.is_reconnecting() {
3075                sleep(Duration::from_millis(10)).await;
3076            }
3077        })
3078        .await
3079        .expect("Client should enter RECONNECT state");
3080
3081        // Disconnect while reconnecting
3082        client.disconnect().await;
3083
3084        // Should be cleanly closed
3085        assert!(
3086            client.is_disconnected(),
3087            "Client should be cleanly disconnected"
3088        );
3089
3090        server.abort();
3091    }
3092
3093    #[rstest]
3094    #[tokio::test]
3095    async fn test_send_fails_fast_when_closed_before_rate_limit() {
3096        // Test that send operations check connection state BEFORE rate limiting,
3097        // preventing unnecessary delays when the connection is already closed.
3098        use std::{num::NonZeroU32, sync::Arc};
3099
3100        use nautilus_common::testing::wait_until_async;
3101
3102        use crate::ratelimiter::quota::Quota;
3103
3104        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
3105        let port = listener.local_addr().unwrap().port();
3106
3107        let server = task::spawn(async move {
3108            // Accept connection and immediately close
3109            if let Ok((stream, _)) = listener.accept().await
3110                && let Ok(ws) = accept_async(stream).await
3111            {
3112                drop(ws);
3113            }
3114            sleep(Duration::from_mins(1)).await;
3115        });
3116
3117        let (handler, _rx) = channel_message_handler();
3118
3119        let config = WebSocketConfig {
3120            url: format!("ws://127.0.0.1:{port}"),
3121            headers: vec![],
3122            heartbeat: None,
3123            heartbeat_msg: None,
3124            reconnect_timeout_ms: Some(5_000),
3125            reconnect_delay_initial_ms: Some(50),
3126            reconnect_delay_max_ms: Some(100),
3127            reconnect_backoff_factor: Some(1.0),
3128            reconnect_jitter_ms: Some(0),
3129            reconnect_max_attempts: None,
3130            idle_timeout_ms: None,
3131            backend: TransportBackend::Tungstenite,
3132            proxy_url: None,
3133        };
3134
3135        // Very restrictive rate limit: 1 request per 10 seconds
3136        // This ensures that if we wait for rate limit, the test will timeout
3137        let quota = Quota::with_period(Duration::from_secs(10))
3138            .unwrap()
3139            .allow_burst(NonZeroU32::new(1).unwrap());
3140
3141        let client = Arc::new(
3142            WebSocketClient::connect(
3143                config,
3144                Some(handler),
3145                None,
3146                None,
3147                vec![("test_key".to_string(), quota)],
3148                None,
3149            )
3150            .await
3151            .unwrap(),
3152        );
3153
3154        // Wait for disconnection
3155        wait_until_async(
3156            || async { client.is_reconnecting() || client.is_closed() },
3157            Duration::from_secs(2),
3158        )
3159        .await;
3160
3161        // Explicitly disconnect to move away from ACTIVE state
3162        client.disconnect().await;
3163        assert!(
3164            !client.is_active(),
3165            "Client should not be active after disconnect"
3166        );
3167
3168        // Attempt send - should fail IMMEDIATELY without waiting for rate limit
3169        let start = std::time::Instant::now();
3170        let test_key: [Ustr; 1] = [Ustr::from("test_key")];
3171        let result = client
3172            .send_text("test".to_string(), Some(test_key.as_slice()))
3173            .await;
3174        let elapsed = start.elapsed();
3175
3176        // Should fail with Closed error
3177        assert!(result.is_err(), "Send should fail when client is closed");
3178        assert!(
3179            matches!(result, Err(crate::error::SendError::Closed)),
3180            "Send should return Closed error, was: {result:?}"
3181        );
3182
3183        // Should fail FAST (< 100ms) without waiting for rate limit (10s)
3184        assert!(
3185            elapsed < Duration::from_millis(100),
3186            "Send should fail fast without rate limiting, took {elapsed:?}"
3187        );
3188
3189        server.abort();
3190    }
3191
3192    #[rstest]
3193    #[tokio::test]
3194    async fn test_connect_rejects_none_message_handler() {
3195        // Test that connect() properly rejects None message_handler
3196        // to prevent zombie connections that appear alive but never detect disconnections
3197
3198        let config = WebSocketConfig {
3199            url: "ws://127.0.0.1:9999".to_string(),
3200            headers: vec![],
3201            heartbeat: None,
3202            heartbeat_msg: None,
3203            reconnect_timeout_ms: Some(1_000),
3204            reconnect_delay_initial_ms: Some(100),
3205            reconnect_delay_max_ms: Some(500),
3206            reconnect_backoff_factor: Some(1.5),
3207            reconnect_jitter_ms: Some(0),
3208            reconnect_max_attempts: None,
3209            idle_timeout_ms: None,
3210            backend: TransportBackend::Tungstenite,
3211            proxy_url: None,
3212        };
3213
3214        // Pass None for message_handler - should be rejected
3215        let result = WebSocketClient::connect(config, None, None, None, vec![], None).await;
3216
3217        assert!(
3218            result.is_err(),
3219            "connect() should reject None message_handler"
3220        );
3221
3222        let err = result.unwrap_err();
3223        let err_msg = err.to_string();
3224        assert!(
3225            err_msg.contains("Handler mode requires message_handler"),
3226            "Error should mention missing message_handler, was: {err_msg}"
3227        );
3228    }
3229
3230    #[rstest]
3231    #[tokio::test]
3232    async fn test_client_without_handler_sets_stream_mode() {
3233        // Test that if a client is created without a handler via connect_url,
3234        // it properly sets is_stream_mode=true to prevent zombie connections
3235
3236        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
3237        let port = listener.local_addr().unwrap().port();
3238
3239        let server = task::spawn(async move {
3240            // Accept and immediately close to simulate server disconnect
3241            if let Ok((stream, _)) = listener.accept().await
3242                && let Ok(ws) = accept_async(stream).await
3243            {
3244                drop(ws); // Drop connection immediately
3245            }
3246        });
3247
3248        let config = WebSocketConfig {
3249            url: format!("ws://127.0.0.1:{port}"),
3250            headers: vec![],
3251            heartbeat: None,
3252            heartbeat_msg: None,
3253            reconnect_timeout_ms: Some(1_000),
3254            reconnect_delay_initial_ms: Some(100),
3255            reconnect_delay_max_ms: Some(500),
3256            reconnect_backoff_factor: Some(1.5),
3257            reconnect_jitter_ms: Some(0),
3258            reconnect_max_attempts: None,
3259            idle_timeout_ms: None,
3260            backend: TransportBackend::Tungstenite,
3261            proxy_url: None,
3262        };
3263
3264        // Create client directly via connect_url with no handler (stream mode)
3265        let inner = WebSocketClientInner::connect_url(config, None, None)
3266            .await
3267            .unwrap();
3268
3269        // Verify is_stream_mode is true when no handler
3270        assert!(
3271            inner.is_stream_mode,
3272            "Client without handler should have is_stream_mode=true"
3273        );
3274
3275        // Verify that when stream mode is enabled, reconnection is disabled
3276        // (documented behavior - stream mode clients close instead of reconnecting)
3277
3278        server.abort();
3279    }
3280
3281    #[rstest]
3282    #[tokio::test]
3283    async fn test_idle_timeout_triggers_reconnect() {
3284        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
3285        let port = listener.local_addr().unwrap().port();
3286
3287        // Server accepts WS connection but sends nothing (simulates silent death)
3288        let server = task::spawn(async move {
3289            let (stream, _) = listener.accept().await.unwrap();
3290            let _ws = accept_async(stream).await.unwrap();
3291            // Hold connection open but send nothing
3292            sleep(Duration::from_secs(5)).await;
3293        });
3294
3295        let (handler, _rx) = channel_message_handler();
3296
3297        let config = WebSocketConfig {
3298            url: format!("ws://127.0.0.1:{port}"),
3299            headers: vec![],
3300            heartbeat: None,
3301            heartbeat_msg: None,
3302            reconnect_timeout_ms: Some(2_000),
3303            reconnect_delay_initial_ms: Some(50),
3304            reconnect_delay_max_ms: Some(100),
3305            reconnect_backoff_factor: Some(1.0),
3306            reconnect_jitter_ms: Some(0),
3307            reconnect_max_attempts: Some(1),
3308            idle_timeout_ms: Some(500),
3309            backend: TransportBackend::Tungstenite,
3310            proxy_url: None,
3311        };
3312
3313        let client = WebSocketClient::connect(config, Some(handler), None, None, vec![], None)
3314            .await
3315            .unwrap();
3316
3317        assert!(client.is_active());
3318
3319        // Wait for idle timeout to fire and client to enter reconnect/closed
3320        wait_until_async(
3321            || async { client.is_reconnecting() || client.is_disconnected() },
3322            Duration::from_secs(3),
3323        )
3324        .await;
3325
3326        assert!(
3327            !client.is_active(),
3328            "Client should not be active after idle timeout"
3329        );
3330
3331        client.disconnect().await;
3332        server.abort();
3333    }
3334
3335    #[rstest]
3336    #[tokio::test]
3337    async fn test_idle_timeout_resets_on_data() {
3338        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
3339        let port = listener.local_addr().unwrap().port();
3340
3341        // Server sends a message every 200ms (well within 1s idle timeout)
3342        let server = task::spawn(async move {
3343            let (stream, _) = listener.accept().await.unwrap();
3344            let mut ws = accept_async(stream).await.unwrap();
3345
3346            for _ in 0..10 {
3347                sleep(Duration::from_millis(200)).await;
3348
3349                if ws.send(WsMessage::Text("ping".into())).await.is_err() {
3350                    break;
3351                }
3352            }
3353        });
3354
3355        let (handler, _rx) = channel_message_handler();
3356
3357        let config = WebSocketConfig {
3358            url: format!("ws://127.0.0.1:{port}"),
3359            headers: vec![],
3360            heartbeat: None,
3361            heartbeat_msg: None,
3362            reconnect_timeout_ms: Some(2_000),
3363            reconnect_delay_initial_ms: Some(50),
3364            reconnect_delay_max_ms: Some(100),
3365            reconnect_backoff_factor: Some(1.0),
3366            reconnect_jitter_ms: Some(0),
3367            reconnect_max_attempts: Some(1),
3368            idle_timeout_ms: Some(1_000),
3369            backend: TransportBackend::Tungstenite,
3370            proxy_url: None,
3371        };
3372
3373        let client = WebSocketClient::connect(config, Some(handler), None, None, vec![], None)
3374            .await
3375            .unwrap();
3376
3377        assert!(client.is_active());
3378
3379        // Wait 1.5s - data arrives every 200ms so idle timeout (1s) should NOT fire
3380        sleep(Duration::from_millis(1_500)).await;
3381
3382        assert!(
3383            client.is_active(),
3384            "Client should remain active when data is flowing"
3385        );
3386
3387        client.disconnect().await;
3388        server.abort();
3389    }
3390
3391    #[rstest]
3392    #[tokio::test]
3393    async fn test_idle_timeout_fires_when_only_pings_received() {
3394        // Regression: pings and pongs are keep-alive frames, not application data,
3395        // so a peer that only emits control frames must still trip the idle timeout.
3396        // The peer keeps pinging for well past the observation window so the
3397        // pre-fix behavior (reset-on-ping) would keep the client active; under the
3398        // fix the idle timer never resets and fires after ~500ms.
3399        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
3400        let port = listener.local_addr().unwrap().port();
3401
3402        let server = task::spawn(async move {
3403            let (stream, _) = listener.accept().await.unwrap();
3404            let mut ws = accept_async(stream).await.unwrap();
3405
3406            for _ in 0..60 {
3407                sleep(Duration::from_millis(100)).await;
3408
3409                if ws.send(WsMessage::Ping(Vec::new().into())).await.is_err() {
3410                    break;
3411                }
3412            }
3413        });
3414
3415        let (handler, _rx) = channel_message_handler();
3416
3417        let config = WebSocketConfig {
3418            url: format!("ws://127.0.0.1:{port}"),
3419            headers: vec![],
3420            heartbeat: None,
3421            heartbeat_msg: None,
3422            reconnect_timeout_ms: Some(2_000),
3423            reconnect_delay_initial_ms: Some(50),
3424            reconnect_delay_max_ms: Some(100),
3425            reconnect_backoff_factor: Some(1.0),
3426            reconnect_jitter_ms: Some(0),
3427            reconnect_max_attempts: Some(1),
3428            idle_timeout_ms: Some(500),
3429            backend: TransportBackend::Tungstenite,
3430            proxy_url: None,
3431        };
3432
3433        let client = WebSocketClient::connect(config, Some(handler), None, None, vec![], None)
3434            .await
3435            .unwrap();
3436
3437        assert!(client.is_active());
3438
3439        // Observation window is shorter than the ping stream (6s). If the idle
3440        // timer mistakenly reset on every ping the client would still be active
3441        // here; under the fix it goes inactive at ~500ms.
3442        wait_until_async(
3443            || async { client.is_reconnecting() || client.is_disconnected() },
3444            Duration::from_millis(1_500),
3445        )
3446        .await;
3447
3448        assert!(
3449            !client.is_active(),
3450            "Client should not be active after idle timeout when only pings/pongs flow"
3451        );
3452
3453        client.disconnect().await;
3454        server.abort();
3455    }
3456
3457    #[rstest]
3458    #[tokio::test]
3459    async fn test_idle_timeout_fires_when_only_pongs_received() {
3460        // Regression for the heartbeat-reply path. When the client heartbeat is
3461        // enabled, the peer auto-replies with pongs for every outgoing ping. If
3462        // those pongs refreshed last_data_time the idle timer would never fire on
3463        // a zombie connection (the motivating Polymarket scenario).
3464        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
3465        let port = listener.local_addr().unwrap().port();
3466
3467        let server = task::spawn(async move {
3468            let (stream, _) = listener.accept().await.unwrap();
3469            let mut ws = accept_async(stream).await.unwrap();
3470
3471            // Drain incoming frames so tungstenite's internal pong replies are
3472            // actually flushed to the client. Hold the connection open well past
3473            // the observation window.
3474            let deadline = tokio::time::Instant::now() + Duration::from_secs(6);
3475            while tokio::time::Instant::now() < deadline {
3476                if let Ok(Some(Err(_)) | None) =
3477                    tokio::time::timeout(Duration::from_millis(100), ws.next()).await
3478                {
3479                    break;
3480                }
3481            }
3482        });
3483
3484        let (handler, _rx) = channel_message_handler();
3485
3486        let config = WebSocketConfig {
3487            url: format!("ws://127.0.0.1:{port}"),
3488            headers: vec![],
3489            heartbeat: Some(1),
3490            heartbeat_msg: None,
3491            reconnect_timeout_ms: Some(2_000),
3492            reconnect_delay_initial_ms: Some(50),
3493            reconnect_delay_max_ms: Some(100),
3494            reconnect_backoff_factor: Some(1.0),
3495            reconnect_jitter_ms: Some(0),
3496            reconnect_max_attempts: Some(1),
3497            idle_timeout_ms: Some(1_500),
3498            backend: TransportBackend::Tungstenite,
3499            proxy_url: None,
3500        };
3501
3502        let client = WebSocketClient::connect(config, Some(handler), None, None, vec![], None)
3503            .await
3504            .unwrap();
3505
3506        assert!(client.is_active());
3507
3508        // Heartbeat cadence is 1s; each ping draws a pong reply. Under the fix
3509        // the idle timer ignores those pongs and fires at ~1.5s. Under the bug
3510        // every pong reset the timer and the client would stay active.
3511        wait_until_async(
3512            || async { client.is_reconnecting() || client.is_disconnected() },
3513            Duration::from_millis(2_500),
3514        )
3515        .await;
3516
3517        assert!(
3518            !client.is_active(),
3519            "Client should not be active after idle timeout when only pongs flow"
3520        );
3521
3522        client.disconnect().await;
3523        server.abort();
3524    }
3525
3526    #[rstest]
3527    #[tokio::test]
3528    async fn test_disconnect_during_backoff_exits_promptly() {
3529        // Verify that disconnect interrupts backoff sleep (Finding 1).
3530        // Server accepts then drops, no second listener -> reconnect fails -> enters backoff.
3531        // We disconnect while backing off and assert the client shuts down quickly.
3532        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
3533        let port = listener.local_addr().unwrap().port();
3534
3535        let server = task::spawn(async move {
3536            // Accept first connection, close immediately
3537            if let Ok((stream, _)) = listener.accept().await {
3538                let _ = accept_async(stream).await;
3539            }
3540            // Don't accept again so reconnect fails and enters backoff
3541            sleep(Duration::from_mins(1)).await;
3542        });
3543
3544        let (handler, _rx) = channel_message_handler();
3545
3546        let config = WebSocketConfig {
3547            url: format!("ws://127.0.0.1:{port}"),
3548            headers: vec![],
3549            heartbeat: None,
3550            heartbeat_msg: None,
3551            reconnect_timeout_ms: Some(1_000),
3552            reconnect_delay_initial_ms: Some(10_000), // 10s backoff to ensure we're sleeping
3553            reconnect_delay_max_ms: Some(10_000),
3554            reconnect_backoff_factor: Some(1.0),
3555            reconnect_jitter_ms: Some(0),
3556            reconnect_max_attempts: None,
3557            idle_timeout_ms: None,
3558            backend: TransportBackend::Tungstenite,
3559            proxy_url: None,
3560        };
3561
3562        let client = WebSocketClient::connect(config, Some(handler), None, None, vec![], None)
3563            .await
3564            .unwrap();
3565
3566        // Wait for client to enter reconnect
3567        wait_until_async(
3568            || async { client.is_reconnecting() },
3569            Duration::from_secs(3),
3570        )
3571        .await;
3572
3573        // Wait a bit more for the reconnect attempt to fail and enter backoff sleep
3574        sleep(Duration::from_millis(1_500)).await;
3575
3576        // Disconnect while backing off
3577        let start = std::time::Instant::now();
3578        client.disconnect().await;
3579        let elapsed = start.elapsed();
3580
3581        assert!(client.is_disconnected(), "Client should be disconnected");
3582        // Should exit well before the 10s backoff sleep completes
3583        assert!(
3584            elapsed < Duration::from_secs(2),
3585            "Disconnect should interrupt backoff sleep, took {elapsed:?}"
3586        );
3587
3588        server.abort();
3589    }
3590
3591    #[rstest]
3592    #[tokio::test]
3593    async fn test_rate_limit_cancelled_on_disconnect() {
3594        // Verify that a send blocked on rate limiting returns Closed when
3595        // the client disconnects (Finding 6).
3596        use std::{num::NonZeroU32, sync::Arc};
3597
3598        use crate::ratelimiter::quota::Quota;
3599
3600        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
3601        let port = listener.local_addr().unwrap().port();
3602
3603        let server = task::spawn(async move {
3604            if let Ok((stream, _)) = listener.accept().await {
3605                let mut ws = accept_async(stream).await.unwrap();
3606                // Keep alive and echo
3607                while let Some(Ok(msg)) = ws.next().await {
3608                    if ws.send(msg).await.is_err() {
3609                        break;
3610                    }
3611                }
3612            }
3613        });
3614
3615        let (handler, _rx) = channel_message_handler();
3616
3617        let config = WebSocketConfig {
3618            url: format!("ws://127.0.0.1:{port}"),
3619            headers: vec![],
3620            heartbeat: None,
3621            heartbeat_msg: None,
3622            reconnect_timeout_ms: Some(5_000),
3623            reconnect_delay_initial_ms: Some(100),
3624            reconnect_delay_max_ms: Some(500),
3625            reconnect_backoff_factor: Some(1.5),
3626            reconnect_jitter_ms: Some(0),
3627            reconnect_max_attempts: None,
3628            idle_timeout_ms: None,
3629            backend: TransportBackend::Tungstenite,
3630            proxy_url: None,
3631        };
3632
3633        // Very restrictive: 1 req per 60 seconds
3634        let quota = Quota::with_period(Duration::from_mins(1))
3635            .unwrap()
3636            .allow_burst(NonZeroU32::new(1).unwrap());
3637
3638        let client = Arc::new(
3639            WebSocketClient::connect(
3640                config,
3641                Some(handler),
3642                None,
3643                None,
3644                vec![("rate_key".to_string(), quota)],
3645                None,
3646            )
3647            .await
3648            .unwrap(),
3649        );
3650
3651        let test_key: [Ustr; 1] = [Ustr::from("rate_key")];
3652
3653        // Exhaust the burst quota
3654        client
3655            .send_text("exhaust".to_string(), Some(test_key.as_slice()))
3656            .await
3657            .unwrap();
3658
3659        // Spawn a send that will block on rate limiter
3660        let client_clone = client.clone();
3661        let send_handle = task::spawn(async move {
3662            client_clone
3663                .send_text("blocked".to_string(), Some(&[Ustr::from("rate_key")]))
3664                .await
3665        });
3666
3667        // Let the send block on rate limit
3668        sleep(Duration::from_millis(200)).await;
3669
3670        // Disconnect while send is blocked
3671        let start = std::time::Instant::now();
3672        client.disconnect().await;
3673        let elapsed_disconnect = start.elapsed();
3674
3675        // The blocked send should return Closed
3676        let result = tokio::time::timeout(Duration::from_secs(2), send_handle)
3677            .await
3678            .expect("Send task should complete quickly")
3679            .expect("Send task should not panic");
3680
3681        assert!(
3682            matches!(result, Err(crate::error::SendError::Closed)),
3683            "Blocked send should return Closed, was: {result:?}"
3684        );
3685
3686        // Disconnect should be fast, not waiting for the 60s rate limit
3687        assert!(
3688            elapsed_disconnect < Duration::from_secs(3),
3689            "Disconnect should not wait for rate limiter, took {elapsed_disconnect:?}"
3690        );
3691
3692        server.abort();
3693    }
3694
3695    #[rstest]
3696    #[tokio::test]
3697    async fn test_stream_mode_transitions_to_closed_on_dead_write_task() {
3698        // Verify that stream mode transitions to CLOSED (not RECONNECT) when
3699        // the write task dies (Finding 4). We force write failure by sending
3700        // after the server closes the connection.
3701        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
3702        let port = listener.local_addr().unwrap().port();
3703
3704        let server = task::spawn(async move {
3705            if let Ok((stream, _)) = listener.accept().await
3706                && let Ok(ws) = accept_async(stream).await
3707            {
3708                // Close immediately to cause write errors
3709                drop(ws);
3710            }
3711        });
3712
3713        let config = WebSocketConfig {
3714            url: format!("ws://127.0.0.1:{port}"),
3715            headers: vec![],
3716            heartbeat: None,
3717            heartbeat_msg: None,
3718            reconnect_timeout_ms: Some(1_000),
3719            reconnect_delay_initial_ms: Some(50),
3720            reconnect_delay_max_ms: Some(100),
3721            reconnect_backoff_factor: Some(1.0),
3722            reconnect_jitter_ms: Some(0),
3723            reconnect_max_attempts: None,
3724            idle_timeout_ms: None,
3725            backend: TransportBackend::Tungstenite,
3726            proxy_url: None,
3727        };
3728
3729        let (_reader, client) = WebSocketClient::connect_stream(config, vec![], None, None)
3730            .await
3731            .unwrap();
3732
3733        assert!(client.is_active(), "Client should start active");
3734
3735        // Wait for server to close, then send to trigger write task failure
3736        sleep(Duration::from_millis(100)).await;
3737
3738        // Keep sending until the write task detects the broken connection
3739        for _ in 0..20 {
3740            let _ = client.send_text("ping".to_string(), None).await;
3741            sleep(Duration::from_millis(50)).await;
3742
3743            if !client.is_active() {
3744                break;
3745            }
3746        }
3747
3748        // Wait for controller to process the state change
3749        wait_until_async(|| async { !client.is_active() }, Duration::from_secs(5)).await;
3750
3751        // Stream mode should go to CLOSED, not RECONNECT
3752        assert!(
3753            client.is_closed() || client.is_disconnected(),
3754            "Stream mode should transition to CLOSED, not RECONNECT. \
3755             is_reconnecting={}, is_closed={}, is_disconnected={}",
3756            client.is_reconnecting(),
3757            client.is_closed(),
3758            client.is_disconnected(),
3759        );
3760        assert!(
3761            !client.is_reconnecting(),
3762            "Stream mode should never attempt reconnection"
3763        );
3764
3765        server.abort();
3766    }
3767
3768    #[tokio::test]
3769    async fn test_write_task_waits_for_auth_before_replaying_buffer() {
3770        use nautilus_common::testing::wait_until_async;
3771
3772        let server = RecordingServer::setup().await;
3773        let url = format!("ws://127.0.0.1:{}", server.port);
3774        let (writer, _reader) = WebSocketClientInner::connect_with_server(
3775            &url,
3776            vec![],
3777            TransportBackend::Tungstenite,
3778            None,
3779        )
3780        .await
3781        .unwrap();
3782
3783        let connection_state = Arc::new(AtomicU8::new(ConnectionMode::Reconnect.as_u8()));
3784        let state_notify = Arc::new(tokio::sync::Notify::new());
3785        let auth_tracker = Arc::new(OnceLock::new());
3786        let reconnect_buffer_waits_for_auth = Arc::new(AtomicBool::new(true));
3787        let tracker = AuthTracker::new();
3788        auth_tracker.set(tracker.clone()).unwrap();
3789
3790        let (writer_tx, writer_rx) = tokio::sync::mpsc::unbounded_channel();
3791        let write_task = WebSocketClientInner::spawn_write_task(
3792            Arc::clone(&connection_state),
3793            Arc::clone(&state_notify),
3794            writer,
3795            writer_rx,
3796            Arc::clone(&auth_tracker),
3797            Arc::clone(&reconnect_buffer_waits_for_auth),
3798        );
3799
3800        writer_tx
3801            .send(WriterCommand::Send(Message::Text("stale".into())))
3802            .unwrap();
3803
3804        let (new_writer, _reader) = WebSocketClientInner::connect_with_server(
3805            &url,
3806            vec![],
3807            TransportBackend::Tungstenite,
3808            None,
3809        )
3810        .await
3811        .unwrap();
3812        let (tx, rx) = tokio::sync::oneshot::channel();
3813        writer_tx
3814            .send(WriterCommand::Update(new_writer, tx))
3815            .unwrap();
3816        assert!(rx.await.unwrap());
3817
3818        connection_state.store(ConnectionMode::Active.as_u8(), Ordering::SeqCst);
3819
3820        tokio::time::sleep(Duration::from_millis(300)).await;
3821        assert!(
3822            server.messages().await.is_empty(),
3823            "buffered messages should wait for re-authentication"
3824        );
3825
3826        tracker.succeed();
3827
3828        wait_until_async(
3829            || {
3830                let messages = Arc::clone(&server.messages);
3831                async move { !messages.lock().await.is_empty() }
3832            },
3833            Duration::from_secs(3),
3834        )
3835        .await;
3836
3837        assert_eq!(server.messages().await, vec!["stale".to_string()]);
3838
3839        connection_state.store(ConnectionMode::Closed.as_u8(), Ordering::SeqCst);
3840        state_notify.notify_waiters();
3841        drop(writer_tx);
3842        write_task.abort();
3843    }
3844
3845    #[tokio::test]
3846    async fn test_write_task_discards_buffer_after_auth_failure() {
3847        let server = RecordingServer::setup().await;
3848        let url = format!("ws://127.0.0.1:{}", server.port);
3849        let (writer, _reader) = WebSocketClientInner::connect_with_server(
3850            &url,
3851            vec![],
3852            TransportBackend::Tungstenite,
3853            None,
3854        )
3855        .await
3856        .unwrap();
3857
3858        let connection_state = Arc::new(AtomicU8::new(ConnectionMode::Reconnect.as_u8()));
3859        let state_notify = Arc::new(tokio::sync::Notify::new());
3860        let auth_tracker = Arc::new(OnceLock::new());
3861        let reconnect_buffer_waits_for_auth = Arc::new(AtomicBool::new(true));
3862        let tracker = AuthTracker::new();
3863        auth_tracker.set(tracker.clone()).unwrap();
3864
3865        let (writer_tx, writer_rx) = tokio::sync::mpsc::unbounded_channel();
3866        let write_task = WebSocketClientInner::spawn_write_task(
3867            Arc::clone(&connection_state),
3868            Arc::clone(&state_notify),
3869            writer,
3870            writer_rx,
3871            Arc::clone(&auth_tracker),
3872            Arc::clone(&reconnect_buffer_waits_for_auth),
3873        );
3874
3875        writer_tx
3876            .send(WriterCommand::Send(Message::Text("stale".into())))
3877            .unwrap();
3878
3879        let (new_writer, _reader) = WebSocketClientInner::connect_with_server(
3880            &url,
3881            vec![],
3882            TransportBackend::Tungstenite,
3883            None,
3884        )
3885        .await
3886        .unwrap();
3887        let (tx, rx) = tokio::sync::oneshot::channel();
3888        writer_tx
3889            .send(WriterCommand::Update(new_writer, tx))
3890            .unwrap();
3891        assert!(rx.await.unwrap());
3892
3893        connection_state.store(ConnectionMode::Active.as_u8(), Ordering::SeqCst);
3894        tracker.fail("rejected");
3895        tokio::time::sleep(Duration::from_millis(300)).await;
3896        assert!(
3897            server.messages().await.is_empty(),
3898            "buffered messages should be discarded after authentication failure"
3899        );
3900
3901        let _auth_receiver = tracker.begin();
3902        tracker.succeed();
3903        tokio::time::sleep(Duration::from_millis(300)).await;
3904        assert!(
3905            server.messages().await.is_empty(),
3906            "discarded buffered messages should not replay on a later auth success"
3907        );
3908
3909        connection_state.store(ConnectionMode::Closed.as_u8(), Ordering::SeqCst);
3910        state_notify.notify_waiters();
3911        drop(writer_tx);
3912        write_task.abort();
3913    }
3914
3915    #[rstest]
3916    #[tokio::test]
3917    async fn test_zero_idle_timeout_rejected() {
3918        let (handler, _rx) = channel_message_handler();
3919
3920        let config = WebSocketConfig {
3921            url: "ws://127.0.0.1:9999".to_string(),
3922            headers: vec![],
3923            heartbeat: None,
3924            heartbeat_msg: None,
3925            reconnect_timeout_ms: None,
3926            reconnect_delay_initial_ms: None,
3927            reconnect_delay_max_ms: None,
3928            reconnect_backoff_factor: None,
3929            reconnect_jitter_ms: None,
3930            reconnect_max_attempts: None,
3931            idle_timeout_ms: Some(0),
3932            backend: TransportBackend::Tungstenite,
3933            proxy_url: None,
3934        };
3935
3936        let result =
3937            WebSocketClient::connect(config, Some(handler), None, None, vec![], None).await;
3938
3939        assert!(result.is_err(), "Zero idle timeout should be rejected");
3940        let err_msg = result.unwrap_err().to_string();
3941        assert!(
3942            err_msg.contains("Idle timeout cannot be zero"),
3943            "Error should mention zero idle timeout, was: {err_msg}"
3944        );
3945    }
3946
3947    #[cfg(all(feature = "transport-sockudo", not(feature = "turmoil")))]
3948    #[rstest]
3949    #[tokio::test]
3950    async fn test_sockudo_backend_rejects_reserved_headers_before_connect() {
3951        let (handler, _rx) = channel_message_handler();
3952
3953        let config = WebSocketConfig {
3954            url: "ws://127.0.0.1:1".to_string(),
3955            headers: vec![("Host".to_string(), "example.com".to_string())],
3956            heartbeat: None,
3957            heartbeat_msg: None,
3958            reconnect_timeout_ms: None,
3959            reconnect_delay_initial_ms: None,
3960            reconnect_delay_max_ms: None,
3961            reconnect_backoff_factor: None,
3962            reconnect_jitter_ms: None,
3963            reconnect_max_attempts: None,
3964            idle_timeout_ms: None,
3965            backend: TransportBackend::Sockudo,
3966            proxy_url: None,
3967        };
3968
3969        let err = WebSocketClient::connect(config, Some(handler), None, None, vec![], None)
3970            .await
3971            .expect_err("reserved header should fail before TCP connect");
3972
3973        assert!(
3974            err.to_string()
3975                .contains("reserved upgrade header not allowed in extra_headers"),
3976            "expected reserved-header failure, was: {err}"
3977        );
3978    }
3979
3980    #[cfg(all(feature = "transport-sockudo", not(feature = "turmoil")))]
3981    #[rstest]
3982    #[tokio::test]
3983    async fn test_sockudo_backend_replays_leftover_without_custom_headers() {
3984        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
3985        let port = listener.local_addr().unwrap().port();
3986
3987        let server = task::spawn(async move {
3988            if let Ok((mut stream, _)) = listener.accept().await {
3989                let request = read_http_request(&mut stream).await;
3990                let request = String::from_utf8(request).unwrap();
3991                let sec_websocket_key = extract_header(&request, "Sec-WebSocket-Key").unwrap();
3992                let accept = sockudo_handshake::generate_accept_key(sec_websocket_key);
3993                let mut response = format!(
3994                    concat!(
3995                        "HTTP/1.1 101 Switching Protocols\r\n",
3996                        "Upgrade: websocket\r\n",
3997                        "Connection: Upgrade\r\n",
3998                        "Sec-WebSocket-Accept: {}\r\n",
3999                        "\r\n",
4000                    ),
4001                    accept
4002                )
4003                .into_bytes();
4004                response.extend_from_slice(b"\x81\x05hello");
4005                stream.write_all(&response).await.unwrap();
4006            }
4007        });
4008
4009        let (handler, mut rx) = channel_message_handler();
4010
4011        let config = WebSocketConfig {
4012            url: format!("ws://127.0.0.1:{port}/ws"),
4013            headers: vec![],
4014            heartbeat: None,
4015            heartbeat_msg: None,
4016            reconnect_timeout_ms: Some(2_000),
4017            reconnect_delay_initial_ms: Some(50),
4018            reconnect_delay_max_ms: Some(100),
4019            reconnect_backoff_factor: Some(1.0),
4020            reconnect_jitter_ms: Some(0),
4021            reconnect_max_attempts: None,
4022            idle_timeout_ms: None,
4023            backend: TransportBackend::Sockudo,
4024            proxy_url: None,
4025        };
4026
4027        let client = WebSocketClient::connect(config, Some(handler), None, None, vec![], None)
4028            .await
4029            .expect("sockudo connect without custom headers");
4030
4031        let received = tokio::time::timeout(Duration::from_secs(3), async {
4032            loop {
4033                if let Ok(msg) = rx.try_recv() {
4034                    return msg;
4035                }
4036                tokio::time::sleep(Duration::from_millis(10)).await;
4037            }
4038        })
4039        .await
4040        .expect("did not receive leftover frame before timeout");
4041
4042        match received {
4043            WsMessage::Text(t) => assert_eq!(t.as_str(), "hello"),
4044            other => panic!("expected text, was {other:?}"),
4045        }
4046
4047        client.disconnect().await;
4048        tokio::time::timeout(Duration::from_secs(3), server)
4049            .await
4050            .expect("server did not close before timeout")
4051            .unwrap();
4052    }
4053
4054    #[cfg(all(feature = "transport-sockudo", not(feature = "turmoil")))]
4055    #[rstest]
4056    #[tokio::test]
4057    async fn test_sockudo_backend_sends_custom_headers() {
4058        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
4059        let port = listener.local_addr().unwrap().port();
4060
4061        let server = task::spawn(async move {
4062            if let Ok((stream, _)) = listener.accept().await {
4063                let callback = HeaderAssertCallback {
4064                    key: "X-Test".to_string(),
4065                    value: HeaderValue::from_static("value"),
4066                };
4067
4068                if let Ok(mut ws) = accept_hdr_async(stream, callback).await {
4069                    while let Some(Ok(msg)) = ws.next().await {
4070                        if msg.is_text() || msg.is_binary() {
4071                            if ws.send(msg).await.is_err() {
4072                                break;
4073                            }
4074
4075                            continue;
4076                        }
4077
4078                        if msg.is_close() {
4079                            let _ = ws.close(None).await;
4080                            break;
4081                        }
4082                    }
4083                }
4084            }
4085        });
4086
4087        let (handler, mut rx) = channel_message_handler();
4088
4089        let config = WebSocketConfig {
4090            url: format!("ws://127.0.0.1:{port}"),
4091            headers: vec![("X-Test".to_string(), "value".to_string())],
4092            heartbeat: None,
4093            heartbeat_msg: None,
4094            reconnect_timeout_ms: Some(2_000),
4095            reconnect_delay_initial_ms: Some(50),
4096            reconnect_delay_max_ms: Some(100),
4097            reconnect_backoff_factor: Some(1.0),
4098            reconnect_jitter_ms: Some(0),
4099            reconnect_max_attempts: None,
4100            idle_timeout_ms: None,
4101            backend: TransportBackend::Sockudo,
4102            proxy_url: None,
4103        };
4104
4105        let client = WebSocketClient::connect(config, Some(handler), None, None, vec![], None)
4106            .await
4107            .expect("sockudo connect with custom headers");
4108
4109        client.send_text("ping".to_string(), None).await.unwrap();
4110
4111        let received = tokio::time::timeout(Duration::from_secs(3), async {
4112            loop {
4113                if let Ok(msg) = rx.try_recv() {
4114                    return msg;
4115                }
4116                tokio::time::sleep(Duration::from_millis(10)).await;
4117            }
4118        })
4119        .await
4120        .expect("did not receive echo before timeout");
4121
4122        match received {
4123            WsMessage::Text(t) => assert_eq!(t.as_str(), "ping"),
4124            other => panic!("expected text, was {other:?}"),
4125        }
4126
4127        client.disconnect().await;
4128        tokio::time::timeout(Duration::from_secs(3), server)
4129            .await
4130            .expect("server did not close before timeout")
4131            .unwrap();
4132    }
4133
4134    #[cfg(all(feature = "transport-sockudo", not(feature = "turmoil")))]
4135    #[rstest]
4136    #[tokio::test]
4137    async fn test_sockudo_backend_round_trip_text() {
4138        // tokio-tungstenite test peer paired with a sockudo client.
4139        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
4140        let port = listener.local_addr().unwrap().port();
4141
4142        let server = task::spawn(async move {
4143            if let Ok((stream, _)) = listener.accept().await
4144                && let Ok(mut ws) = accept_async(stream).await
4145            {
4146                while let Some(Ok(msg)) = ws.next().await {
4147                    // Inner if consumes `msg`, cannot hoist into a match guard
4148                    #[expect(clippy::collapsible_match)]
4149                    match msg {
4150                        WsMessage::Text(_) | WsMessage::Binary(_) => {
4151                            if ws.send(msg).await.is_err() {
4152                                break;
4153                            }
4154                        }
4155                        WsMessage::Close(_) => {
4156                            let _ = ws.close(None).await;
4157                            break;
4158                        }
4159                        _ => {}
4160                    }
4161                }
4162            }
4163        });
4164
4165        let (handler, mut rx) = channel_message_handler();
4166        let config = WebSocketConfig {
4167            url: format!("ws://127.0.0.1:{port}"),
4168            headers: vec![],
4169            heartbeat: None,
4170            heartbeat_msg: None,
4171            reconnect_timeout_ms: Some(2_000),
4172            reconnect_delay_initial_ms: Some(50),
4173            reconnect_delay_max_ms: Some(100),
4174            reconnect_backoff_factor: Some(1.0),
4175            reconnect_jitter_ms: Some(0),
4176            reconnect_max_attempts: None,
4177            idle_timeout_ms: None,
4178            backend: TransportBackend::Sockudo,
4179            proxy_url: None,
4180        };
4181
4182        let client = WebSocketClient::connect(config, Some(handler), None, None, vec![], None)
4183            .await
4184            .expect("sockudo connect");
4185
4186        client.send_text("ping".to_string(), None).await.unwrap();
4187
4188        let received = tokio::time::timeout(Duration::from_secs(3), async {
4189            loop {
4190                if let Ok(msg) = rx.try_recv() {
4191                    return msg;
4192                }
4193                tokio::time::sleep(Duration::from_millis(10)).await;
4194            }
4195        })
4196        .await
4197        .expect("did not receive echo before timeout");
4198
4199        match received {
4200            WsMessage::Text(t) => assert_eq!(t.as_str(), "ping"),
4201            other => panic!("expected text, was {other:?}"),
4202        }
4203
4204        client.disconnect().await;
4205        server.abort();
4206    }
4207
4208    #[cfg(all(feature = "transport-sockudo", not(feature = "turmoil")))]
4209    #[rstest]
4210    #[case::ws_default_port("ws://example.com/ws", "example.com", "example.com", 80, "/ws", false)]
4211    #[case::wss_default_port(
4212        "wss://example.com/ws",
4213        "example.com",
4214        "example.com",
4215        443,
4216        "/ws",
4217        true
4218    )]
4219    // url::Url normalises explicit default ports (`:80` for ws, `:443` for wss)
4220    // away, so `parsed.port()` reports `None` here and Host stays unqualified.
4221    #[case::ws_explicit_default(
4222        "ws://example.com:80/ws",
4223        "example.com",
4224        "example.com",
4225        80,
4226        "/ws",
4227        false
4228    )]
4229    #[case::ws_non_default(
4230        "ws://example.com:8443/feed",
4231        "example.com",
4232        "example.com:8443",
4233        8443,
4234        "/feed",
4235        false
4236    )]
4237    #[case::wss_non_default(
4238        "wss://example.com:9443/feed",
4239        "example.com",
4240        "example.com:9443",
4241        9443,
4242        "/feed",
4243        true
4244    )]
4245    #[case::root_path(
4246        "ws://example.com:9000/",
4247        "example.com",
4248        "example.com:9000",
4249        9000,
4250        "/",
4251        false
4252    )]
4253    #[case::query_string(
4254        "ws://example.com/feed?token=abc&channel=trades",
4255        "example.com",
4256        "example.com",
4257        80,
4258        "/feed?token=abc&channel=trades",
4259        false
4260    )]
4261    // IPv6: bare host strips brackets for DNS/TCP/SNI; Host header keeps them.
4262    #[case::ipv6_default("ws://[::1]/feed", "::1", "[::1]", 80, "/feed", false)]
4263    #[case::ipv6_explicit_port("ws://[::1]:9000/feed", "::1", "[::1]:9000", 9000, "/feed", false)]
4264    #[case::ipv6_wss(
4265        "wss://[2001:db8::1]:8443/",
4266        "2001:db8::1",
4267        "[2001:db8::1]:8443",
4268        8443,
4269        "/",
4270        true
4271    )]
4272    fn sockudo_target_parses_url(
4273        #[case] url: &str,
4274        #[case] host: &str,
4275        #[case] host_header: &str,
4276        #[case] port: u16,
4277        #[case] path: &str,
4278        #[case] is_tls: bool,
4279    ) {
4280        let target = super::SockudoTarget::parse(url).expect("parse should succeed");
4281        assert_eq!(target.host, host);
4282        assert_eq!(target.host_header, host_header);
4283        assert_eq!(target.port, port);
4284        assert_eq!(target.path, path);
4285        assert_eq!(target.is_tls, is_tls);
4286    }
4287
4288    #[cfg(all(feature = "transport-sockudo", not(feature = "turmoil")))]
4289    #[rstest]
4290    fn sockudo_target_rejects_unsupported_scheme() {
4291        let err = super::SockudoTarget::parse("http://example.com/feed").expect_err("not a ws URL");
4292        let msg = err.to_string();
4293        assert!(
4294            msg.contains("expected ws:// or wss://"),
4295            "unexpected error: {msg}"
4296        );
4297    }
4298
4299    #[cfg(all(feature = "transport-sockudo", not(feature = "turmoil")))]
4300    #[rstest]
4301    fn sockudo_target_rejects_malformed_url() {
4302        let err = super::SockudoTarget::parse("not a url").expect_err("malformed URL");
4303        assert!(
4304            matches!(err, super::TransportError::InvalidUrl(_)),
4305            "expected InvalidUrl, was: {err:?}"
4306        );
4307    }
4308}
4309
4310#[cfg(test)]
4311#[cfg(feature = "turmoil")]
4312mod turmoil_tests {
4313    use std::{sync::Arc, time::Duration};
4314
4315    use futures_util::{SinkExt, StreamExt};
4316    use nautilus_common::testing::wait_until_async;
4317    use rstest::rstest;
4318    use tokio_tungstenite::{accept_async, tungstenite::Message as WsMessage};
4319    use turmoil::{Builder, net};
4320
4321    use super::*;
4322    use crate::websocket::types::channel_message_handler;
4323
4324    #[rstest]
4325    fn test_turmoil_reconnect_buffer_waits_for_auth() {
4326        let mut sim = Builder::new().build();
4327        let messages = Arc::new(tokio::sync::Mutex::new(Vec::new()));
4328        let server_messages = Arc::clone(&messages);
4329
4330        sim.host("server", move || {
4331            let messages = Arc::clone(&server_messages);
4332            auth_buffer_server(messages)
4333        });
4334
4335        sim.client("client", async move {
4336            let tracker = AuthTracker::new();
4337            let (handler, _rx) = channel_message_handler();
4338            let client = WebSocketClient::connect(
4339                turmoil_websocket_config(),
4340                Some(handler),
4341                None,
4342                None,
4343                vec![],
4344                None,
4345            )
4346            .await
4347            .expect("Should connect");
4348
4349            client.set_auth_tracker(tracker.clone(), true);
4350            assert!(client.is_active(), "Client should start active");
4351
4352            wait_until_async(
4353                || async { client.is_reconnecting() },
4354                Duration::from_secs(3),
4355            )
4356            .await;
4357
4358            client
4359                .writer_tx
4360                .send(WriterCommand::Send(Message::Text("stale".into())))
4361                .unwrap();
4362
4363            wait_until_async(|| async { client.is_active() }, Duration::from_secs(3)).await;
4364
4365            let _auth_receiver = tracker.begin();
4366
4367            tokio::time::sleep(Duration::from_millis(300)).await;
4368            assert!(
4369                messages.lock().await.is_empty(),
4370                "buffered messages should wait for auth after reconnect"
4371            );
4372
4373            tracker.succeed();
4374
4375            wait_until_async(
4376                || {
4377                    let messages = Arc::clone(&messages);
4378                    async move { messages.lock().await.as_slice() == ["stale"] }
4379                },
4380                Duration::from_secs(3),
4381            )
4382            .await;
4383
4384            assert_eq!(messages.lock().await.as_slice(), ["stale"]);
4385
4386            client.disconnect().await;
4387            assert!(client.is_disconnected());
4388
4389            Ok(())
4390        });
4391
4392        sim.run().unwrap();
4393    }
4394
4395    #[rstest]
4396    fn test_turmoil_reconnect_buffer_discards_after_auth_failure() {
4397        let mut sim = Builder::new().build();
4398        let messages = Arc::new(tokio::sync::Mutex::new(Vec::new()));
4399        let server_messages = Arc::clone(&messages);
4400
4401        sim.host("server", move || {
4402            let messages = Arc::clone(&server_messages);
4403            auth_buffer_server(messages)
4404        });
4405
4406        sim.client("client", async move {
4407            let tracker = AuthTracker::new();
4408            let (handler, _rx) = channel_message_handler();
4409            let client = WebSocketClient::connect(
4410                turmoil_websocket_config(),
4411                Some(handler),
4412                None,
4413                None,
4414                vec![],
4415                None,
4416            )
4417            .await
4418            .expect("Should connect");
4419
4420            client.set_auth_tracker(tracker.clone(), true);
4421            assert!(client.is_active(), "Client should start active");
4422
4423            wait_until_async(
4424                || async { client.is_reconnecting() },
4425                Duration::from_secs(3),
4426            )
4427            .await;
4428
4429            client
4430                .writer_tx
4431                .send(WriterCommand::Send(Message::Text("stale".into())))
4432                .unwrap();
4433
4434            wait_until_async(|| async { client.is_active() }, Duration::from_secs(3)).await;
4435
4436            let _auth_receiver = tracker.begin();
4437            tracker.fail("rejected");
4438
4439            tokio::time::sleep(Duration::from_millis(300)).await;
4440            assert!(
4441                messages.lock().await.is_empty(),
4442                "buffered messages should be discarded after auth failure"
4443            );
4444
4445            let _retry_auth_receiver = tracker.begin();
4446            tracker.succeed();
4447
4448            tokio::time::sleep(Duration::from_millis(300)).await;
4449            assert!(
4450                messages.lock().await.is_empty(),
4451                "discarded messages should not replay on a later auth success"
4452            );
4453
4454            client.disconnect().await;
4455            assert!(client.is_disconnected());
4456
4457            Ok(())
4458        });
4459
4460        sim.run().unwrap();
4461    }
4462
4463    fn turmoil_websocket_config() -> WebSocketConfig {
4464        WebSocketConfig {
4465            url: "ws://server:8080".to_string(),
4466            headers: vec![],
4467            heartbeat: None,
4468            heartbeat_msg: None,
4469            reconnect_timeout_ms: Some(5_000),
4470            reconnect_delay_initial_ms: Some(50),
4471            reconnect_delay_max_ms: Some(200),
4472            reconnect_backoff_factor: Some(1.0),
4473            reconnect_jitter_ms: Some(0),
4474            reconnect_max_attempts: None,
4475            idle_timeout_ms: None,
4476            backend: TransportBackend::Tungstenite,
4477            proxy_url: None,
4478        }
4479    }
4480
4481    async fn auth_buffer_server(
4482        messages: Arc<tokio::sync::Mutex<Vec<String>>>,
4483    ) -> Result<(), Box<dyn std::error::Error>> {
4484        let listener = net::TcpListener::bind("0.0.0.0:8080").await?;
4485
4486        let (stream, _) = listener.accept().await?;
4487        let mut websocket = accept_async(stream).await?;
4488        let _ = websocket.send(WsMessage::Text("first".into())).await;
4489        drop(websocket);
4490
4491        tokio::time::sleep(Duration::from_millis(200)).await;
4492
4493        let (stream, _) = listener.accept().await?;
4494        let mut websocket = accept_async(stream).await?;
4495
4496        while let Some(msg) = websocket.next().await {
4497            match msg {
4498                Ok(WsMessage::Text(text)) => {
4499                    messages.lock().await.push(text.to_string());
4500                }
4501                Ok(WsMessage::Close(_)) => {
4502                    let _ = websocket.close(None).await;
4503                    break;
4504                }
4505                Ok(_) => {}
4506                Err(_) => break,
4507            }
4508        }
4509
4510        Ok(())
4511    }
4512}