Skip to main content

nautilus_network/websocket/
client.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2026 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16//! WebSocket client implementation with automatic reconnection.
17//!
18//! This module contains the core WebSocket client implementation including:
19//! - Connection management with automatic reconnection.
20//! - Split read/write architecture with separate tasks.
21//! - Unbounded channels on latency-sensitive paths.
22//! - Heartbeat support.
23//! - Rate limiting integration.
24
25use std::{
26    collections::VecDeque,
27    fmt::Debug,
28    sync::{
29        Arc,
30        atomic::{AtomicU8, Ordering},
31    },
32    time::Duration,
33};
34
35use futures_util::{SinkExt, StreamExt};
36use http::HeaderName;
37use nautilus_core::CleanDrop;
38use nautilus_cryptography::providers::install_cryptographic_provider;
39#[cfg(feature = "turmoil")]
40use tokio_tungstenite::MaybeTlsStream;
41#[cfg(feature = "turmoil")]
42use tokio_tungstenite::client_async;
43#[cfg(not(feature = "turmoil"))]
44use tokio_tungstenite::connect_async_with_config;
45use tokio_tungstenite::tungstenite::{
46    Error, Message, client::IntoClientRequest, http::HeaderValue,
47};
48use ustr::Ustr;
49
50use super::{
51    config::WebSocketConfig,
52    consts::{
53        CONNECTION_STATE_CHECK_INTERVAL_MS, GRACEFUL_SHUTDOWN_DELAY_MS,
54        GRACEFUL_SHUTDOWN_TIMEOUT_SECS, SEND_OPERATION_CHECK_INTERVAL_MS,
55    },
56    types::{MessageHandler, MessageReader, MessageWriter, PingHandler, WriterCommand},
57};
58#[cfg(feature = "turmoil")]
59use crate::net::TcpConnector;
60use crate::{
61    RECONNECTED,
62    backoff::ExponentialBackoff,
63    error::SendError,
64    logging::{log_task_aborted, log_task_started, log_task_stopped},
65    mode::ConnectionMode,
66    ratelimiter::{RateLimiter, clock::MonotonicClock, quota::Quota},
67};
68
69/// `WebSocketClient` connects to a websocket server to read and send messages.
70///
71/// The client is opinionated about how messages are read and written. It
72/// assumes that data can only have one reader but multiple writers.
73///
74/// The client splits the connection into read and write halves. It moves
75/// the read half into a tokio task which keeps receiving messages from the
76/// server and calls a handler - a Python function that takes the data
77/// as its parameter. It stores the write half in the struct wrapped
78/// with an Arc Mutex. This way the client struct can be used to write
79/// data to the server from multiple scopes/tasks.
80///
81/// The client also maintains a heartbeat if given a duration in seconds.
82/// It's preferable to set the duration slightly lower - heartbeat more
83/// frequently - than the required amount.
84pub struct WebSocketClientInner {
85    config: WebSocketConfig,
86    /// The function to handle incoming messages (stored separately from config).
87    message_handler: Option<MessageHandler>,
88    /// The handler for incoming pings (stored separately from config).
89    ping_handler: Option<PingHandler>,
90    read_task: Option<tokio::task::JoinHandle<()>>,
91    write_task: tokio::task::JoinHandle<()>,
92    writer_tx: tokio::sync::mpsc::UnboundedSender<WriterCommand>,
93    heartbeat_task: Option<tokio::task::JoinHandle<()>>,
94    connection_mode: Arc<AtomicU8>,
95    reconnect_timeout: Duration,
96    backoff: ExponentialBackoff,
97    /// True if this is a stream-based client (created via `connect_stream`).
98    /// Stream-based clients disable auto-reconnect because the reader is
99    /// owned by the caller and cannot be replaced during reconnection.
100    is_stream_mode: bool,
101    /// Maximum number of reconnection attempts before giving up (None = unlimited).
102    reconnect_max_attempts: Option<u32>,
103    /// Current count of consecutive reconnection attempts.
104    reconnection_attempt_count: u32,
105}
106
107impl WebSocketClientInner {
108    /// Create an inner websocket client with an existing writer.
109    ///
110    /// This is used for stream mode where the reader is owned by the caller.
111    ///
112    /// # Errors
113    ///
114    /// Returns an error if the exponential backoff configuration is invalid.
115    pub async fn new_with_writer(
116        config: WebSocketConfig,
117        writer: MessageWriter,
118    ) -> Result<Self, Error> {
119        install_cryptographic_provider();
120
121        let connection_mode = Arc::new(AtomicU8::new(ConnectionMode::Active.as_u8()));
122
123        // Note: We don't spawn a read task here since the reader is handled externally
124        let read_task = None;
125
126        let backoff = ExponentialBackoff::new(
127            Duration::from_millis(config.reconnect_delay_initial_ms.unwrap_or(2_000)),
128            Duration::from_millis(config.reconnect_delay_max_ms.unwrap_or(30_000)),
129            config.reconnect_backoff_factor.unwrap_or(1.5),
130            config.reconnect_jitter_ms.unwrap_or(100),
131            true, // immediate-first
132        )
133        .map_err(|e| Error::Io(std::io::Error::new(std::io::ErrorKind::InvalidInput, e)))?;
134
135        let (writer_tx, writer_rx) = tokio::sync::mpsc::unbounded_channel::<WriterCommand>();
136        let write_task = Self::spawn_write_task(connection_mode.clone(), writer, writer_rx);
137
138        let heartbeat_task = if let Some(heartbeat_interval) = config.heartbeat {
139            Some(Self::spawn_heartbeat_task(
140                connection_mode.clone(),
141                heartbeat_interval,
142                config.heartbeat_msg.clone(),
143                writer_tx.clone(),
144            ))
145        } else {
146            None
147        };
148
149        let reconnect_max_attempts = config.reconnect_max_attempts;
150        let reconnect_timeout = Duration::from_millis(config.reconnect_timeout_ms.unwrap_or(10000));
151
152        Ok(Self {
153            config,
154            message_handler: None, // Stream mode has no handler
155            ping_handler: None,
156            writer_tx,
157            connection_mode,
158            reconnect_timeout,
159            heartbeat_task,
160            read_task,
161            write_task,
162            backoff,
163            is_stream_mode: true,
164            reconnect_max_attempts,
165            reconnection_attempt_count: 0,
166        })
167    }
168
169    /// Create an inner websocket client.
170    ///
171    /// # Errors
172    ///
173    /// Returns an error if:
174    /// - The connection to the server fails.
175    /// - The exponential backoff configuration is invalid.
176    pub async fn connect_url(
177        config: WebSocketConfig,
178        message_handler: Option<MessageHandler>,
179        ping_handler: Option<PingHandler>,
180    ) -> Result<Self, Error> {
181        install_cryptographic_provider();
182
183        if config.heartbeat == Some(0) {
184            return Err(Error::Io(std::io::Error::new(
185                std::io::ErrorKind::InvalidInput,
186                "Heartbeat interval cannot be zero",
187            )));
188        }
189
190        if config.idle_timeout_ms == Some(0) {
191            return Err(Error::Io(std::io::Error::new(
192                std::io::ErrorKind::InvalidInput,
193                "Idle timeout cannot be zero",
194            )));
195        }
196
197        // Capture whether we're in stream mode before moving config
198        let is_stream_mode = message_handler.is_none();
199        let reconnect_max_attempts = config.reconnect_max_attempts;
200
201        let (writer, reader) =
202            Self::connect_with_server(&config.url, config.headers.clone()).await?;
203
204        let connection_mode = Arc::new(AtomicU8::new(ConnectionMode::Active.as_u8()));
205
206        let read_task = if message_handler.is_some() {
207            Some(Self::spawn_message_handler_task(
208                connection_mode.clone(),
209                reader,
210                message_handler.as_ref(),
211                ping_handler.as_ref(),
212                config.idle_timeout_ms,
213            ))
214        } else {
215            None
216        };
217
218        let (writer_tx, writer_rx) = tokio::sync::mpsc::unbounded_channel::<WriterCommand>();
219        let write_task = Self::spawn_write_task(connection_mode.clone(), writer, writer_rx);
220
221        // Optionally spawn a heartbeat task to periodically ping server
222        let heartbeat_task = config.heartbeat.map(|heartbeat_secs| {
223            Self::spawn_heartbeat_task(
224                connection_mode.clone(),
225                heartbeat_secs,
226                config.heartbeat_msg.clone(),
227                writer_tx.clone(),
228            )
229        });
230
231        let reconnect_timeout =
232            Duration::from_millis(config.reconnect_timeout_ms.unwrap_or(10_000));
233        let backoff = ExponentialBackoff::new(
234            Duration::from_millis(config.reconnect_delay_initial_ms.unwrap_or(2_000)),
235            Duration::from_millis(config.reconnect_delay_max_ms.unwrap_or(30_000)),
236            config.reconnect_backoff_factor.unwrap_or(1.5),
237            config.reconnect_jitter_ms.unwrap_or(100),
238            true, // immediate-first
239        )
240        .map_err(|e| Error::Io(std::io::Error::new(std::io::ErrorKind::InvalidInput, e)))?;
241
242        Ok(Self {
243            config,
244            message_handler,
245            ping_handler,
246            read_task,
247            write_task,
248            writer_tx,
249            heartbeat_task,
250            connection_mode,
251            reconnect_timeout,
252            backoff,
253            // Set stream mode when no message handler (reader not managed by client)
254            is_stream_mode,
255            reconnect_max_attempts,
256            reconnection_attempt_count: 0,
257        })
258    }
259
260    /// Connects with the server creating a tokio-tungstenite websocket stream.
261    /// Production version that uses `connect_async_with_config` convenience helper.
262    ///
263    /// # Errors
264    ///
265    /// Returns an error if:
266    /// - The URL cannot be parsed into a valid client request.
267    /// - Header values are invalid.
268    /// - The WebSocket connection fails.
269    #[inline]
270    #[cfg(not(feature = "turmoil"))]
271    pub async fn connect_with_server(
272        url: &str,
273        headers: Vec<(String, String)>,
274    ) -> Result<(MessageWriter, MessageReader), Error> {
275        let mut request = url.into_client_request()?;
276        let req_headers = request.headers_mut();
277
278        let mut header_names: Vec<HeaderName> = Vec::new();
279        for (key, val) in headers {
280            let header_value = HeaderValue::from_str(&val)?;
281            let header_name: HeaderName = key.parse()?;
282            header_names.push(header_name.clone());
283            req_headers.insert(header_name, header_value);
284        }
285
286        connect_async_with_config(request, None, true)
287            .await
288            .map(|resp| resp.0.split())
289    }
290
291    /// Connects with the server creating a tokio-tungstenite websocket stream.
292    /// Turmoil version that uses the lower-level `client_async` API with injected stream.
293    ///
294    /// # Errors
295    ///
296    /// Returns an error if:
297    /// - The URL cannot be parsed into a valid client request.
298    /// - The URL is missing a hostname.
299    /// - Header values are invalid.
300    /// - The TCP connection fails.
301    /// - TLS setup fails (for wss:// URLs).
302    /// - The WebSocket handshake fails.
303    #[inline]
304    #[cfg(feature = "turmoil")]
305    pub async fn connect_with_server(
306        url: &str,
307        headers: Vec<(String, String)>,
308    ) -> Result<(MessageWriter, MessageReader), Error> {
309        use rustls::ClientConfig;
310        use tokio_rustls::TlsConnector;
311
312        let mut request = url.into_client_request()?;
313        let req_headers = request.headers_mut();
314
315        let mut header_names: Vec<HeaderName> = Vec::new();
316        for (key, val) in headers {
317            let header_value = HeaderValue::from_str(&val)?;
318            let header_name: HeaderName = key.parse()?;
319            header_names.push(header_name.clone());
320            req_headers.insert(header_name, header_value);
321        }
322
323        let uri = request.uri();
324        let scheme = uri.scheme_str().unwrap_or("ws");
325        let host = uri.host().ok_or_else(|| {
326            Error::Url(tokio_tungstenite::tungstenite::error::UrlError::NoHostName)
327        })?;
328
329        // Determine port: use explicit port if specified, otherwise default based on scheme
330        let port = uri
331            .port_u16()
332            .unwrap_or_else(|| if scheme == "wss" { 443 } else { 80 });
333
334        let addr = format!("{host}:{port}");
335
336        // Use the connector to get a turmoil-compatible stream
337        let connector = crate::net::RealTcpConnector;
338        let tcp_stream = connector.connect(&addr).await?;
339        if let Err(e) = tcp_stream.set_nodelay(true) {
340            log::warn!("Failed to enable TCP_NODELAY for socket client: {e:?}");
341        }
342
343        // Wrap stream appropriately based on scheme
344        let maybe_tls_stream = if scheme == "wss" {
345            // Build TLS config with webpki roots
346            let mut root_store = rustls::RootCertStore::empty();
347            root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
348
349            let config = ClientConfig::builder()
350                .with_root_certificates(root_store)
351                .with_no_client_auth();
352
353            let tls_connector = TlsConnector::from(std::sync::Arc::new(config));
354            let domain =
355                rustls::pki_types::ServerName::try_from(host.to_string()).map_err(|e| {
356                    Error::Io(std::io::Error::new(
357                        std::io::ErrorKind::InvalidInput,
358                        format!("Invalid DNS name: {e}"),
359                    ))
360                })?;
361
362            let tls_stream = tls_connector.connect(domain, tcp_stream).await?;
363            MaybeTlsStream::Rustls(tls_stream)
364        } else {
365            MaybeTlsStream::Plain(tcp_stream)
366        };
367
368        // Use client_async with the stream (plain or TLS)
369        client_async(request, maybe_tls_stream)
370            .await
371            .map(|resp| resp.0.split())
372    }
373
374    /// Reconnect with server.
375    ///
376    /// Make a new connection with server. Use the new read and write halves
377    /// to update self writer and read and heartbeat tasks.
378    ///
379    /// For stream-based clients (created via `connect_stream`), reconnection is disabled
380    /// because the reader is owned by the caller and cannot be replaced. Stream users
381    /// should handle disconnections by creating a new connection.
382    ///
383    /// # Errors
384    ///
385    /// Returns an error if:
386    /// - The reconnection attempt times out.
387    /// - The connection to the server fails.
388    pub async fn reconnect(&mut self) -> Result<(), Error> {
389        log::debug!("Reconnecting");
390
391        if self.is_stream_mode {
392            log::warn!(
393                "Auto-reconnect disabled for stream-based WebSocket client; \
394                stream users must manually reconnect by creating a new connection"
395            );
396            // Transition to CLOSED state to stop reconnection attempts
397            self.connection_mode
398                .store(ConnectionMode::Closed.as_u8(), Ordering::SeqCst);
399            return Ok(());
400        }
401
402        if ConnectionMode::from_atomic(&self.connection_mode).is_disconnect() {
403            log::debug!("Reconnect aborted due to disconnect state");
404            return Ok(());
405        }
406
407        tokio::time::timeout(self.reconnect_timeout, async {
408            // Attempt to connect; abort early if a disconnect was requested
409            let (new_writer, reader) =
410                Self::connect_with_server(&self.config.url, self.config.headers.clone()).await?;
411
412            if ConnectionMode::from_atomic(&self.connection_mode).is_disconnect() {
413                log::debug!("Reconnect aborted mid-flight (after connect)");
414                return Ok(());
415            }
416
417            // Use a oneshot channel to synchronize with the writer task.
418            // We must verify that the buffer was successfully drained before transitioning to ACTIVE
419            // to prevent silent message loss if the new connection drops immediately.
420            let (tx, rx) = tokio::sync::oneshot::channel();
421            if let Err(e) = self.writer_tx.send(WriterCommand::Update(new_writer, tx)) {
422                log::error!("{e}");
423                return Err(Error::Io(std::io::Error::new(
424                    std::io::ErrorKind::BrokenPipe,
425                    format!("Failed to send update command: {e}"),
426                )));
427            }
428
429            // Wait for writer to confirm it has drained the buffer
430            match rx.await {
431                Ok(true) => log::debug!("Writer confirmed buffer drain success"),
432                Ok(false) => {
433                    log::warn!("Writer failed to drain buffer, aborting reconnect");
434                    // Return error to trigger retry logic in controller
435                    return Err(Error::Io(std::io::Error::other(
436                        "Failed to drain reconnection buffer",
437                    )));
438                }
439                Err(e) => {
440                    log::error!("Writer dropped update channel: {e}");
441                    return Err(Error::Io(std::io::Error::new(
442                        std::io::ErrorKind::BrokenPipe,
443                        "Writer task dropped response channel",
444                    )));
445                }
446            }
447
448            // Delay before closing connection
449            tokio::time::sleep(Duration::from_millis(GRACEFUL_SHUTDOWN_DELAY_MS)).await;
450
451            if ConnectionMode::from_atomic(&self.connection_mode).is_disconnect() {
452                log::debug!("Reconnect aborted mid-flight (after delay)");
453                return Ok(());
454            }
455
456            if let Some(ref read_task) = self.read_task.take()
457                && !read_task.is_finished()
458            {
459                read_task.abort();
460                log_task_aborted("read");
461            }
462
463            // Atomically transition from Reconnect to Active
464            // This prevents race condition where disconnect could be requested between check and store
465            if self
466                .connection_mode
467                .compare_exchange(
468                    ConnectionMode::Reconnect.as_u8(),
469                    ConnectionMode::Active.as_u8(),
470                    Ordering::SeqCst,
471                    Ordering::SeqCst,
472                )
473                .is_err()
474            {
475                log::debug!("Reconnect aborted (state changed during reconnect)");
476                return Ok(());
477            }
478
479            self.read_task = if self.message_handler.is_some() {
480                Some(Self::spawn_message_handler_task(
481                    self.connection_mode.clone(),
482                    reader,
483                    self.message_handler.as_ref(),
484                    self.ping_handler.as_ref(),
485                    self.config.idle_timeout_ms,
486                ))
487            } else {
488                None
489            };
490
491            log::debug!("Reconnect succeeded");
492            Ok(())
493        })
494        .await
495        .map_err(|_| {
496            Error::Io(std::io::Error::new(
497                std::io::ErrorKind::TimedOut,
498                format!(
499                    "reconnection timed out after {}s",
500                    self.reconnect_timeout.as_secs_f64()
501                ),
502            ))
503        })?
504    }
505
506    /// Check if the client is still alive.
507    ///
508    /// Returns `true` if both the read and write tasks are still running.
509    /// There may be some delay between the connection closing and the
510    /// client detecting it.
511    #[inline]
512    #[must_use]
513    pub fn is_alive(&self) -> bool {
514        match &self.read_task {
515            Some(read_task) => !read_task.is_finished() && !self.write_task.is_finished(),
516            None => !self.write_task.is_finished(),
517        }
518    }
519
520    fn spawn_message_handler_task(
521        connection_state: Arc<AtomicU8>,
522        mut reader: MessageReader,
523        message_handler: Option<&MessageHandler>,
524        ping_handler: Option<&PingHandler>,
525        idle_timeout_ms: Option<u64>,
526    ) -> tokio::task::JoinHandle<()> {
527        log::debug!("Started message handler task 'read'");
528
529        let check_interval = Duration::from_millis(CONNECTION_STATE_CHECK_INTERVAL_MS);
530        let idle_timeout = idle_timeout_ms.map(Duration::from_millis);
531
532        // Clone Arc handlers for the async task
533        let message_handler = message_handler.cloned();
534        let ping_handler = ping_handler.cloned();
535
536        tokio::task::spawn(async move {
537            let mut last_data_time = tokio::time::Instant::now();
538
539            loop {
540                if !ConnectionMode::from_atomic(&connection_state).is_active() {
541                    break;
542                }
543
544                match tokio::time::timeout(check_interval, reader.next()).await {
545                    Ok(Some(Ok(Message::Binary(data)))) => {
546                        log::trace!("Received message <binary> {} bytes", data.len());
547                        last_data_time = tokio::time::Instant::now();
548                        if let Some(ref handler) = message_handler {
549                            handler(Message::Binary(data));
550                        }
551                    }
552                    Ok(Some(Ok(Message::Text(data)))) => {
553                        log::trace!("Received message: {data}");
554                        last_data_time = tokio::time::Instant::now();
555                        if let Some(ref handler) = message_handler {
556                            handler(Message::Text(data));
557                        }
558                    }
559                    Ok(Some(Ok(Message::Ping(ping_data)))) => {
560                        log::trace!("Received ping: {ping_data:?}");
561                        last_data_time = tokio::time::Instant::now();
562                        if let Some(ref handler) = ping_handler {
563                            handler(ping_data.to_vec());
564                        }
565                    }
566                    Ok(Some(Ok(Message::Pong(_)))) => {
567                        log::trace!("Received pong");
568                        last_data_time = tokio::time::Instant::now();
569                    }
570                    Ok(Some(Ok(Message::Close(_)))) => {
571                        log::debug!("Received close message - terminating");
572                        break;
573                    }
574                    Ok(Some(Ok(_))) => (),
575                    Ok(Some(Err(e))) => {
576                        log::error!("Received error message - terminating: {e}");
577                        break;
578                    }
579                    Ok(None) => {
580                        log::debug!("No message received - terminating");
581                        break;
582                    }
583                    Err(_) => {
584                        if let Some(timeout) = idle_timeout {
585                            let idle_duration = last_data_time.elapsed();
586                            if idle_duration >= timeout {
587                                log::warn!(
588                                    "Read idle timeout: no data received for {:.1}s",
589                                    idle_duration.as_secs_f64()
590                                );
591                                break;
592                            }
593                        }
594                        continue;
595                    }
596                }
597            }
598        })
599    }
600
601    /// Attempts to send all buffered messages after reconnection.
602    ///
603    /// Returns `true` if a send error occurred (caller should trigger reconnection).
604    /// Messages remain in buffer if send fails, preserving them for the next reconnection attempt.
605    async fn drain_reconnect_buffer(
606        buffer: &mut VecDeque<Message>,
607        writer: &mut MessageWriter,
608    ) -> bool {
609        if buffer.is_empty() {
610            return false;
611        }
612
613        let initial_buffer_len = buffer.len();
614        log::info!("Sending {initial_buffer_len} buffered messages after reconnection");
615
616        let mut send_error_occurred = false;
617
618        while let Some(buffered_msg) = buffer.front() {
619            // Clone message before attempting send (to keep in buffer if send fails)
620            let msg_to_send = buffered_msg.clone();
621
622            if let Err(e) = writer.send(msg_to_send).await {
623                log::error!(
624                    "Failed to send buffered message after reconnection: {e}, {} messages remain in buffer",
625                    buffer.len()
626                );
627                send_error_occurred = true;
628                break; // Stop processing buffer, remaining messages preserved for next reconnection
629            }
630
631            // Only remove from buffer after successful send
632            buffer.pop_front();
633        }
634
635        if buffer.is_empty() {
636            log::info!("Successfully sent all {initial_buffer_len} buffered messages");
637        }
638
639        send_error_occurred
640    }
641
642    fn spawn_write_task(
643        connection_state: Arc<AtomicU8>,
644        writer: MessageWriter,
645        mut writer_rx: tokio::sync::mpsc::UnboundedReceiver<WriterCommand>,
646    ) -> tokio::task::JoinHandle<()> {
647        log_task_started("write");
648
649        // Interval between checking the connection mode
650        let check_interval = Duration::from_millis(CONNECTION_STATE_CHECK_INTERVAL_MS);
651
652        tokio::task::spawn(async move {
653            let mut active_writer = writer;
654            // Buffer for messages received during reconnection
655            // VecDeque for efficient pop_front() operations
656            let mut reconnect_buffer: VecDeque<Message> = VecDeque::new();
657
658            loop {
659                match ConnectionMode::from_atomic(&connection_state) {
660                    ConnectionMode::Disconnect => {
661                        // Log any buffered messages that will be lost
662                        if !reconnect_buffer.is_empty() {
663                            log::warn!(
664                                "Discarding {} buffered messages due to disconnect",
665                                reconnect_buffer.len()
666                            );
667                            reconnect_buffer.clear();
668                        }
669
670                        // Attempt to close the writer gracefully before exiting,
671                        // we ignore any error as the writer may already be closed.
672                        _ = tokio::time::timeout(
673                            Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS),
674                            active_writer.close(),
675                        )
676                        .await;
677                        break;
678                    }
679                    ConnectionMode::Closed => {
680                        // Log any buffered messages that will be lost
681                        if !reconnect_buffer.is_empty() {
682                            log::warn!(
683                                "Discarding {} buffered messages due to closed connection",
684                                reconnect_buffer.len()
685                            );
686                            reconnect_buffer.clear();
687                        }
688                        break;
689                    }
690                    _ => {}
691                }
692
693                match tokio::time::timeout(check_interval, writer_rx.recv()).await {
694                    Ok(Some(msg)) => {
695                        // Re-check connection mode after receiving a message
696                        let mode = ConnectionMode::from_atomic(&connection_state);
697                        if matches!(mode, ConnectionMode::Disconnect | ConnectionMode::Closed) {
698                            break;
699                        }
700
701                        match msg {
702                            WriterCommand::Update(new_writer, tx) => {
703                                log::debug!("Received new writer");
704
705                                // Delay before closing connection
706                                tokio::time::sleep(Duration::from_millis(100)).await;
707
708                                // Attempt to close the writer gracefully on update,
709                                // we ignore any error as the writer may already be closed.
710                                _ = tokio::time::timeout(
711                                    Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS),
712                                    active_writer.close(),
713                                )
714                                .await;
715
716                                active_writer = new_writer;
717                                log::debug!("Updated writer");
718
719                                let send_error = Self::drain_reconnect_buffer(
720                                    &mut reconnect_buffer,
721                                    &mut active_writer,
722                                )
723                                .await;
724
725                                if let Err(e) = tx.send(!send_error) {
726                                    log::error!(
727                                        "Failed to report drain status to controller: {e:?}"
728                                    );
729                                }
730                            }
731                            WriterCommand::Send(msg) if mode.is_reconnect() => {
732                                // Buffer messages during reconnection instead of dropping them
733                                log::debug!(
734                                    "Buffering message during reconnection (buffer size: {})",
735                                    reconnect_buffer.len() + 1
736                                );
737                                reconnect_buffer.push_back(msg);
738                            }
739                            WriterCommand::Send(msg) => {
740                                if let Err(e) = active_writer.send(msg.clone()).await {
741                                    log::error!("Failed to send message: {e}");
742                                    log::warn!("Writer triggering reconnect");
743                                    reconnect_buffer.push_back(msg);
744                                    connection_state
745                                        .store(ConnectionMode::Reconnect.as_u8(), Ordering::SeqCst);
746                                }
747                            }
748                        }
749                    }
750                    Ok(None) => {
751                        // Channel closed - writer task should terminate
752                        log::debug!("Writer channel closed, terminating writer task");
753                        break;
754                    }
755                    Err(_) => {
756                        // Timeout - just continue the loop
757                        continue;
758                    }
759                }
760            }
761
762            // Attempt to close the writer gracefully before exiting,
763            // we ignore any error as the writer may already be closed.
764            _ = tokio::time::timeout(
765                Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS),
766                active_writer.close(),
767            )
768            .await;
769
770            log_task_stopped("write");
771        })
772    }
773
774    fn spawn_heartbeat_task(
775        connection_state: Arc<AtomicU8>,
776        heartbeat_secs: u64,
777        message: Option<String>,
778        writer_tx: tokio::sync::mpsc::UnboundedSender<WriterCommand>,
779    ) -> tokio::task::JoinHandle<()> {
780        log_task_started("heartbeat");
781
782        tokio::task::spawn(async move {
783            let interval = Duration::from_secs(heartbeat_secs);
784
785            loop {
786                tokio::time::sleep(interval).await;
787
788                match ConnectionMode::from_u8(connection_state.load(Ordering::SeqCst)) {
789                    ConnectionMode::Active => {
790                        let msg = match &message {
791                            Some(text) => WriterCommand::Send(Message::Text(text.clone().into())),
792                            None => WriterCommand::Send(Message::Ping(vec![].into())),
793                        };
794
795                        match writer_tx.send(msg) {
796                            Ok(()) => log::trace!("Sent heartbeat to writer task"),
797                            Err(e) => {
798                                log::error!("Failed to send heartbeat to writer task: {e}");
799                            }
800                        }
801                    }
802                    ConnectionMode::Reconnect => continue,
803                    ConnectionMode::Disconnect | ConnectionMode::Closed => break,
804                }
805            }
806
807            log_task_stopped("heartbeat");
808        })
809    }
810}
811
812impl Drop for WebSocketClientInner {
813    fn drop(&mut self) {
814        // Delegate to explicit cleanup handler
815        self.clean_drop();
816    }
817}
818
819/// Cleanup on drop: aborts background tasks and clears handlers to break reference cycles.
820impl CleanDrop for WebSocketClientInner {
821    fn clean_drop(&mut self) {
822        if let Some(ref read_task) = self.read_task.take()
823            && !read_task.is_finished()
824        {
825            read_task.abort();
826            log_task_aborted("read");
827        }
828
829        if !self.write_task.is_finished() {
830            self.write_task.abort();
831            log_task_aborted("write");
832        }
833
834        if let Some(ref handle) = self.heartbeat_task.take()
835            && !handle.is_finished()
836        {
837            handle.abort();
838            log_task_aborted("heartbeat");
839        }
840
841        // Clear handlers to break potential reference cycles
842        self.message_handler = None;
843        self.ping_handler = None;
844    }
845}
846
847impl Debug for WebSocketClientInner {
848    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
849        f.debug_struct(stringify!(WebSocketClientInner))
850            .field("config", &self.config)
851            .field(
852                "connection_mode",
853                &ConnectionMode::from_atomic(&self.connection_mode),
854            )
855            .field("reconnect_timeout", &self.reconnect_timeout)
856            .field("is_stream_mode", &self.is_stream_mode)
857            .finish()
858    }
859}
860
861/// WebSocket client with automatic reconnection.
862///
863/// Handles connection state, callbacks, and rate limiting.
864/// See module docs for architecture details.
865#[cfg_attr(
866    feature = "python",
867    pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.network")
868)]
869pub struct WebSocketClient {
870    pub(crate) controller_task: tokio::task::JoinHandle<()>,
871    pub(crate) connection_mode: Arc<AtomicU8>,
872    pub(crate) reconnect_timeout: Duration,
873    pub(crate) rate_limiter: Arc<RateLimiter<Ustr, MonotonicClock>>,
874    pub(crate) writer_tx: tokio::sync::mpsc::UnboundedSender<WriterCommand>,
875}
876
877impl Debug for WebSocketClient {
878    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
879        f.debug_struct(stringify!(WebSocketClient)).finish()
880    }
881}
882
883impl WebSocketClient {
884    /// Creates a websocket client in **stream mode** that returns a [`MessageReader`].
885    ///
886    /// Returns a stream that the caller owns and reads from directly. Automatic reconnection
887    /// is **disabled** because the reader cannot be replaced internally. On disconnection, the
888    /// client transitions to CLOSED state and the caller must manually reconnect by calling
889    /// `connect_stream` again.
890    ///
891    /// Use stream mode when you need custom reconnection logic, direct control over message
892    /// reading, or fine-grained backpressure handling.
893    ///
894    /// See [`WebSocketConfig`] documentation for comparison with handler mode.
895    ///
896    /// # Errors
897    ///
898    /// Returns an error if the connection cannot be established.
899    #[allow(clippy::too_many_arguments)]
900    pub async fn connect_stream(
901        config: WebSocketConfig,
902        keyed_quotas: Vec<(String, Quota)>,
903        default_quota: Option<Quota>,
904        post_reconnect: Option<Arc<dyn Fn() + Send + Sync>>,
905    ) -> Result<(MessageReader, Self), Error> {
906        install_cryptographic_provider();
907
908        // Create a single connection and split it, respecting configured headers
909        let (writer, reader) =
910            WebSocketClientInner::connect_with_server(&config.url, config.headers.clone()).await?;
911
912        // Create inner without connecting (we'll provide the writer)
913        let inner = WebSocketClientInner::new_with_writer(config, writer).await?;
914
915        let connection_mode = inner.connection_mode.clone();
916        let reconnect_timeout = inner.reconnect_timeout;
917        let keyed_quotas = keyed_quotas
918            .into_iter()
919            .map(|(key, quota)| (Ustr::from(&key), quota))
920            .collect();
921        let rate_limiter = Arc::new(RateLimiter::new_with_quota(default_quota, keyed_quotas));
922        let writer_tx = inner.writer_tx.clone();
923
924        let controller_task =
925            Self::spawn_controller_task(inner, connection_mode.clone(), post_reconnect);
926
927        Ok((
928            reader,
929            Self {
930                controller_task,
931                connection_mode,
932                reconnect_timeout,
933                rate_limiter,
934                writer_tx,
935            },
936        ))
937    }
938
939    /// Creates a websocket client in **handler mode** with automatic reconnection.
940    ///
941    /// The handler is called for each incoming message on an internal task.
942    /// Automatic reconnection is **enabled** with exponential backoff. On disconnection,
943    /// the client automatically attempts to reconnect and replaces the internal reader
944    /// (the handler continues working seamlessly).
945    ///
946    /// Use handler mode for simplified connection management, automatic reconnection, Python
947    /// bindings, or callback-based message handling.
948    ///
949    /// See [`WebSocketConfig`] documentation for comparison with stream mode.
950    ///
951    /// # Errors
952    ///
953    /// Returns an error if:
954    /// - The connection cannot be established.
955    /// - `message_handler` is `None` (use `connect_stream` instead).
956    pub async fn connect(
957        config: WebSocketConfig,
958        message_handler: Option<MessageHandler>,
959        ping_handler: Option<PingHandler>,
960        post_reconnection: Option<Arc<dyn Fn() + Send + Sync>>,
961        keyed_quotas: Vec<(String, Quota)>,
962        default_quota: Option<Quota>,
963    ) -> Result<Self, Error> {
964        // Validate that handler mode has a message handler
965        if message_handler.is_none() {
966            return Err(Error::Io(std::io::Error::new(
967                std::io::ErrorKind::InvalidInput,
968                "Handler mode requires message_handler to be set. Use connect_stream() for stream mode without a handler.",
969            )));
970        }
971
972        log::debug!("Connecting");
973        let inner =
974            WebSocketClientInner::connect_url(config, message_handler, ping_handler).await?;
975        let connection_mode = inner.connection_mode.clone();
976        let writer_tx = inner.writer_tx.clone();
977        let reconnect_timeout = inner.reconnect_timeout;
978
979        let controller_task =
980            Self::spawn_controller_task(inner, connection_mode.clone(), post_reconnection);
981
982        let keyed_quotas = keyed_quotas
983            .into_iter()
984            .map(|(key, quota)| (Ustr::from(&key), quota))
985            .collect();
986        let rate_limiter = Arc::new(RateLimiter::new_with_quota(default_quota, keyed_quotas));
987
988        Ok(Self {
989            controller_task,
990            connection_mode,
991            reconnect_timeout,
992            rate_limiter,
993            writer_tx,
994        })
995    }
996
997    /// Returns the current connection mode.
998    #[must_use]
999    pub fn connection_mode(&self) -> ConnectionMode {
1000        ConnectionMode::from_atomic(&self.connection_mode)
1001    }
1002
1003    /// Returns a clone of the connection mode atomic for external state tracking.
1004    ///
1005    /// This allows adapter clients to track connection state across reconnections
1006    /// without message-passing delays.
1007    #[must_use]
1008    pub fn connection_mode_atomic(&self) -> Arc<AtomicU8> {
1009        Arc::clone(&self.connection_mode)
1010    }
1011
1012    /// Check if the client connection is active.
1013    ///
1014    /// Returns `true` if the client is connected and has not been signalled to disconnect.
1015    /// The client will automatically retry connection based on its configuration.
1016    #[inline]
1017    #[must_use]
1018    pub fn is_active(&self) -> bool {
1019        self.connection_mode().is_active()
1020    }
1021
1022    /// Check if the client is disconnected.
1023    #[must_use]
1024    pub fn is_disconnected(&self) -> bool {
1025        self.controller_task.is_finished()
1026    }
1027
1028    /// Check if the client is reconnecting.
1029    ///
1030    /// Returns `true` if the client lost connection and is attempting to reestablish it.
1031    /// The client will automatically retry connection based on its configuration.
1032    #[inline]
1033    #[must_use]
1034    pub fn is_reconnecting(&self) -> bool {
1035        self.connection_mode().is_reconnect()
1036    }
1037
1038    /// Check if the client is disconnecting.
1039    ///
1040    /// Returns `true` if the client is in disconnect mode.
1041    #[inline]
1042    #[must_use]
1043    pub fn is_disconnecting(&self) -> bool {
1044        self.connection_mode().is_disconnect()
1045    }
1046
1047    /// Check if the client is closed.
1048    ///
1049    /// Returns `true` if the client has been explicitly disconnected or reached
1050    /// maximum reconnection attempts. In this state, the client cannot be reused
1051    /// and a new client must be created for further connections.
1052    #[inline]
1053    #[must_use]
1054    pub fn is_closed(&self) -> bool {
1055        self.connection_mode().is_closed()
1056    }
1057
1058    /// Wait for the client to become active before sending.
1059    ///
1060    /// Returns an error if the client is closed, disconnecting, or if the wait times out.
1061    async fn wait_for_active(&self) -> Result<(), SendError> {
1062        if self.is_closed() {
1063            return Err(SendError::Closed);
1064        }
1065
1066        let timeout = self.reconnect_timeout;
1067        let check_interval = Duration::from_millis(SEND_OPERATION_CHECK_INTERVAL_MS);
1068
1069        if !self.is_active() {
1070            log::debug!("Waiting for client to become ACTIVE before sending...");
1071
1072            let inner = tokio::time::timeout(timeout, async {
1073                loop {
1074                    if self.is_active() {
1075                        return Ok(());
1076                    }
1077                    if matches!(
1078                        self.connection_mode(),
1079                        ConnectionMode::Disconnect | ConnectionMode::Closed
1080                    ) {
1081                        return Err(());
1082                    }
1083                    tokio::time::sleep(check_interval).await;
1084                }
1085            })
1086            .await
1087            .map_err(|_| SendError::Timeout)?;
1088            inner.map_err(|()| SendError::Closed)?;
1089        }
1090
1091        Ok(())
1092    }
1093
1094    /// Set disconnect mode to true.
1095    ///
1096    /// Controller task will periodically check the disconnect mode
1097    /// and shutdown the client if it is alive
1098    pub async fn disconnect(&self) {
1099        log::debug!("Disconnecting");
1100        self.connection_mode
1101            .store(ConnectionMode::Disconnect.as_u8(), Ordering::SeqCst);
1102
1103        if tokio::time::timeout(Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS), async {
1104            while !self.is_disconnected() {
1105                tokio::time::sleep(Duration::from_millis(CONNECTION_STATE_CHECK_INTERVAL_MS)).await;
1106            }
1107
1108            if !self.controller_task.is_finished() {
1109                self.controller_task.abort();
1110                log_task_aborted("controller");
1111            }
1112        })
1113        .await
1114            == Ok(())
1115        {
1116            log::debug!("Controller task finished");
1117        } else {
1118            log::error!("Timeout waiting for controller task to finish");
1119            if !self.controller_task.is_finished() {
1120                self.controller_task.abort();
1121                log_task_aborted("controller");
1122            }
1123            self.connection_mode
1124                .store(ConnectionMode::Closed.as_u8(), Ordering::SeqCst);
1125        }
1126    }
1127
1128    /// Sends the given text `data` to the server.
1129    ///
1130    /// # Errors
1131    ///
1132    /// Returns a websocket error if unable to send.
1133    #[allow(unused_variables)]
1134    pub async fn send_text(&self, data: String, keys: Option<&[Ustr]>) -> Result<(), SendError> {
1135        // Check connection state before rate limiting to fail fast
1136        if self.is_closed() || self.is_disconnecting() {
1137            return Err(SendError::Closed);
1138        }
1139
1140        self.rate_limiter.await_keys_ready(keys).await;
1141        self.wait_for_active().await?;
1142
1143        log::trace!("Sending text: {data:?}");
1144
1145        let msg = Message::Text(data.into());
1146        self.writer_tx
1147            .send(WriterCommand::Send(msg))
1148            .map_err(|e| SendError::BrokenPipe(e.to_string()))
1149    }
1150
1151    /// Sends a pong frame back to the server.
1152    ///
1153    /// # Errors
1154    ///
1155    /// Returns a websocket error if unable to send.
1156    pub async fn send_pong(&self, data: Vec<u8>) -> Result<(), SendError> {
1157        self.wait_for_active().await?;
1158
1159        log::trace!("Sending pong frame ({} bytes)", data.len());
1160
1161        let msg = Message::Pong(data.into());
1162        self.writer_tx
1163            .send(WriterCommand::Send(msg))
1164            .map_err(|e| SendError::BrokenPipe(e.to_string()))
1165    }
1166
1167    /// Sends the given bytes `data` to the server.
1168    ///
1169    /// # Errors
1170    ///
1171    /// Returns a websocket error if unable to send.
1172    #[allow(unused_variables)]
1173    pub async fn send_bytes(&self, data: Vec<u8>, keys: Option<&[Ustr]>) -> Result<(), SendError> {
1174        // Check connection state before rate limiting to fail fast
1175        if self.is_closed() || self.is_disconnecting() {
1176            return Err(SendError::Closed);
1177        }
1178
1179        self.rate_limiter.await_keys_ready(keys).await;
1180        self.wait_for_active().await?;
1181
1182        log::trace!("Sending bytes: {data:?}");
1183
1184        let msg = Message::Binary(data.into());
1185        self.writer_tx
1186            .send(WriterCommand::Send(msg))
1187            .map_err(|e| SendError::BrokenPipe(e.to_string()))
1188    }
1189
1190    /// Sends a close message to the server.
1191    ///
1192    /// # Errors
1193    ///
1194    /// Returns a websocket error if unable to send.
1195    pub async fn send_close_message(&self) -> Result<(), SendError> {
1196        self.wait_for_active().await?;
1197
1198        let msg = Message::Close(None);
1199        self.writer_tx
1200            .send(WriterCommand::Send(msg))
1201            .map_err(|e| SendError::BrokenPipe(e.to_string()))
1202    }
1203
1204    fn spawn_controller_task(
1205        mut inner: WebSocketClientInner,
1206        connection_mode: Arc<AtomicU8>,
1207        post_reconnection: Option<Arc<dyn Fn() + Send + Sync>>,
1208    ) -> tokio::task::JoinHandle<()> {
1209        tokio::task::spawn(async move {
1210            log_task_started("controller");
1211
1212            let check_interval = Duration::from_millis(CONNECTION_STATE_CHECK_INTERVAL_MS);
1213
1214            loop {
1215                tokio::time::sleep(check_interval).await;
1216                let mut mode = ConnectionMode::from_atomic(&connection_mode);
1217
1218                if mode.is_disconnect() {
1219                    log::debug!("Disconnecting");
1220
1221                    let timeout = Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS);
1222                    if tokio::time::timeout(timeout, async {
1223                        // Delay awaiting graceful shutdown
1224                        tokio::time::sleep(Duration::from_millis(GRACEFUL_SHUTDOWN_DELAY_MS)).await;
1225
1226                        if let Some(task) = &inner.read_task
1227                            && !task.is_finished()
1228                        {
1229                            task.abort();
1230                            log_task_aborted("read");
1231                        }
1232
1233                        if let Some(task) = &inner.heartbeat_task
1234                            && !task.is_finished()
1235                        {
1236                            task.abort();
1237                            log_task_aborted("heartbeat");
1238                        }
1239                    })
1240                    .await
1241                    .is_err()
1242                    {
1243                        log::error!("Shutdown timed out after {}s", timeout.as_secs());
1244                    }
1245
1246                    log::debug!("Closed");
1247                    break; // Controller finished
1248                }
1249
1250                if mode.is_closed() {
1251                    log::debug!("Connection closed");
1252                    break;
1253                }
1254
1255                if mode.is_active() && !inner.is_alive() {
1256                    if connection_mode
1257                        .compare_exchange(
1258                            ConnectionMode::Active.as_u8(),
1259                            ConnectionMode::Reconnect.as_u8(),
1260                            Ordering::SeqCst,
1261                            Ordering::SeqCst,
1262                        )
1263                        .is_ok()
1264                    {
1265                        log::debug!("Detected dead read task, transitioning to RECONNECT");
1266                    }
1267                    mode = ConnectionMode::from_atomic(&connection_mode);
1268                }
1269
1270                if mode.is_reconnect() {
1271                    // Check if max reconnection attempts exceeded
1272                    if let Some(max_attempts) = inner.reconnect_max_attempts
1273                        && inner.reconnection_attempt_count >= max_attempts
1274                    {
1275                        log::error!(
1276                            "Max reconnection attempts ({max_attempts}) exceeded, transitioning to CLOSED"
1277                        );
1278                        connection_mode.store(ConnectionMode::Closed.as_u8(), Ordering::SeqCst);
1279                        break;
1280                    }
1281
1282                    inner.reconnection_attempt_count += 1;
1283                    log::debug!(
1284                        "Reconnection attempt {} of {}",
1285                        inner.reconnection_attempt_count,
1286                        inner
1287                            .reconnect_max_attempts
1288                            .map_or_else(|| "unlimited".to_string(), |m| m.to_string())
1289                    );
1290
1291                    match inner.reconnect().await {
1292                        Ok(()) => {
1293                            inner.backoff.reset();
1294                            inner.reconnection_attempt_count = 0; // Reset counter on success
1295
1296                            // Only invoke callbacks if not in disconnect state
1297                            if ConnectionMode::from_atomic(&connection_mode).is_active() {
1298                                if let Some(ref handler) = inner.message_handler {
1299                                    let reconnected_msg =
1300                                        Message::Text(RECONNECTED.to_string().into());
1301                                    handler(reconnected_msg);
1302                                    log::debug!("Sent reconnected message to handler");
1303                                }
1304
1305                                // TODO: Retain this legacy callback for use from Python
1306                                if let Some(ref callback) = post_reconnection {
1307                                    callback();
1308                                    log::debug!("Called `post_reconnection` handler");
1309                                }
1310
1311                                log::debug!("Reconnected successfully");
1312                            } else {
1313                                log::debug!(
1314                                    "Skipping post_reconnection handlers due to disconnect state"
1315                                );
1316                            }
1317                        }
1318                        Err(e) => {
1319                            let duration = inner.backoff.next_duration();
1320                            log::warn!(
1321                                "Reconnect attempt {} failed: {e}",
1322                                inner.reconnection_attempt_count
1323                            );
1324                            if !duration.is_zero() {
1325                                log::warn!("Backing off for {}s...", duration.as_secs_f64());
1326                            }
1327                            tokio::time::sleep(duration).await;
1328                        }
1329                    }
1330                }
1331            }
1332            inner
1333                .connection_mode
1334                .store(ConnectionMode::Closed.as_u8(), Ordering::SeqCst);
1335
1336            log_task_stopped("controller");
1337        })
1338    }
1339}
1340
1341// Abort controller task on drop to clean up background tasks
1342impl Drop for WebSocketClient {
1343    fn drop(&mut self) {
1344        if !self.controller_task.is_finished() {
1345            self.controller_task.abort();
1346            log_task_aborted("controller");
1347        }
1348    }
1349}
1350
1351#[cfg(test)]
1352#[cfg(not(feature = "turmoil"))]
1353#[cfg(target_os = "linux")] // Only run network tests on Linux (CI stability)
1354mod tests {
1355    use std::{num::NonZeroU32, sync::Arc};
1356
1357    use futures_util::{SinkExt, StreamExt};
1358    use tokio::{
1359        net::TcpListener,
1360        task::{self, JoinHandle},
1361    };
1362    use tokio_tungstenite::{
1363        accept_hdr_async,
1364        tungstenite::{
1365            handshake::server::{self, Callback},
1366            http::HeaderValue,
1367        },
1368    };
1369
1370    use crate::{
1371        ratelimiter::quota::Quota,
1372        websocket::{WebSocketClient, WebSocketConfig},
1373    };
1374
1375    struct TestServer {
1376        task: JoinHandle<()>,
1377        port: u16,
1378    }
1379
1380    #[derive(Debug, Clone)]
1381    struct TestCallback {
1382        key: String,
1383        value: HeaderValue,
1384    }
1385
1386    impl Callback for TestCallback {
1387        #[allow(clippy::panic_in_result_fn)]
1388        fn on_request(
1389            self,
1390            request: &server::Request,
1391            response: server::Response,
1392        ) -> Result<server::Response, server::ErrorResponse> {
1393            let _ = response;
1394            let value = request.headers().get(&self.key);
1395            assert!(value.is_some());
1396
1397            if let Some(value) = request.headers().get(&self.key) {
1398                assert_eq!(value, self.value);
1399            }
1400
1401            Ok(response)
1402        }
1403    }
1404
1405    impl TestServer {
1406        async fn setup() -> Self {
1407            let server = TcpListener::bind("127.0.0.1:0").await.unwrap();
1408            let port = TcpListener::local_addr(&server).unwrap().port();
1409
1410            let header_key = "test".to_string();
1411            let header_value = "test".to_string();
1412
1413            let test_call_back = TestCallback {
1414                key: header_key,
1415                value: HeaderValue::from_str(&header_value).unwrap(),
1416            };
1417
1418            let task = task::spawn(async move {
1419                // Keep accepting connections
1420                loop {
1421                    let (conn, _) = server.accept().await.unwrap();
1422                    let mut websocket = accept_hdr_async(conn, test_call_back.clone())
1423                        .await
1424                        .unwrap();
1425
1426                    task::spawn(async move {
1427                        while let Some(Ok(msg)) = websocket.next().await {
1428                            match msg {
1429                                tokio_tungstenite::tungstenite::protocol::Message::Text(txt)
1430                                    if txt == "close-now" =>
1431                                {
1432                                    log::debug!("Forcibly closing from server side");
1433                                    // This sends a close frame, then stops reading
1434                                    let _ = websocket.close(None).await;
1435                                    break;
1436                                }
1437                                // Echo text/binary frames
1438                                tokio_tungstenite::tungstenite::protocol::Message::Text(_)
1439                                | tokio_tungstenite::tungstenite::protocol::Message::Binary(_) => {
1440                                    if websocket.send(msg).await.is_err() {
1441                                        break;
1442                                    }
1443                                }
1444                                // If the client closes, we also break
1445                                tokio_tungstenite::tungstenite::protocol::Message::Close(
1446                                    _frame,
1447                                ) => {
1448                                    let _ = websocket.close(None).await;
1449                                    break;
1450                                }
1451                                // Ignore pings/pongs
1452                                _ => {}
1453                            }
1454                        }
1455                    });
1456                }
1457            });
1458
1459            Self { task, port }
1460        }
1461    }
1462
1463    impl Drop for TestServer {
1464        fn drop(&mut self) {
1465            self.task.abort();
1466        }
1467    }
1468
1469    async fn setup_test_client(port: u16) -> WebSocketClient {
1470        let config = WebSocketConfig {
1471            url: format!("ws://127.0.0.1:{port}"),
1472            headers: vec![("test".into(), "test".into())],
1473            heartbeat: None,
1474            heartbeat_msg: None,
1475            reconnect_timeout_ms: None,
1476            reconnect_delay_initial_ms: None,
1477            reconnect_backoff_factor: None,
1478            reconnect_delay_max_ms: None,
1479            reconnect_jitter_ms: None,
1480            reconnect_max_attempts: None,
1481            idle_timeout_ms: None,
1482        };
1483        WebSocketClient::connect(config, Some(Arc::new(|_| {})), None, None, vec![], None)
1484            .await
1485            .expect("Failed to connect")
1486    }
1487
1488    #[tokio::test]
1489    async fn test_websocket_basic() {
1490        let server = TestServer::setup().await;
1491        let client = setup_test_client(server.port).await;
1492
1493        assert!(!client.is_disconnected());
1494
1495        client.disconnect().await;
1496        assert!(client.is_disconnected());
1497    }
1498
1499    #[tokio::test]
1500    async fn test_websocket_heartbeat() {
1501        let server = TestServer::setup().await;
1502        let client = setup_test_client(server.port).await;
1503
1504        // Wait ~3s => server should see multiple "ping"
1505        tokio::time::sleep(std::time::Duration::from_secs(3)).await;
1506
1507        // Cleanup
1508        client.disconnect().await;
1509        assert!(client.is_disconnected());
1510    }
1511
1512    #[tokio::test]
1513    async fn test_websocket_reconnect_exhausted() {
1514        let config = WebSocketConfig {
1515            url: "ws://127.0.0.1:9997".into(), // <-- No server
1516            headers: vec![],
1517            heartbeat: None,
1518            heartbeat_msg: None,
1519            reconnect_timeout_ms: None,
1520            reconnect_delay_initial_ms: None,
1521            reconnect_backoff_factor: None,
1522            reconnect_delay_max_ms: None,
1523            reconnect_jitter_ms: None,
1524            reconnect_max_attempts: None,
1525            idle_timeout_ms: None,
1526        };
1527        let res =
1528            WebSocketClient::connect(config, Some(Arc::new(|_| {})), None, None, vec![], None)
1529                .await;
1530        assert!(res.is_err(), "Should fail quickly with no server");
1531    }
1532
1533    #[tokio::test]
1534    async fn test_websocket_forced_close_reconnect() {
1535        let server = TestServer::setup().await;
1536        let client = setup_test_client(server.port).await;
1537
1538        // 1) Send normal message
1539        client.send_text("Hello".into(), None).await.unwrap();
1540
1541        // 2) Trigger forced close from server
1542        client.send_text("close-now".into(), None).await.unwrap();
1543
1544        // 3) Wait a bit => read loop sees close => reconnect
1545        tokio::time::sleep(std::time::Duration::from_secs(1)).await;
1546
1547        // Confirm not disconnected
1548        assert!(!client.is_disconnected());
1549
1550        // Cleanup
1551        client.disconnect().await;
1552        assert!(client.is_disconnected());
1553    }
1554
1555    #[tokio::test]
1556    async fn test_rate_limiter() {
1557        let server = TestServer::setup().await;
1558        let quota = Quota::per_second(NonZeroU32::new(2).unwrap()).unwrap();
1559
1560        let config = WebSocketConfig {
1561            url: format!("ws://127.0.0.1:{}", server.port),
1562            headers: vec![("test".into(), "test".into())],
1563            heartbeat: None,
1564            heartbeat_msg: None,
1565            reconnect_timeout_ms: None,
1566            reconnect_delay_initial_ms: None,
1567            reconnect_backoff_factor: None,
1568            reconnect_delay_max_ms: None,
1569            reconnect_jitter_ms: None,
1570            reconnect_max_attempts: None,
1571            idle_timeout_ms: None,
1572        };
1573
1574        let client = WebSocketClient::connect(
1575            config,
1576            Some(Arc::new(|_| {})),
1577            None,
1578            None,
1579            vec![("default".into(), quota)],
1580            None,
1581        )
1582        .await
1583        .unwrap();
1584
1585        // First 2 should succeed
1586        client.send_text("test1".into(), None).await.unwrap();
1587        client.send_text("test2".into(), None).await.unwrap();
1588
1589        // Third should error
1590        client.send_text("test3".into(), None).await.unwrap();
1591
1592        // Cleanup
1593        client.disconnect().await;
1594        assert!(client.is_disconnected());
1595    }
1596
1597    #[tokio::test]
1598    async fn test_concurrent_writers() {
1599        let server = TestServer::setup().await;
1600        let client = Arc::new(setup_test_client(server.port).await);
1601
1602        let mut handles = vec![];
1603        for i in 0..10 {
1604            let client = client.clone();
1605            handles.push(task::spawn(async move {
1606                client.send_text(format!("test{i}"), None).await.unwrap();
1607            }));
1608        }
1609
1610        for handle in handles {
1611            handle.await.unwrap();
1612        }
1613
1614        // Cleanup
1615        client.disconnect().await;
1616        assert!(client.is_disconnected());
1617    }
1618}
1619
1620#[cfg(test)]
1621#[cfg(not(feature = "turmoil"))]
1622mod rust_tests {
1623    use futures_util::{SinkExt, StreamExt};
1624    use nautilus_common::testing::wait_until_async;
1625    use rstest::rstest;
1626    use tokio::{
1627        net::TcpListener,
1628        task,
1629        time::{Duration, sleep},
1630    };
1631    use tokio_tungstenite::accept_async;
1632
1633    use super::*;
1634    use crate::websocket::types::channel_message_handler;
1635
1636    #[rstest]
1637    #[tokio::test]
1638    async fn test_reconnect_then_disconnect() {
1639        // Bind an ephemeral port
1640        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1641        let port = listener.local_addr().unwrap().port();
1642
1643        // Server task: accept one ws connection then close it
1644        let server = task::spawn(async move {
1645            let (stream, _) = listener.accept().await.unwrap();
1646            let ws = accept_async(stream).await.unwrap();
1647            drop(ws);
1648            // Keep alive briefly
1649            sleep(Duration::from_secs(1)).await;
1650        });
1651
1652        // Build a channel-based message handler for incoming messages (unused here)
1653        let (handler, _rx) = channel_message_handler();
1654
1655        // Configure client with short reconnect backoff
1656        let config = WebSocketConfig {
1657            url: format!("ws://127.0.0.1:{port}"),
1658            headers: vec![],
1659            heartbeat: None,
1660            heartbeat_msg: None,
1661            reconnect_timeout_ms: Some(1_000),
1662            reconnect_delay_initial_ms: Some(50),
1663            reconnect_delay_max_ms: Some(100),
1664            reconnect_backoff_factor: Some(1.0),
1665            reconnect_jitter_ms: Some(0),
1666            reconnect_max_attempts: None,
1667            idle_timeout_ms: None,
1668        };
1669
1670        // Connect the client
1671        let client = WebSocketClient::connect(config, Some(handler), None, None, vec![], None)
1672            .await
1673            .unwrap();
1674
1675        // Allow server to drop connection and client to detect
1676        sleep(Duration::from_millis(100)).await;
1677        // Now immediately disconnect the client
1678        client.disconnect().await;
1679        assert!(client.is_disconnected());
1680        server.abort();
1681    }
1682
1683    #[rstest]
1684    #[tokio::test]
1685    async fn test_reconnect_state_flips_when_reader_stops() {
1686        // Bind an ephemeral port and accept a single websocket connection which we drop.
1687        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1688        let port = listener.local_addr().unwrap().port();
1689
1690        let server = task::spawn(async move {
1691            if let Ok((stream, _)) = listener.accept().await
1692                && let Ok(ws) = accept_async(stream).await
1693            {
1694                drop(ws);
1695            }
1696            sleep(Duration::from_millis(50)).await;
1697        });
1698
1699        let (handler, _rx) = channel_message_handler();
1700
1701        let config = WebSocketConfig {
1702            url: format!("ws://127.0.0.1:{port}"),
1703            headers: vec![],
1704            heartbeat: None,
1705            heartbeat_msg: None,
1706            reconnect_timeout_ms: Some(1_000),
1707            reconnect_delay_initial_ms: Some(50),
1708            reconnect_delay_max_ms: Some(100),
1709            reconnect_backoff_factor: Some(1.0),
1710            reconnect_jitter_ms: Some(0),
1711            reconnect_max_attempts: None,
1712            idle_timeout_ms: None,
1713        };
1714
1715        let client = WebSocketClient::connect(config, Some(handler), None, None, vec![], None)
1716            .await
1717            .unwrap();
1718
1719        tokio::time::timeout(Duration::from_secs(2), async {
1720            loop {
1721                if client.is_reconnecting() {
1722                    break;
1723                }
1724                tokio::time::sleep(Duration::from_millis(10)).await;
1725            }
1726        })
1727        .await
1728        .expect("client did not enter RECONNECT state");
1729
1730        client.disconnect().await;
1731        server.abort();
1732    }
1733
1734    #[rstest]
1735    #[tokio::test]
1736    async fn test_stream_mode_disables_auto_reconnect() {
1737        // Test that stream-based clients (created via connect_stream) set is_stream_mode flag
1738        // and that reconnect() transitions to CLOSED state for stream mode
1739        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1740        let port = listener.local_addr().unwrap().port();
1741
1742        let server = task::spawn(async move {
1743            if let Ok((stream, _)) = listener.accept().await
1744                && let Ok(_ws) = accept_async(stream).await
1745            {
1746                // Keep connection alive briefly
1747                sleep(Duration::from_millis(100)).await;
1748            }
1749        });
1750
1751        let config = WebSocketConfig {
1752            url: format!("ws://127.0.0.1:{port}"),
1753            headers: vec![],
1754            heartbeat: None,
1755            heartbeat_msg: None,
1756            reconnect_timeout_ms: Some(1_000),
1757            reconnect_delay_initial_ms: Some(50),
1758            reconnect_delay_max_ms: Some(100),
1759            reconnect_backoff_factor: Some(1.0),
1760            reconnect_jitter_ms: Some(0),
1761            reconnect_max_attempts: None,
1762            idle_timeout_ms: None,
1763        };
1764
1765        let (_reader, _client) = WebSocketClient::connect_stream(config, vec![], None, None)
1766            .await
1767            .unwrap();
1768
1769        // Note: We can't easily test the reconnect behavior from the outside since
1770        // the inner client is private. The key fix is that WebSocketClientInner
1771        // now has is_stream_mode=true for connect_stream, and reconnect() will
1772        // transition to CLOSED state instead of creating a new reader that gets dropped.
1773        // This is tested implicitly by the fact that stream users won't get stuck
1774        // in an infinite reconnect loop.
1775
1776        server.abort();
1777    }
1778
1779    #[rstest]
1780    #[tokio::test]
1781    async fn test_message_handler_mode_allows_auto_reconnect() {
1782        // Test that regular clients (with message handler) can auto-reconnect
1783        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1784        let port = listener.local_addr().unwrap().port();
1785
1786        let server = task::spawn(async move {
1787            // Accept first connection and close it
1788            if let Ok((stream, _)) = listener.accept().await
1789                && let Ok(ws) = accept_async(stream).await
1790            {
1791                drop(ws);
1792            }
1793            sleep(Duration::from_millis(50)).await;
1794        });
1795
1796        let (handler, _rx) = channel_message_handler();
1797
1798        let config = WebSocketConfig {
1799            url: format!("ws://127.0.0.1:{port}"),
1800            headers: vec![],
1801            heartbeat: None,
1802            heartbeat_msg: None,
1803            reconnect_timeout_ms: Some(1_000),
1804            reconnect_delay_initial_ms: Some(50),
1805            reconnect_delay_max_ms: Some(100),
1806            reconnect_backoff_factor: Some(1.0),
1807            reconnect_jitter_ms: Some(0),
1808            reconnect_max_attempts: None,
1809            idle_timeout_ms: None,
1810        };
1811
1812        let client = WebSocketClient::connect(config, Some(handler), None, None, vec![], None)
1813            .await
1814            .unwrap();
1815
1816        // Wait for the connection to be dropped and reconnection to be attempted
1817        tokio::time::timeout(Duration::from_secs(2), async {
1818            loop {
1819                if client.is_reconnecting() || client.is_closed() {
1820                    break;
1821                }
1822                tokio::time::sleep(Duration::from_millis(10)).await;
1823            }
1824        })
1825        .await
1826        .expect("client should attempt reconnection or close");
1827
1828        // Should either be reconnecting or closed (depending on timing)
1829        // The important thing is it's not staying active forever
1830        assert!(
1831            client.is_reconnecting() || client.is_closed(),
1832            "Client with message handler should attempt reconnection"
1833        );
1834
1835        client.disconnect().await;
1836        server.abort();
1837    }
1838
1839    #[rstest]
1840    #[tokio::test]
1841    async fn test_handler_mode_reconnect_with_new_connection() {
1842        // Test that handler mode successfully reconnects and messages continue flowing
1843        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1844        let port = listener.local_addr().unwrap().port();
1845
1846        let server = task::spawn(async move {
1847            // First connection - accept and immediately close
1848            if let Ok((stream, _)) = listener.accept().await
1849                && let Ok(ws) = accept_async(stream).await
1850            {
1851                drop(ws);
1852            }
1853
1854            // Small delay to let client detect disconnection
1855            sleep(Duration::from_millis(100)).await;
1856
1857            // Second connection - accept, send a message, then keep alive
1858            if let Ok((stream, _)) = listener.accept().await
1859                && let Ok(mut ws) = accept_async(stream).await
1860            {
1861                use futures_util::SinkExt;
1862                let _ = ws
1863                    .send(Message::Text("reconnected".to_string().into()))
1864                    .await;
1865                sleep(Duration::from_secs(1)).await;
1866            }
1867        });
1868
1869        let (handler, mut rx) = channel_message_handler();
1870
1871        let config = WebSocketConfig {
1872            url: format!("ws://127.0.0.1:{port}"),
1873            headers: vec![],
1874            heartbeat: None,
1875            heartbeat_msg: None,
1876            reconnect_timeout_ms: Some(2_000),
1877            reconnect_delay_initial_ms: Some(50),
1878            reconnect_delay_max_ms: Some(200),
1879            reconnect_backoff_factor: Some(1.5),
1880            reconnect_jitter_ms: Some(10),
1881            reconnect_max_attempts: None,
1882            idle_timeout_ms: None,
1883        };
1884
1885        let client = WebSocketClient::connect(config, Some(handler), None, None, vec![], None)
1886            .await
1887            .unwrap();
1888
1889        // Wait for reconnection to happen and message to arrive
1890        let result = tokio::time::timeout(Duration::from_secs(5), async {
1891            loop {
1892                if let Ok(msg) = rx.try_recv()
1893                    && matches!(msg, Message::Text(ref text) if AsRef::<str>::as_ref(text) == "reconnected")
1894                {
1895                    return true;
1896                }
1897                tokio::time::sleep(Duration::from_millis(10)).await;
1898            }
1899        })
1900        .await;
1901
1902        assert!(
1903            result.is_ok(),
1904            "Should receive message after reconnection within timeout"
1905        );
1906
1907        client.disconnect().await;
1908        server.abort();
1909    }
1910
1911    #[rstest]
1912    #[tokio::test]
1913    async fn test_stream_mode_no_auto_reconnect() {
1914        // Test that stream mode does not automatically reconnect when connection is lost
1915        // The caller owns the reader and is responsible for detecting disconnection
1916        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1917        let port = listener.local_addr().unwrap().port();
1918
1919        let server = task::spawn(async move {
1920            // Accept connection and send one message, then close
1921            if let Ok((stream, _)) = listener.accept().await
1922                && let Ok(mut ws) = accept_async(stream).await
1923            {
1924                use futures_util::SinkExt;
1925                let _ = ws.send(Message::Text("hello".to_string().into())).await;
1926                sleep(Duration::from_millis(50)).await;
1927                // Connection closes when ws is dropped
1928            }
1929        });
1930
1931        let config = WebSocketConfig {
1932            url: format!("ws://127.0.0.1:{port}"),
1933            headers: vec![],
1934            heartbeat: None,
1935            heartbeat_msg: None,
1936            reconnect_timeout_ms: Some(1_000),
1937            reconnect_delay_initial_ms: Some(50),
1938            reconnect_delay_max_ms: Some(100),
1939            reconnect_backoff_factor: Some(1.0),
1940            reconnect_jitter_ms: Some(0),
1941            reconnect_max_attempts: None,
1942            idle_timeout_ms: None,
1943        };
1944
1945        let (mut reader, client) = WebSocketClient::connect_stream(config, vec![], None, None)
1946            .await
1947            .unwrap();
1948
1949        // Initially active
1950        assert!(client.is_active(), "Client should start as active");
1951
1952        // Read the hello message
1953        let msg = reader.next().await;
1954        assert!(
1955            matches!(msg, Some(Ok(Message::Text(ref text))) if AsRef::<str>::as_ref(text) == "hello"),
1956            "Should receive initial message"
1957        );
1958
1959        // Read until connection closes (reader will return None or error)
1960        while let Some(msg) = reader.next().await {
1961            if msg.is_err() || matches!(msg, Ok(Message::Close(_))) {
1962                break;
1963            }
1964        }
1965
1966        // In stream mode, the controller cannot detect disconnection (reader is owned by caller)
1967        // The client remains ACTIVE - it's the caller's responsibility to call disconnect()
1968        sleep(Duration::from_millis(200)).await;
1969
1970        // Client should still be ACTIVE (not RECONNECTING or CLOSED)
1971        // This is correct behavior - stream mode doesn't auto-detect disconnection
1972        assert!(
1973            client.is_active() || client.is_closed(),
1974            "Stream mode client stays ACTIVE (caller owns reader) or caller disconnected"
1975        );
1976        assert!(
1977            !client.is_reconnecting(),
1978            "Stream mode client should never attempt reconnection"
1979        );
1980
1981        client.disconnect().await;
1982        server.abort();
1983    }
1984
1985    #[rstest]
1986    #[tokio::test]
1987    async fn test_send_timeout_uses_configured_reconnect_timeout() {
1988        // Test that send operations respect the configured reconnect_timeout.
1989        // When a client is stuck in RECONNECT longer than the timeout, sends should fail with Timeout.
1990        use nautilus_common::testing::wait_until_async;
1991
1992        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1993        let port = listener.local_addr().unwrap().port();
1994
1995        let server = task::spawn(async move {
1996            // Accept first connection and immediately close it
1997            if let Ok((stream, _)) = listener.accept().await
1998                && let Ok(ws) = accept_async(stream).await
1999            {
2000                drop(ws);
2001            }
2002            // Don't accept second connection - client will be stuck in RECONNECT
2003            sleep(Duration::from_secs(60)).await;
2004        });
2005
2006        let (handler, _rx) = channel_message_handler();
2007
2008        // Configure with SHORT 2s reconnect timeout
2009        let config = WebSocketConfig {
2010            url: format!("ws://127.0.0.1:{port}"),
2011            headers: vec![],
2012            heartbeat: None,
2013            heartbeat_msg: None,
2014            reconnect_timeout_ms: Some(2_000), // 2s timeout
2015            reconnect_delay_initial_ms: Some(50),
2016            reconnect_delay_max_ms: Some(100),
2017            reconnect_backoff_factor: Some(1.0),
2018            reconnect_jitter_ms: Some(0),
2019            reconnect_max_attempts: None,
2020            idle_timeout_ms: None,
2021        };
2022
2023        let client = WebSocketClient::connect(config, Some(handler), None, None, vec![], None)
2024            .await
2025            .unwrap();
2026
2027        // Wait for client to enter RECONNECT state
2028        wait_until_async(
2029            || async { client.is_reconnecting() },
2030            Duration::from_secs(3),
2031        )
2032        .await;
2033
2034        // Attempt send while stuck in RECONNECT - should timeout after 2s (configured timeout)
2035        let start = std::time::Instant::now();
2036        let send_result = client.send_text("test".to_string(), None).await;
2037        let elapsed = start.elapsed();
2038
2039        assert!(
2040            send_result.is_err(),
2041            "Send should fail when client stuck in RECONNECT"
2042        );
2043        assert!(
2044            matches!(send_result, Err(crate::error::SendError::Timeout)),
2045            "Send should return Timeout error, was: {send_result:?}"
2046        );
2047        // Verify timeout respects configured value (2s), but don't check upper bound
2048        // as CI scheduler jitter can cause legitimate delays beyond the timeout
2049        assert!(
2050            elapsed >= Duration::from_millis(1800),
2051            "Send should timeout after at least 2s (configured timeout), took {elapsed:?}"
2052        );
2053
2054        client.disconnect().await;
2055        server.abort();
2056    }
2057
2058    #[rstest]
2059    #[tokio::test]
2060    async fn test_send_waits_during_reconnection() {
2061        // Test that send operations wait for reconnection to complete (up to timeout)
2062        use nautilus_common::testing::wait_until_async;
2063
2064        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
2065        let port = listener.local_addr().unwrap().port();
2066
2067        let server = task::spawn(async move {
2068            // First connection - accept and immediately close
2069            if let Ok((stream, _)) = listener.accept().await
2070                && let Ok(ws) = accept_async(stream).await
2071            {
2072                drop(ws);
2073            }
2074
2075            // Wait a bit before accepting second connection
2076            sleep(Duration::from_millis(500)).await;
2077
2078            // Second connection - accept and keep alive
2079            if let Ok((stream, _)) = listener.accept().await
2080                && let Ok(mut ws) = accept_async(stream).await
2081            {
2082                // Echo messages
2083                while let Some(Ok(msg)) = ws.next().await {
2084                    if ws.send(msg).await.is_err() {
2085                        break;
2086                    }
2087                }
2088            }
2089        });
2090
2091        let (handler, _rx) = channel_message_handler();
2092
2093        let config = WebSocketConfig {
2094            url: format!("ws://127.0.0.1:{port}"),
2095            headers: vec![],
2096            heartbeat: None,
2097            heartbeat_msg: None,
2098            reconnect_timeout_ms: Some(5_000), // 5s timeout - enough for reconnect
2099            reconnect_delay_initial_ms: Some(100),
2100            reconnect_delay_max_ms: Some(200),
2101            reconnect_backoff_factor: Some(1.0),
2102            reconnect_jitter_ms: Some(0),
2103            reconnect_max_attempts: None,
2104            idle_timeout_ms: None,
2105        };
2106
2107        let client = WebSocketClient::connect(config, Some(handler), None, None, vec![], None)
2108            .await
2109            .unwrap();
2110
2111        // Wait for reconnection to trigger
2112        wait_until_async(
2113            || async { client.is_reconnecting() },
2114            Duration::from_secs(2),
2115        )
2116        .await;
2117
2118        // Try to send while reconnecting - should wait and succeed after reconnect
2119        let send_result = tokio::time::timeout(
2120            Duration::from_secs(3),
2121            client.send_text("test_message".to_string(), None),
2122        )
2123        .await;
2124
2125        assert!(
2126            send_result.is_ok() && send_result.unwrap().is_ok(),
2127            "Send should succeed after waiting for reconnection"
2128        );
2129
2130        client.disconnect().await;
2131        server.abort();
2132    }
2133
2134    #[rstest]
2135    #[tokio::test]
2136    async fn test_rate_limiter_before_active_wait() {
2137        // Test that rate limiting happens BEFORE active state check.
2138        // This prevents race conditions where connection state changes during rate limit wait.
2139        // We verify this by: (1) exhausting rate limit, (2) ensuring client is RECONNECTING,
2140        // (3) sending again and confirming it waits for rate limit THEN reconnection.
2141        use std::{num::NonZeroU32, sync::Arc};
2142
2143        use nautilus_common::testing::wait_until_async;
2144
2145        use crate::ratelimiter::quota::Quota;
2146
2147        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
2148        let port = listener.local_addr().unwrap().port();
2149
2150        let server = task::spawn(async move {
2151            // First connection - accept and close after receiving one message
2152            if let Ok((stream, _)) = listener.accept().await
2153                && let Ok(mut ws) = accept_async(stream).await
2154            {
2155                // Receive first message then close
2156                if let Some(Ok(_)) = ws.next().await {
2157                    drop(ws);
2158                }
2159            }
2160
2161            // Wait before accepting reconnection
2162            sleep(Duration::from_millis(500)).await;
2163
2164            // Second connection - accept and keep alive
2165            if let Ok((stream, _)) = listener.accept().await
2166                && let Ok(mut ws) = accept_async(stream).await
2167            {
2168                while let Some(Ok(msg)) = ws.next().await {
2169                    if ws.send(msg).await.is_err() {
2170                        break;
2171                    }
2172                }
2173            }
2174        });
2175
2176        let (handler, _rx) = channel_message_handler();
2177
2178        let config = WebSocketConfig {
2179            url: format!("ws://127.0.0.1:{port}"),
2180            headers: vec![],
2181            heartbeat: None,
2182            heartbeat_msg: None,
2183            reconnect_timeout_ms: Some(5_000),
2184            reconnect_delay_initial_ms: Some(50),
2185            reconnect_delay_max_ms: Some(100),
2186            reconnect_backoff_factor: Some(1.0),
2187            reconnect_jitter_ms: Some(0),
2188            reconnect_max_attempts: None,
2189            idle_timeout_ms: None,
2190        };
2191
2192        // Very restrictive rate limit: 1 request per second, burst of 1
2193        let quota = Quota::per_second(NonZeroU32::new(1).unwrap())
2194            .unwrap()
2195            .allow_burst(NonZeroU32::new(1).unwrap());
2196
2197        let client = Arc::new(
2198            WebSocketClient::connect(
2199                config,
2200                Some(handler),
2201                None,
2202                None,
2203                vec![("test_key".to_string(), quota)],
2204                None,
2205            )
2206            .await
2207            .unwrap(),
2208        );
2209
2210        // First send exhausts burst capacity and triggers connection close
2211        let test_key: [Ustr; 1] = [Ustr::from("test_key")];
2212        client
2213            .send_text("msg1".to_string(), Some(test_key.as_slice()))
2214            .await
2215            .unwrap();
2216
2217        // Wait for client to enter RECONNECT state
2218        wait_until_async(
2219            || async { client.is_reconnecting() },
2220            Duration::from_secs(2),
2221        )
2222        .await;
2223
2224        // Second send: will hit rate limit (~1s) THEN wait for reconnection (~0.5s)
2225        let start = std::time::Instant::now();
2226        let send_result = client
2227            .send_text("msg2".to_string(), Some(test_key.as_slice()))
2228            .await;
2229        let elapsed = start.elapsed();
2230
2231        // Should succeed after both rate limit AND reconnection
2232        assert!(
2233            send_result.is_ok(),
2234            "Send should succeed after rate limit + reconnection, was: {send_result:?}"
2235        );
2236        // Total wait should be at least rate limit time (~1s)
2237        // The reconnection completes while rate limiting or after
2238        // Use 850ms threshold to account for timing jitter in CI
2239        assert!(
2240            elapsed >= Duration::from_millis(850),
2241            "Should wait for rate limit (~1s), waited {elapsed:?}"
2242        );
2243
2244        client.disconnect().await;
2245        server.abort();
2246    }
2247
2248    #[rstest]
2249    #[tokio::test]
2250    async fn test_disconnect_during_reconnect_exits_cleanly() {
2251        // Test CAS race condition: disconnect called during reconnection
2252        // Should exit cleanly without spawning new tasks
2253        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
2254        let port = listener.local_addr().unwrap().port();
2255
2256        let server = task::spawn(async move {
2257            // Accept first connection and immediately close
2258            if let Ok((stream, _)) = listener.accept().await
2259                && let Ok(ws) = accept_async(stream).await
2260            {
2261                drop(ws);
2262            }
2263            // Don't accept second connection - let reconnect hang
2264            sleep(Duration::from_secs(60)).await;
2265        });
2266
2267        let (handler, _rx) = channel_message_handler();
2268
2269        let config = WebSocketConfig {
2270            url: format!("ws://127.0.0.1:{port}"),
2271            headers: vec![],
2272            heartbeat: None,
2273            heartbeat_msg: None,
2274            reconnect_timeout_ms: Some(2_000), // 2s timeout - shorter than disconnect timeout
2275            reconnect_delay_initial_ms: Some(100),
2276            reconnect_delay_max_ms: Some(200),
2277            reconnect_backoff_factor: Some(1.0),
2278            reconnect_jitter_ms: Some(0),
2279            reconnect_max_attempts: None,
2280            idle_timeout_ms: None,
2281        };
2282
2283        let client = WebSocketClient::connect(config, Some(handler), None, None, vec![], None)
2284            .await
2285            .unwrap();
2286
2287        // Wait for reconnection to start
2288        tokio::time::timeout(Duration::from_secs(2), async {
2289            while !client.is_reconnecting() {
2290                sleep(Duration::from_millis(10)).await;
2291            }
2292        })
2293        .await
2294        .expect("Client should enter RECONNECT state");
2295
2296        // Disconnect while reconnecting
2297        client.disconnect().await;
2298
2299        // Should be cleanly closed
2300        assert!(
2301            client.is_disconnected(),
2302            "Client should be cleanly disconnected"
2303        );
2304
2305        server.abort();
2306    }
2307
2308    #[rstest]
2309    #[tokio::test]
2310    async fn test_send_fails_fast_when_closed_before_rate_limit() {
2311        // Test that send operations check connection state BEFORE rate limiting,
2312        // preventing unnecessary delays when the connection is already closed.
2313        use std::{num::NonZeroU32, sync::Arc};
2314
2315        use nautilus_common::testing::wait_until_async;
2316
2317        use crate::ratelimiter::quota::Quota;
2318
2319        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
2320        let port = listener.local_addr().unwrap().port();
2321
2322        let server = task::spawn(async move {
2323            // Accept connection and immediately close
2324            if let Ok((stream, _)) = listener.accept().await
2325                && let Ok(ws) = accept_async(stream).await
2326            {
2327                drop(ws);
2328            }
2329            sleep(Duration::from_secs(60)).await;
2330        });
2331
2332        let (handler, _rx) = channel_message_handler();
2333
2334        let config = WebSocketConfig {
2335            url: format!("ws://127.0.0.1:{port}"),
2336            headers: vec![],
2337            heartbeat: None,
2338            heartbeat_msg: None,
2339            reconnect_timeout_ms: Some(5_000),
2340            reconnect_delay_initial_ms: Some(50),
2341            reconnect_delay_max_ms: Some(100),
2342            reconnect_backoff_factor: Some(1.0),
2343            reconnect_jitter_ms: Some(0),
2344            reconnect_max_attempts: None,
2345            idle_timeout_ms: None,
2346        };
2347
2348        // Very restrictive rate limit: 1 request per 10 seconds
2349        // This ensures that if we wait for rate limit, the test will timeout
2350        let quota = Quota::with_period(Duration::from_secs(10))
2351            .unwrap()
2352            .allow_burst(NonZeroU32::new(1).unwrap());
2353
2354        let client = Arc::new(
2355            WebSocketClient::connect(
2356                config,
2357                Some(handler),
2358                None,
2359                None,
2360                vec![("test_key".to_string(), quota)],
2361                None,
2362            )
2363            .await
2364            .unwrap(),
2365        );
2366
2367        // Wait for disconnection
2368        wait_until_async(
2369            || async { client.is_reconnecting() || client.is_closed() },
2370            Duration::from_secs(2),
2371        )
2372        .await;
2373
2374        // Explicitly disconnect to move away from ACTIVE state
2375        client.disconnect().await;
2376        assert!(
2377            !client.is_active(),
2378            "Client should not be active after disconnect"
2379        );
2380
2381        // Attempt send - should fail IMMEDIATELY without waiting for rate limit
2382        let start = std::time::Instant::now();
2383        let test_key: [Ustr; 1] = [Ustr::from("test_key")];
2384        let result = client
2385            .send_text("test".to_string(), Some(test_key.as_slice()))
2386            .await;
2387        let elapsed = start.elapsed();
2388
2389        // Should fail with Closed error
2390        assert!(result.is_err(), "Send should fail when client is closed");
2391        assert!(
2392            matches!(result, Err(crate::error::SendError::Closed)),
2393            "Send should return Closed error, was: {result:?}"
2394        );
2395
2396        // Should fail FAST (< 100ms) without waiting for rate limit (10s)
2397        assert!(
2398            elapsed < Duration::from_millis(100),
2399            "Send should fail fast without rate limiting, took {elapsed:?}"
2400        );
2401
2402        server.abort();
2403    }
2404
2405    #[rstest]
2406    #[tokio::test]
2407    async fn test_connect_rejects_none_message_handler() {
2408        // Test that connect() properly rejects None message_handler
2409        // to prevent zombie connections that appear alive but never detect disconnections
2410
2411        let config = WebSocketConfig {
2412            url: "ws://127.0.0.1:9999".to_string(),
2413            headers: vec![],
2414            heartbeat: None,
2415            heartbeat_msg: None,
2416            reconnect_timeout_ms: Some(1_000),
2417            reconnect_delay_initial_ms: Some(100),
2418            reconnect_delay_max_ms: Some(500),
2419            reconnect_backoff_factor: Some(1.5),
2420            reconnect_jitter_ms: Some(0),
2421            reconnect_max_attempts: None,
2422            idle_timeout_ms: None,
2423        };
2424
2425        // Pass None for message_handler - should be rejected
2426        let result = WebSocketClient::connect(config, None, None, None, vec![], None).await;
2427
2428        assert!(
2429            result.is_err(),
2430            "connect() should reject None message_handler"
2431        );
2432
2433        let err = result.unwrap_err();
2434        let err_msg = err.to_string();
2435        assert!(
2436            err_msg.contains("Handler mode requires message_handler"),
2437            "Error should mention missing message_handler, was: {err_msg}"
2438        );
2439    }
2440
2441    #[rstest]
2442    #[tokio::test]
2443    async fn test_client_without_handler_sets_stream_mode() {
2444        // Test that if a client is created without a handler via connect_url,
2445        // it properly sets is_stream_mode=true to prevent zombie connections
2446
2447        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
2448        let port = listener.local_addr().unwrap().port();
2449
2450        let server = task::spawn(async move {
2451            // Accept and immediately close to simulate server disconnect
2452            if let Ok((stream, _)) = listener.accept().await
2453                && let Ok(ws) = accept_async(stream).await
2454            {
2455                drop(ws); // Drop connection immediately
2456            }
2457        });
2458
2459        let config = WebSocketConfig {
2460            url: format!("ws://127.0.0.1:{port}"),
2461            headers: vec![],
2462            heartbeat: None,
2463            heartbeat_msg: None,
2464            reconnect_timeout_ms: Some(1_000),
2465            reconnect_delay_initial_ms: Some(100),
2466            reconnect_delay_max_ms: Some(500),
2467            reconnect_backoff_factor: Some(1.5),
2468            reconnect_jitter_ms: Some(0),
2469            reconnect_max_attempts: None,
2470            idle_timeout_ms: None,
2471        };
2472
2473        // Create client directly via connect_url with no handler (stream mode)
2474        let inner = WebSocketClientInner::connect_url(config, None, None)
2475            .await
2476            .unwrap();
2477
2478        // Verify is_stream_mode is true when no handler
2479        assert!(
2480            inner.is_stream_mode,
2481            "Client without handler should have is_stream_mode=true"
2482        );
2483
2484        // Verify that when stream mode is enabled, reconnection is disabled
2485        // (documented behavior - stream mode clients close instead of reconnecting)
2486
2487        server.abort();
2488    }
2489
2490    #[rstest]
2491    #[tokio::test]
2492    async fn test_idle_timeout_triggers_reconnect() {
2493        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
2494        let port = listener.local_addr().unwrap().port();
2495
2496        // Server accepts WS connection but sends nothing (simulates silent death)
2497        let server = task::spawn(async move {
2498            let (stream, _) = listener.accept().await.unwrap();
2499            let _ws = accept_async(stream).await.unwrap();
2500            // Hold connection open but send nothing
2501            sleep(Duration::from_secs(5)).await;
2502        });
2503
2504        let (handler, _rx) = channel_message_handler();
2505
2506        let config = WebSocketConfig {
2507            url: format!("ws://127.0.0.1:{port}"),
2508            headers: vec![],
2509            heartbeat: None,
2510            heartbeat_msg: None,
2511            reconnect_timeout_ms: Some(2_000),
2512            reconnect_delay_initial_ms: Some(50),
2513            reconnect_delay_max_ms: Some(100),
2514            reconnect_backoff_factor: Some(1.0),
2515            reconnect_jitter_ms: Some(0),
2516            reconnect_max_attempts: Some(1),
2517            idle_timeout_ms: Some(500),
2518        };
2519
2520        let client = WebSocketClient::connect(config, Some(handler), None, None, vec![], None)
2521            .await
2522            .unwrap();
2523
2524        assert!(client.is_active());
2525
2526        // Wait for idle timeout to fire and client to enter reconnect/closed
2527        wait_until_async(
2528            || async { client.is_reconnecting() || client.is_disconnected() },
2529            Duration::from_secs(3),
2530        )
2531        .await;
2532
2533        assert!(
2534            !client.is_active(),
2535            "Client should not be active after idle timeout"
2536        );
2537
2538        client.disconnect().await;
2539        server.abort();
2540    }
2541
2542    #[rstest]
2543    #[tokio::test]
2544    async fn test_idle_timeout_resets_on_data() {
2545        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
2546        let port = listener.local_addr().unwrap().port();
2547
2548        // Server sends a message every 200ms (well within 1s idle timeout)
2549        let server = task::spawn(async move {
2550            let (stream, _) = listener.accept().await.unwrap();
2551            let mut ws = accept_async(stream).await.unwrap();
2552            for _ in 0..10 {
2553                sleep(Duration::from_millis(200)).await;
2554                if ws
2555                    .send(tokio_tungstenite::tungstenite::Message::Text("ping".into()))
2556                    .await
2557                    .is_err()
2558                {
2559                    break;
2560                }
2561            }
2562        });
2563
2564        let (handler, _rx) = channel_message_handler();
2565
2566        let config = WebSocketConfig {
2567            url: format!("ws://127.0.0.1:{port}"),
2568            headers: vec![],
2569            heartbeat: None,
2570            heartbeat_msg: None,
2571            reconnect_timeout_ms: Some(2_000),
2572            reconnect_delay_initial_ms: Some(50),
2573            reconnect_delay_max_ms: Some(100),
2574            reconnect_backoff_factor: Some(1.0),
2575            reconnect_jitter_ms: Some(0),
2576            reconnect_max_attempts: Some(1),
2577            idle_timeout_ms: Some(1_000),
2578        };
2579
2580        let client = WebSocketClient::connect(config, Some(handler), None, None, vec![], None)
2581            .await
2582            .unwrap();
2583
2584        assert!(client.is_active());
2585
2586        // Wait 1.5s - data arrives every 200ms so idle timeout (1s) should NOT fire
2587        sleep(Duration::from_millis(1_500)).await;
2588
2589        assert!(
2590            client.is_active(),
2591            "Client should remain active when data is flowing"
2592        );
2593
2594        client.disconnect().await;
2595        server.abort();
2596    }
2597
2598    #[rstest]
2599    #[tokio::test]
2600    async fn test_zero_idle_timeout_rejected() {
2601        let (handler, _rx) = channel_message_handler();
2602
2603        let config = WebSocketConfig {
2604            url: "ws://127.0.0.1:9999".to_string(),
2605            headers: vec![],
2606            heartbeat: None,
2607            heartbeat_msg: None,
2608            reconnect_timeout_ms: None,
2609            reconnect_delay_initial_ms: None,
2610            reconnect_delay_max_ms: None,
2611            reconnect_backoff_factor: None,
2612            reconnect_jitter_ms: None,
2613            reconnect_max_attempts: None,
2614            idle_timeout_ms: Some(0),
2615        };
2616
2617        let result =
2618            WebSocketClient::connect(config, Some(handler), None, None, vec![], None).await;
2619
2620        assert!(result.is_err(), "Zero idle timeout should be rejected");
2621        let err_msg = result.unwrap_err().to_string();
2622        assert!(
2623            err_msg.contains("Idle timeout cannot be zero"),
2624            "Error should mention zero idle timeout, was: {err_msg}"
2625        );
2626    }
2627}