Skip to main content

nautilus_network/websocket/
client.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2026 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16//! WebSocket client implementation with automatic reconnection.
17//!
18//! This module contains the core WebSocket client implementation including:
19//! - Connection management with automatic reconnection.
20//! - Split read/write architecture with separate tasks.
21//! - Unbounded channels on latency-sensitive paths.
22//! - Event-driven state notification via `Notify` for immediate wakeup on transitions.
23//! - Heartbeat support.
24//! - Rate limiting integration.
25
26use std::{
27    collections::VecDeque,
28    fmt::Debug,
29    sync::{
30        Arc,
31        atomic::{AtomicU8, Ordering},
32    },
33    time::Duration,
34};
35
36use futures_util::{SinkExt, StreamExt};
37use http::HeaderName;
38use nautilus_core::CleanDrop;
39use nautilus_cryptography::providers::install_cryptographic_provider;
40#[cfg(feature = "turmoil")]
41use tokio_tungstenite::MaybeTlsStream;
42#[cfg(feature = "turmoil")]
43use tokio_tungstenite::client_async;
44#[cfg(not(feature = "turmoil"))]
45use tokio_tungstenite::connect_async_with_config;
46use tokio_tungstenite::tungstenite::{
47    Error, Message, client::IntoClientRequest, http::HeaderValue,
48};
49use ustr::Ustr;
50
51use super::{
52    config::WebSocketConfig,
53    consts::{
54        CONNECTION_STATE_CHECK_INTERVAL_MS, GRACEFUL_SHUTDOWN_DELAY_MS,
55        GRACEFUL_SHUTDOWN_TIMEOUT_SECS,
56    },
57    types::{MessageHandler, MessageReader, MessageWriter, PingHandler, WriterCommand},
58};
59#[cfg(feature = "turmoil")]
60use crate::net::TcpConnector;
61use crate::{
62    RECONNECTED,
63    backoff::ExponentialBackoff,
64    error::SendError,
65    logging::{log_task_aborted, log_task_started, log_task_stopped},
66    mode::ConnectionMode,
67    ratelimiter::{RateLimiter, clock::MonotonicClock, quota::Quota},
68};
69
70/// `WebSocketClient` connects to a websocket server to read and send messages.
71///
72/// The client is opinionated about how messages are read and written. It
73/// assumes that data can only have one reader but multiple writers.
74///
75/// The client splits the connection into read and write halves. It moves
76/// the read half into a tokio task which keeps receiving messages from the
77/// server and calls a handler - a Python function that takes the data
78/// as its parameter. It stores the write half in the struct wrapped
79/// with an Arc Mutex. This way the client struct can be used to write
80/// data to the server from multiple scopes/tasks.
81///
82/// The client also maintains a heartbeat if given a duration in seconds.
83/// It's preferable to set the duration slightly lower - heartbeat more
84/// frequently - than the required amount.
85pub struct WebSocketClientInner {
86    config: WebSocketConfig,
87    /// The function to handle incoming messages (stored separately from config).
88    message_handler: Option<MessageHandler>,
89    /// The handler for incoming pings (stored separately from config).
90    ping_handler: Option<PingHandler>,
91    read_task: Option<tokio::task::JoinHandle<()>>,
92    write_task: tokio::task::JoinHandle<()>,
93    writer_tx: tokio::sync::mpsc::UnboundedSender<WriterCommand>,
94    heartbeat_task: Option<tokio::task::JoinHandle<()>>,
95    connection_mode: Arc<AtomicU8>,
96    state_notify: Arc<tokio::sync::Notify>,
97    reconnect_timeout: Duration,
98    backoff: ExponentialBackoff,
99    /// True if this is a stream-based client (created via `connect_stream`).
100    /// Stream-based clients disable auto-reconnect because the reader is
101    /// owned by the caller and cannot be replaced during reconnection.
102    is_stream_mode: bool,
103    /// Maximum number of reconnection attempts before giving up (None = unlimited).
104    reconnect_max_attempts: Option<u32>,
105    /// Current count of consecutive reconnection attempts.
106    reconnection_attempt_count: u32,
107}
108
109impl WebSocketClientInner {
110    /// Create an inner websocket client with an existing writer.
111    ///
112    /// This is used for stream mode where the reader is owned by the caller.
113    ///
114    /// # Errors
115    ///
116    /// Returns an error if the exponential backoff configuration is invalid.
117    pub async fn new_with_writer(
118        config: WebSocketConfig,
119        writer: MessageWriter,
120    ) -> Result<Self, Error> {
121        install_cryptographic_provider();
122
123        let connection_mode = Arc::new(AtomicU8::new(ConnectionMode::Active.as_u8()));
124        let state_notify = Arc::new(tokio::sync::Notify::new());
125
126        // Note: We don't spawn a read task here since the reader is handled externally
127        let read_task = None;
128
129        // Stream mode ignores reconnect settings, use harmless defaults
130        let backoff = ExponentialBackoff::new(
131            Duration::from_secs(2),
132            Duration::from_secs(30),
133            1.5,
134            100,
135            true,
136        )
137        .map_err(|e| Error::Io(std::io::Error::new(std::io::ErrorKind::InvalidInput, e)))?;
138
139        let (writer_tx, writer_rx) = tokio::sync::mpsc::unbounded_channel::<WriterCommand>();
140        let write_task = Self::spawn_write_task(
141            connection_mode.clone(),
142            state_notify.clone(),
143            writer,
144            writer_rx,
145        );
146
147        let heartbeat_task = if let Some(heartbeat_interval) = config.heartbeat {
148            Some(Self::spawn_heartbeat_task(
149                connection_mode.clone(),
150                heartbeat_interval,
151                config.heartbeat_msg.clone(),
152                writer_tx.clone(),
153            ))
154        } else {
155            None
156        };
157
158        let reconnect_max_attempts = None; // Stream mode does not reconnect
159        let reconnect_timeout = Duration::from_secs(10);
160
161        Ok(Self {
162            config,
163            message_handler: None, // Stream mode has no handler
164            ping_handler: None,
165            writer_tx,
166            connection_mode,
167            state_notify,
168            reconnect_timeout,
169            heartbeat_task,
170            read_task,
171            write_task,
172            backoff,
173            is_stream_mode: true,
174            reconnect_max_attempts,
175            reconnection_attempt_count: 0,
176        })
177    }
178
179    /// Create an inner websocket client.
180    ///
181    /// # Errors
182    ///
183    /// Returns an error if:
184    /// - The connection to the server fails.
185    /// - The exponential backoff configuration is invalid.
186    pub async fn connect_url(
187        config: WebSocketConfig,
188        message_handler: Option<MessageHandler>,
189        ping_handler: Option<PingHandler>,
190    ) -> Result<Self, Error> {
191        install_cryptographic_provider();
192
193        if config.heartbeat == Some(0) {
194            return Err(Error::Io(std::io::Error::new(
195                std::io::ErrorKind::InvalidInput,
196                "Heartbeat interval cannot be zero",
197            )));
198        }
199
200        if config.idle_timeout_ms == Some(0) {
201            return Err(Error::Io(std::io::Error::new(
202                std::io::ErrorKind::InvalidInput,
203                "Idle timeout cannot be zero",
204            )));
205        }
206
207        // Capture whether we're in stream mode before moving config
208        let is_stream_mode = message_handler.is_none();
209        let reconnect_max_attempts = config.reconnect_max_attempts;
210
211        let (writer, reader) =
212            Self::connect_with_server(&config.url, config.headers.clone()).await?;
213
214        let connection_mode = Arc::new(AtomicU8::new(ConnectionMode::Active.as_u8()));
215        let state_notify = Arc::new(tokio::sync::Notify::new());
216
217        let read_task = if message_handler.is_some() {
218            Some(Self::spawn_message_handler_task(
219                connection_mode.clone(),
220                state_notify.clone(),
221                reader,
222                message_handler.as_ref(),
223                ping_handler.as_ref(),
224                config.idle_timeout_ms,
225            ))
226        } else {
227            None
228        };
229
230        let (writer_tx, writer_rx) = tokio::sync::mpsc::unbounded_channel::<WriterCommand>();
231        let write_task = Self::spawn_write_task(
232            connection_mode.clone(),
233            state_notify.clone(),
234            writer,
235            writer_rx,
236        );
237
238        // Optionally spawn a heartbeat task to periodically ping server
239        let heartbeat_task = config.heartbeat.map(|heartbeat_secs| {
240            Self::spawn_heartbeat_task(
241                connection_mode.clone(),
242                heartbeat_secs,
243                config.heartbeat_msg.clone(),
244                writer_tx.clone(),
245            )
246        });
247
248        let reconnect_timeout =
249            Duration::from_millis(config.reconnect_timeout_ms.unwrap_or(10_000));
250        let backoff = ExponentialBackoff::new(
251            Duration::from_millis(config.reconnect_delay_initial_ms.unwrap_or(2_000)),
252            Duration::from_millis(config.reconnect_delay_max_ms.unwrap_or(30_000)),
253            config.reconnect_backoff_factor.unwrap_or(1.5),
254            config.reconnect_jitter_ms.unwrap_or(100),
255            true, // immediate-first
256        )
257        .map_err(|e| Error::Io(std::io::Error::new(std::io::ErrorKind::InvalidInput, e)))?;
258
259        Ok(Self {
260            config,
261            message_handler,
262            ping_handler,
263            read_task,
264            write_task,
265            writer_tx,
266            heartbeat_task,
267            connection_mode,
268            state_notify,
269            reconnect_timeout,
270            backoff,
271            // Set stream mode when no message handler (reader not managed by client)
272            is_stream_mode,
273            reconnect_max_attempts,
274            reconnection_attempt_count: 0,
275        })
276    }
277
278    /// Connects with the server creating a tokio-tungstenite websocket stream.
279    /// Production version that uses `connect_async_with_config` convenience helper.
280    ///
281    /// # Errors
282    ///
283    /// Returns an error if:
284    /// - The URL cannot be parsed into a valid client request.
285    /// - Header values are invalid.
286    /// - The WebSocket connection fails.
287    #[inline]
288    #[cfg(not(feature = "turmoil"))]
289    pub async fn connect_with_server(
290        url: &str,
291        headers: Vec<(String, String)>,
292    ) -> Result<(MessageWriter, MessageReader), Error> {
293        let mut request = url.into_client_request()?;
294        let req_headers = request.headers_mut();
295
296        let mut header_names: Vec<HeaderName> = Vec::new();
297        for (key, val) in headers {
298            let header_value = HeaderValue::from_str(&val)?;
299            let header_name: HeaderName = key.parse()?;
300            header_names.push(header_name.clone());
301            req_headers.insert(header_name, header_value);
302        }
303
304        connect_async_with_config(request, None, true)
305            .await
306            .map(|resp| resp.0.split())
307    }
308
309    /// Connects with the server creating a tokio-tungstenite websocket stream.
310    /// Turmoil version that uses the lower-level `client_async` API with injected stream.
311    ///
312    /// # Errors
313    ///
314    /// Returns an error if:
315    /// - The URL cannot be parsed into a valid client request.
316    /// - The URL is missing a hostname.
317    /// - Header values are invalid.
318    /// - The TCP connection fails.
319    /// - TLS setup fails (for wss:// URLs).
320    /// - The WebSocket handshake fails.
321    #[inline]
322    #[cfg(feature = "turmoil")]
323    pub async fn connect_with_server(
324        url: &str,
325        headers: Vec<(String, String)>,
326    ) -> Result<(MessageWriter, MessageReader), Error> {
327        use rustls::ClientConfig;
328        use tokio_rustls::TlsConnector;
329
330        let mut request = url.into_client_request()?;
331        let req_headers = request.headers_mut();
332
333        let mut header_names: Vec<HeaderName> = Vec::new();
334        for (key, val) in headers {
335            let header_value = HeaderValue::from_str(&val)?;
336            let header_name: HeaderName = key.parse()?;
337            header_names.push(header_name.clone());
338            req_headers.insert(header_name, header_value);
339        }
340
341        let uri = request.uri();
342        let scheme = uri.scheme_str().unwrap_or("ws");
343        let host = uri.host().ok_or_else(|| {
344            Error::Url(tokio_tungstenite::tungstenite::error::UrlError::NoHostName)
345        })?;
346
347        // Determine port: use explicit port if specified, otherwise default based on scheme
348        let port = uri
349            .port_u16()
350            .unwrap_or_else(|| if scheme == "wss" { 443 } else { 80 });
351
352        let addr = format!("{host}:{port}");
353
354        // Use the connector to get a turmoil-compatible stream
355        let connector = crate::net::RealTcpConnector;
356        let tcp_stream = connector.connect(&addr).await?;
357        if let Err(e) = tcp_stream.set_nodelay(true) {
358            log::warn!("Failed to enable TCP_NODELAY for socket client: {e:?}");
359        }
360
361        // Wrap stream appropriately based on scheme
362        let maybe_tls_stream = if scheme == "wss" {
363            // Build TLS config with webpki roots
364            let mut root_store = rustls::RootCertStore::empty();
365            root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
366
367            let config = ClientConfig::builder()
368                .with_root_certificates(root_store)
369                .with_no_client_auth();
370
371            let tls_connector = TlsConnector::from(std::sync::Arc::new(config));
372            let domain =
373                rustls::pki_types::ServerName::try_from(host.to_string()).map_err(|e| {
374                    Error::Io(std::io::Error::new(
375                        std::io::ErrorKind::InvalidInput,
376                        format!("Invalid DNS name: {e}"),
377                    ))
378                })?;
379
380            let tls_stream = tls_connector.connect(domain, tcp_stream).await?;
381            MaybeTlsStream::Rustls(tls_stream)
382        } else {
383            MaybeTlsStream::Plain(tcp_stream)
384        };
385
386        // Use client_async with the stream (plain or TLS)
387        client_async(request, maybe_tls_stream)
388            .await
389            .map(|resp| resp.0.split())
390    }
391
392    /// Reconnect with server.
393    ///
394    /// Make a new connection with server. Use the new read and write halves
395    /// to update self writer and read and heartbeat tasks.
396    ///
397    /// For stream-based clients (created via `connect_stream`), reconnection is disabled
398    /// because the reader is owned by the caller and cannot be replaced. Stream users
399    /// should handle disconnections by creating a new connection.
400    ///
401    /// # Errors
402    ///
403    /// Returns an error if:
404    /// - The reconnection attempt times out.
405    /// - The connection to the server fails.
406    pub async fn reconnect(&mut self) -> Result<(), Error> {
407        log::debug!("Reconnecting");
408
409        if self.is_stream_mode {
410            log::warn!(
411                "Auto-reconnect disabled for stream-based WebSocket client; \
412                stream users must manually reconnect by creating a new connection"
413            );
414            // Transition to CLOSED state to stop reconnection attempts
415            self.connection_mode
416                .store(ConnectionMode::Closed.as_u8(), Ordering::SeqCst);
417            return Ok(());
418        }
419
420        if ConnectionMode::from_atomic(&self.connection_mode).is_disconnect() {
421            log::debug!("Reconnect aborted due to disconnect state");
422            return Ok(());
423        }
424
425        tokio::time::timeout(self.reconnect_timeout, async {
426            // Attempt to connect; abort early if a disconnect was requested
427            let (new_writer, reader) =
428                Self::connect_with_server(&self.config.url, self.config.headers.clone()).await?;
429
430            if ConnectionMode::from_atomic(&self.connection_mode).is_disconnect() {
431                log::debug!("Reconnect aborted mid-flight (after connect)");
432                return Ok(());
433            }
434
435            // Use a oneshot channel to synchronize with the writer task.
436            // We must verify that the buffer was successfully drained before transitioning to ACTIVE
437            // to prevent silent message loss if the new connection drops immediately.
438            let (tx, rx) = tokio::sync::oneshot::channel();
439            if let Err(e) = self.writer_tx.send(WriterCommand::Update(new_writer, tx)) {
440                log::error!("{e}");
441                return Err(Error::Io(std::io::Error::new(
442                    std::io::ErrorKind::BrokenPipe,
443                    format!("Failed to send update command: {e}"),
444                )));
445            }
446
447            // Wait for writer to confirm it has drained the buffer
448            match rx.await {
449                Ok(true) => log::debug!("Writer confirmed buffer drain success"),
450                Ok(false) => {
451                    log::warn!("Writer failed to drain buffer, aborting reconnect");
452                    // Return error to trigger retry logic in controller
453                    return Err(Error::Io(std::io::Error::other(
454                        "Failed to drain reconnection buffer",
455                    )));
456                }
457                Err(e) => {
458                    log::error!("Writer dropped update channel: {e}");
459                    return Err(Error::Io(std::io::Error::new(
460                        std::io::ErrorKind::BrokenPipe,
461                        "Writer task dropped response channel",
462                    )));
463                }
464            }
465
466            // Delay before closing connection
467            tokio::time::sleep(Duration::from_millis(GRACEFUL_SHUTDOWN_DELAY_MS)).await;
468
469            if ConnectionMode::from_atomic(&self.connection_mode).is_disconnect() {
470                log::debug!("Reconnect aborted mid-flight (after delay)");
471                return Ok(());
472            }
473
474            if let Some(ref read_task) = self.read_task.take()
475                && !read_task.is_finished()
476            {
477                read_task.abort();
478                log_task_aborted("read");
479            }
480
481            // Atomically transition from Reconnect to Active
482            // This prevents race condition where disconnect could be requested between check and store
483            if self
484                .connection_mode
485                .compare_exchange(
486                    ConnectionMode::Reconnect.as_u8(),
487                    ConnectionMode::Active.as_u8(),
488                    Ordering::SeqCst,
489                    Ordering::SeqCst,
490                )
491                .is_err()
492            {
493                log::debug!("Reconnect aborted (state changed during reconnect)");
494                return Ok(());
495            }
496
497            self.read_task = if self.message_handler.is_some() {
498                Some(Self::spawn_message_handler_task(
499                    self.connection_mode.clone(),
500                    self.state_notify.clone(),
501                    reader,
502                    self.message_handler.as_ref(),
503                    self.ping_handler.as_ref(),
504                    self.config.idle_timeout_ms,
505                ))
506            } else {
507                None
508            };
509
510            log::debug!("Reconnect succeeded");
511            Ok(())
512        })
513        .await
514        .map_err(|_| {
515            Error::Io(std::io::Error::new(
516                std::io::ErrorKind::TimedOut,
517                format!(
518                    "reconnection timed out after {}s",
519                    self.reconnect_timeout.as_secs_f64()
520                ),
521            ))
522        })?
523    }
524
525    /// Check if the client is still alive.
526    ///
527    /// Returns `true` if both the read and write tasks are still running.
528    /// There may be some delay between the connection closing and the
529    /// client detecting it.
530    #[inline]
531    #[must_use]
532    pub fn is_alive(&self) -> bool {
533        match &self.read_task {
534            Some(read_task) => !read_task.is_finished() && !self.write_task.is_finished(),
535            None => !self.write_task.is_finished(),
536        }
537    }
538
539    fn spawn_message_handler_task(
540        connection_state: Arc<AtomicU8>,
541        state_notify: Arc<tokio::sync::Notify>,
542        mut reader: MessageReader,
543        message_handler: Option<&MessageHandler>,
544        ping_handler: Option<&PingHandler>,
545        idle_timeout_ms: Option<u64>,
546    ) -> tokio::task::JoinHandle<()> {
547        log::debug!("Started message handler task 'read'");
548
549        let check_interval = Duration::from_millis(CONNECTION_STATE_CHECK_INTERVAL_MS);
550        let idle_timeout = idle_timeout_ms.map(Duration::from_millis);
551
552        // Clone Arc handlers for the async task
553        let message_handler = message_handler.cloned();
554        let ping_handler = ping_handler.cloned();
555
556        tokio::task::spawn(async move {
557            let mut last_data_time = tokio::time::Instant::now();
558
559            loop {
560                if !ConnectionMode::from_atomic(&connection_state).is_active() {
561                    break;
562                }
563
564                match tokio::time::timeout(check_interval, reader.next()).await {
565                    Ok(Some(Ok(Message::Binary(data)))) => {
566                        log::trace!("Received message <binary> {} bytes", data.len());
567                        last_data_time = tokio::time::Instant::now();
568
569                        if let Some(ref handler) = message_handler {
570                            handler(Message::Binary(data));
571                        }
572                    }
573                    Ok(Some(Ok(Message::Text(data)))) => {
574                        log::trace!("Received message: {data}");
575                        last_data_time = tokio::time::Instant::now();
576
577                        if let Some(ref handler) = message_handler {
578                            handler(Message::Text(data));
579                        }
580                    }
581                    Ok(Some(Ok(Message::Ping(ping_data)))) => {
582                        log::trace!("Received ping: {ping_data:?}");
583                        last_data_time = tokio::time::Instant::now();
584
585                        if let Some(ref handler) = ping_handler {
586                            handler(ping_data.to_vec());
587                        }
588                    }
589                    Ok(Some(Ok(Message::Pong(_)))) => {
590                        log::trace!("Received pong");
591                        last_data_time = tokio::time::Instant::now();
592                    }
593                    Ok(Some(Ok(Message::Close(_)))) => {
594                        log::debug!("Received close message - terminating");
595                        break;
596                    }
597                    Ok(Some(Ok(_))) => (),
598                    Ok(Some(Err(e))) => {
599                        log::error!("Received error message - terminating: {e}");
600                        break;
601                    }
602                    Ok(None) => {
603                        log::debug!("No message received - terminating");
604                        break;
605                    }
606                    Err(_) => {
607                        if let Some(timeout) = idle_timeout {
608                            let idle_duration = last_data_time.elapsed();
609                            if idle_duration >= timeout {
610                                log::warn!(
611                                    "Read idle timeout: no data received for {:.1}s",
612                                    idle_duration.as_secs_f64()
613                                );
614                                break;
615                            }
616                        }
617                    }
618                }
619            }
620
621            // Wake the controller immediately so it detects the dead read task
622            state_notify.notify_one();
623        })
624    }
625
626    /// Attempts to send all buffered messages after reconnection.
627    ///
628    /// Returns `true` if a send error occurred (caller should trigger reconnection).
629    /// Messages remain in buffer if send fails, preserving them for the next reconnection attempt.
630    async fn drain_reconnect_buffer(
631        buffer: &mut VecDeque<Message>,
632        writer: &mut MessageWriter,
633    ) -> bool {
634        if buffer.is_empty() {
635            return false;
636        }
637
638        let initial_buffer_len = buffer.len();
639        log::info!("Sending {initial_buffer_len} buffered messages after reconnection");
640
641        let mut send_error_occurred = false;
642
643        while let Some(buffered_msg) = buffer.front() {
644            // Clone message before attempting send (to keep in buffer if send fails)
645            let msg_to_send = buffered_msg.clone();
646
647            if let Err(e) = writer.send(msg_to_send).await {
648                log::error!(
649                    "Failed to send buffered message after reconnection: {e}, {} messages remain in buffer",
650                    buffer.len()
651                );
652                send_error_occurred = true;
653                break; // Stop processing buffer, remaining messages preserved for next reconnection
654            }
655
656            // Only remove from buffer after successful send
657            buffer.pop_front();
658        }
659
660        if buffer.is_empty() {
661            log::info!("Successfully sent all {initial_buffer_len} buffered messages");
662        }
663
664        send_error_occurred
665    }
666
667    fn spawn_write_task(
668        connection_state: Arc<AtomicU8>,
669        state_notify: Arc<tokio::sync::Notify>,
670        writer: MessageWriter,
671        mut writer_rx: tokio::sync::mpsc::UnboundedReceiver<WriterCommand>,
672    ) -> tokio::task::JoinHandle<()> {
673        log_task_started("write");
674
675        // Interval between checking the connection mode
676        let check_interval = Duration::from_millis(CONNECTION_STATE_CHECK_INTERVAL_MS);
677
678        tokio::task::spawn(async move {
679            let mut active_writer = writer;
680            // Buffer for messages received during reconnection
681            // VecDeque for efficient pop_front() operations
682            let mut reconnect_buffer: VecDeque<Message> = VecDeque::new();
683
684            loop {
685                match ConnectionMode::from_atomic(&connection_state) {
686                    ConnectionMode::Disconnect => {
687                        // Log any buffered messages that will be lost
688                        if !reconnect_buffer.is_empty() {
689                            log::warn!(
690                                "Discarding {} buffered messages due to disconnect",
691                                reconnect_buffer.len()
692                            );
693                            reconnect_buffer.clear();
694                        }
695
696                        // Attempt to close the writer gracefully before exiting,
697                        // we ignore any error as the writer may already be closed.
698                        _ = tokio::time::timeout(
699                            Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS),
700                            active_writer.close(),
701                        )
702                        .await;
703                        break;
704                    }
705                    ConnectionMode::Closed => {
706                        // Log any buffered messages that will be lost
707                        if !reconnect_buffer.is_empty() {
708                            log::warn!(
709                                "Discarding {} buffered messages due to closed connection",
710                                reconnect_buffer.len()
711                            );
712                            reconnect_buffer.clear();
713                        }
714                        break;
715                    }
716                    _ => {}
717                }
718
719                match tokio::time::timeout(check_interval, writer_rx.recv()).await {
720                    Ok(Some(msg)) => {
721                        // Re-check connection mode after receiving a message
722                        let mode = ConnectionMode::from_atomic(&connection_state);
723                        if matches!(mode, ConnectionMode::Disconnect | ConnectionMode::Closed) {
724                            break;
725                        }
726
727                        match msg {
728                            WriterCommand::Update(new_writer, tx) => {
729                                log::debug!("Received new writer");
730
731                                // Delay before closing connection
732                                tokio::time::sleep(Duration::from_millis(100)).await;
733
734                                // Attempt to close the writer gracefully on update,
735                                // we ignore any error as the writer may already be closed.
736                                _ = tokio::time::timeout(
737                                    Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS),
738                                    active_writer.close(),
739                                )
740                                .await;
741
742                                active_writer = new_writer;
743                                log::debug!("Updated writer");
744
745                                let send_error = Self::drain_reconnect_buffer(
746                                    &mut reconnect_buffer,
747                                    &mut active_writer,
748                                )
749                                .await;
750
751                                if let Err(e) = tx.send(!send_error) {
752                                    log::error!(
753                                        "Failed to report drain status to controller: {e:?}"
754                                    );
755                                }
756                            }
757                            WriterCommand::Send(msg) if mode.is_reconnect() => {
758                                // Buffer messages during reconnection instead of dropping them
759                                log::debug!(
760                                    "Buffering message during reconnection (buffer size: {})",
761                                    reconnect_buffer.len() + 1
762                                );
763                                reconnect_buffer.push_back(msg);
764                            }
765                            WriterCommand::Send(msg) => {
766                                if let Err(e) = active_writer.send(msg.clone()).await {
767                                    log::error!("Failed to send message: {e}");
768                                    log::warn!("Writer triggering reconnect");
769                                    reconnect_buffer.push_back(msg);
770                                    connection_state
771                                        .store(ConnectionMode::Reconnect.as_u8(), Ordering::SeqCst);
772                                    state_notify.notify_one();
773                                }
774                            }
775                        }
776                    }
777                    Ok(None) => {
778                        // Channel closed - writer task should terminate
779                        log::debug!("Writer channel closed, terminating writer task");
780                        break;
781                    }
782                    Err(_) => {
783                        // Timeout - just continue the loop
784                    }
785                }
786            }
787
788            // Attempt to close the writer gracefully before exiting,
789            // we ignore any error as the writer may already be closed.
790            _ = tokio::time::timeout(
791                Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS),
792                active_writer.close(),
793            )
794            .await;
795
796            log_task_stopped("write");
797        })
798    }
799
800    fn spawn_heartbeat_task(
801        connection_state: Arc<AtomicU8>,
802        heartbeat_secs: u64,
803        message: Option<String>,
804        writer_tx: tokio::sync::mpsc::UnboundedSender<WriterCommand>,
805    ) -> tokio::task::JoinHandle<()> {
806        log_task_started("heartbeat");
807
808        tokio::task::spawn(async move {
809            let interval = Duration::from_secs(heartbeat_secs);
810
811            loop {
812                tokio::time::sleep(interval).await;
813
814                match ConnectionMode::from_u8(connection_state.load(Ordering::SeqCst)) {
815                    ConnectionMode::Active => {
816                        let msg = match &message {
817                            Some(text) => WriterCommand::Send(Message::Text(text.clone().into())),
818                            None => WriterCommand::Send(Message::Ping(vec![].into())),
819                        };
820
821                        match writer_tx.send(msg) {
822                            Ok(()) => log::trace!("Sent heartbeat to writer task"),
823                            Err(e) => {
824                                log::error!("Failed to send heartbeat to writer task: {e}");
825                            }
826                        }
827                    }
828                    ConnectionMode::Reconnect => {}
829                    ConnectionMode::Disconnect | ConnectionMode::Closed => break,
830                }
831            }
832
833            log_task_stopped("heartbeat");
834        })
835    }
836}
837
838impl Drop for WebSocketClientInner {
839    fn drop(&mut self) {
840        // Delegate to explicit cleanup handler
841        self.clean_drop();
842    }
843}
844
845/// Cleanup on drop: aborts background tasks and clears handlers to break reference cycles.
846impl CleanDrop for WebSocketClientInner {
847    fn clean_drop(&mut self) {
848        if let Some(ref read_task) = self.read_task.take()
849            && !read_task.is_finished()
850        {
851            read_task.abort();
852            log_task_aborted("read");
853        }
854
855        if !self.write_task.is_finished() {
856            self.write_task.abort();
857            log_task_aborted("write");
858        }
859
860        if let Some(ref handle) = self.heartbeat_task.take()
861            && !handle.is_finished()
862        {
863            handle.abort();
864            log_task_aborted("heartbeat");
865        }
866
867        // Clear handlers to break potential reference cycles
868        self.message_handler = None;
869        self.ping_handler = None;
870    }
871}
872
873impl Debug for WebSocketClientInner {
874    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
875        f.debug_struct(stringify!(WebSocketClientInner))
876            .field("config", &self.config)
877            .field(
878                "connection_mode",
879                &ConnectionMode::from_atomic(&self.connection_mode),
880            )
881            .field("reconnect_timeout", &self.reconnect_timeout)
882            .field("is_stream_mode", &self.is_stream_mode)
883            .finish()
884    }
885}
886
887/// WebSocket client with automatic reconnection.
888///
889/// Handles connection state, callbacks, and rate limiting.
890/// See module docs for architecture details.
891#[cfg_attr(
892    feature = "python",
893    pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.network")
894)]
895#[cfg_attr(
896    feature = "python",
897    pyo3_stub_gen::derive::gen_stub_pyclass(module = "nautilus_trader.network")
898)]
899pub struct WebSocketClient {
900    pub(crate) controller_task: tokio::task::JoinHandle<()>,
901    pub(crate) connection_mode: Arc<AtomicU8>,
902    pub(crate) state_notify: Arc<tokio::sync::Notify>,
903    pub(crate) reconnect_timeout: Duration,
904    pub(crate) rate_limiter: Arc<RateLimiter<Ustr, MonotonicClock>>,
905    pub(crate) writer_tx: tokio::sync::mpsc::UnboundedSender<WriterCommand>,
906}
907
908impl Debug for WebSocketClient {
909    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
910        f.debug_struct(stringify!(WebSocketClient)).finish()
911    }
912}
913
914impl WebSocketClient {
915    /// Creates a websocket client in **stream mode** that returns a [`MessageReader`].
916    ///
917    /// Returns a stream that the caller owns and reads from directly. Automatic reconnection
918    /// is **disabled** because the reader cannot be replaced internally. On disconnection, the
919    /// client transitions to CLOSED state and the caller must manually reconnect by calling
920    /// `connect_stream` again.
921    ///
922    /// Use stream mode when you need custom reconnection logic, direct control over message
923    /// reading, or fine-grained backpressure handling.
924    ///
925    /// See [`WebSocketConfig`] documentation for comparison with handler mode.
926    ///
927    /// # Errors
928    ///
929    /// Returns an error if the connection cannot be established.
930    #[allow(clippy::too_many_arguments)]
931    pub async fn connect_stream(
932        config: WebSocketConfig,
933        keyed_quotas: Vec<(String, Quota)>,
934        default_quota: Option<Quota>,
935        post_reconnect: Option<Arc<dyn Fn() + Send + Sync>>,
936    ) -> Result<(MessageReader, Self), Error> {
937        install_cryptographic_provider();
938
939        // Create a single connection and split it, respecting configured headers
940        let (writer, reader) =
941            WebSocketClientInner::connect_with_server(&config.url, config.headers.clone()).await?;
942
943        // Create inner without connecting (we'll provide the writer)
944        let inner = WebSocketClientInner::new_with_writer(config, writer).await?;
945
946        let connection_mode = inner.connection_mode.clone();
947        let state_notify = inner.state_notify.clone();
948        let reconnect_timeout = inner.reconnect_timeout;
949        let keyed_quotas = keyed_quotas
950            .into_iter()
951            .map(|(key, quota)| (Ustr::from(&key), quota))
952            .collect();
953        let rate_limiter = Arc::new(RateLimiter::new_with_quota(default_quota, keyed_quotas));
954        let writer_tx = inner.writer_tx.clone();
955
956        let controller_task = Self::spawn_controller_task(
957            inner,
958            connection_mode.clone(),
959            state_notify.clone(),
960            post_reconnect,
961        );
962
963        Ok((
964            reader,
965            Self {
966                controller_task,
967                connection_mode,
968                state_notify,
969                reconnect_timeout,
970                rate_limiter,
971                writer_tx,
972            },
973        ))
974    }
975
976    /// Creates a websocket client in **handler mode** with automatic reconnection.
977    ///
978    /// The handler is called for each incoming message on an internal task.
979    /// Automatic reconnection is **enabled** with exponential backoff. On disconnection,
980    /// the client automatically attempts to reconnect and replaces the internal reader
981    /// (the handler continues working seamlessly).
982    ///
983    /// Use handler mode for simplified connection management, automatic reconnection, Python
984    /// bindings, or callback-based message handling.
985    ///
986    /// See [`WebSocketConfig`] documentation for comparison with stream mode.
987    ///
988    /// # Errors
989    ///
990    /// Returns an error if:
991    /// - The connection cannot be established.
992    /// - `message_handler` is `None` (use `connect_stream` instead).
993    pub async fn connect(
994        config: WebSocketConfig,
995        message_handler: Option<MessageHandler>,
996        ping_handler: Option<PingHandler>,
997        post_reconnection: Option<Arc<dyn Fn() + Send + Sync>>,
998        keyed_quotas: Vec<(String, Quota)>,
999        default_quota: Option<Quota>,
1000    ) -> Result<Self, Error> {
1001        // Validate that handler mode has a message handler
1002        if message_handler.is_none() {
1003            return Err(Error::Io(std::io::Error::new(
1004                std::io::ErrorKind::InvalidInput,
1005                "Handler mode requires message_handler to be set. Use connect_stream() for stream mode without a handler.",
1006            )));
1007        }
1008
1009        log::debug!("Connecting");
1010        let inner =
1011            WebSocketClientInner::connect_url(config, message_handler, ping_handler).await?;
1012        let connection_mode = inner.connection_mode.clone();
1013        let state_notify = inner.state_notify.clone();
1014        let writer_tx = inner.writer_tx.clone();
1015        let reconnect_timeout = inner.reconnect_timeout;
1016
1017        let controller_task = Self::spawn_controller_task(
1018            inner,
1019            connection_mode.clone(),
1020            state_notify.clone(),
1021            post_reconnection,
1022        );
1023
1024        let keyed_quotas = keyed_quotas
1025            .into_iter()
1026            .map(|(key, quota)| (Ustr::from(&key), quota))
1027            .collect();
1028        let rate_limiter = Arc::new(RateLimiter::new_with_quota(default_quota, keyed_quotas));
1029
1030        Ok(Self {
1031            controller_task,
1032            connection_mode,
1033            state_notify,
1034            reconnect_timeout,
1035            rate_limiter,
1036            writer_tx,
1037        })
1038    }
1039
1040    /// Returns the current connection mode.
1041    #[must_use]
1042    pub fn connection_mode(&self) -> ConnectionMode {
1043        ConnectionMode::from_atomic(&self.connection_mode)
1044    }
1045
1046    /// Returns a clone of the connection mode atomic for external state tracking.
1047    ///
1048    /// This allows adapter clients to track connection state across reconnections
1049    /// without message-passing delays.
1050    #[must_use]
1051    pub fn connection_mode_atomic(&self) -> Arc<AtomicU8> {
1052        Arc::clone(&self.connection_mode)
1053    }
1054
1055    /// Check if the client connection is active.
1056    ///
1057    /// Returns `true` if the client is connected and has not been signalled to disconnect.
1058    /// The client will automatically retry connection based on its configuration.
1059    #[inline]
1060    #[must_use]
1061    pub fn is_active(&self) -> bool {
1062        self.connection_mode().is_active()
1063    }
1064
1065    /// Check if the client is disconnected.
1066    #[must_use]
1067    pub fn is_disconnected(&self) -> bool {
1068        self.controller_task.is_finished()
1069    }
1070
1071    /// Check if the client is reconnecting.
1072    ///
1073    /// Returns `true` if the client lost connection and is attempting to reestablish it.
1074    /// The client will automatically retry connection based on its configuration.
1075    #[inline]
1076    #[must_use]
1077    pub fn is_reconnecting(&self) -> bool {
1078        self.connection_mode().is_reconnect()
1079    }
1080
1081    /// Check if the client is disconnecting.
1082    ///
1083    /// Returns `true` if the client is in disconnect mode.
1084    #[inline]
1085    #[must_use]
1086    pub fn is_disconnecting(&self) -> bool {
1087        self.connection_mode().is_disconnect()
1088    }
1089
1090    /// Check if the client is closed.
1091    ///
1092    /// Returns `true` if the client has been explicitly disconnected or reached
1093    /// maximum reconnection attempts. In this state, the client cannot be reused
1094    /// and a new client must be created for further connections.
1095    #[inline]
1096    #[must_use]
1097    pub fn is_closed(&self) -> bool {
1098        self.connection_mode().is_closed()
1099    }
1100
1101    /// Checks whether the connection is in a terminal state (disconnecting or closed).
1102    ///
1103    /// Single atomic load to fail fast before rate limiting or waiting.
1104    #[inline]
1105    fn check_not_terminal(&self) -> Result<(), SendError> {
1106        match self.connection_mode() {
1107            ConnectionMode::Disconnect | ConnectionMode::Closed => Err(SendError::Closed),
1108            _ => Ok(()),
1109        }
1110    }
1111
1112    /// Waits for rate limiter quota, aborting early if connection enters a terminal state.
1113    async fn await_rate_limit_or_closed(&self, keys: Option<&[Ustr]>) -> Result<(), SendError> {
1114        const CHECK_INTERVAL_MS: u64 = 100;
1115
1116        tokio::select! {
1117            () = self.rate_limiter.await_keys_ready(keys) => Ok(()),
1118            () = async {
1119                loop {
1120                    let notified = self.state_notify.notified();
1121
1122                    if matches!(self.connection_mode(), ConnectionMode::Disconnect | ConnectionMode::Closed) {
1123                        break;
1124                    }
1125                    tokio::select! {
1126                        () = notified => {}
1127                        () = tokio::time::sleep(Duration::from_millis(CHECK_INTERVAL_MS)) => {}
1128                    }
1129                }
1130            } => Err(SendError::Closed),
1131        }
1132    }
1133
1134    /// Waits for the client to become active before sending.
1135    ///
1136    /// Uses `state_notify` for event-driven wakeup so sends resume immediately
1137    /// after reconnection completes. A fallback interval guards against missed
1138    /// notifications.
1139    async fn wait_for_active(&self) -> Result<(), SendError> {
1140        const FALLBACK_INTERVAL_MS: u64 = 100;
1141
1142        let mode = self.connection_mode();
1143        if mode.is_active() {
1144            return Ok(());
1145        }
1146
1147        if matches!(mode, ConnectionMode::Disconnect | ConnectionMode::Closed) {
1148            return Err(SendError::Closed);
1149        }
1150
1151        log::debug!("Waiting for client to become ACTIVE before sending...");
1152
1153        let fallback_interval = Duration::from_millis(FALLBACK_INTERVAL_MS);
1154
1155        tokio::time::timeout(self.reconnect_timeout, async {
1156            loop {
1157                // Register notification interest BEFORE checking state to prevent
1158                // a race where the state changes between our check and the await
1159                let notified = self.state_notify.notified();
1160
1161                let mode = self.connection_mode();
1162                if mode.is_active() {
1163                    return Ok(());
1164                }
1165
1166                if matches!(mode, ConnectionMode::Disconnect | ConnectionMode::Closed) {
1167                    return Err(());
1168                }
1169
1170                tokio::select! {
1171                    () = notified => {}
1172                    () = tokio::time::sleep(fallback_interval) => {}
1173                }
1174            }
1175        })
1176        .await
1177        .map_err(|_| SendError::Timeout)?
1178        .map_err(|()| SendError::Closed)
1179    }
1180
1181    /// Signals that the caller's reader has observed EOF or a fatal error.
1182    ///
1183    /// In stream mode the controller has no visibility into the caller-owned reader.
1184    /// Call this method when `reader.next().await` returns `None` or an unrecoverable
1185    /// error so the controller transitions to `Closed` and dependent tasks shut down.
1186    ///
1187    /// For peer-initiated close frames (`Message::Close`), use [`disconnect`](Self::disconnect)
1188    /// instead so the writer can send the close reply before shutting down.
1189    ///
1190    /// This is a no-op if the connection is already closed or disconnecting.
1191    pub fn notify_closed(&self) {
1192        let mode = self.connection_mode();
1193        if mode.is_disconnect() || mode.is_closed() {
1194            return;
1195        }
1196
1197        log::debug!("Stream reader signalled EOF, transitioning to CLOSED");
1198
1199        self.connection_mode
1200            .store(ConnectionMode::Closed.as_u8(), Ordering::SeqCst);
1201        self.state_notify.notify_waiters();
1202    }
1203
1204    /// Set disconnect mode to true.
1205    ///
1206    /// Controller task will periodically check the disconnect mode
1207    /// and shutdown the client if it is alive
1208    pub async fn disconnect(&self) {
1209        log::debug!("Disconnecting");
1210        self.connection_mode
1211            .store(ConnectionMode::Disconnect.as_u8(), Ordering::SeqCst);
1212        self.state_notify.notify_waiters();
1213
1214        if tokio::time::timeout(Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS), async {
1215            while !self.is_disconnected() {
1216                tokio::time::sleep(Duration::from_millis(CONNECTION_STATE_CHECK_INTERVAL_MS)).await;
1217            }
1218
1219            if !self.controller_task.is_finished() {
1220                self.controller_task.abort();
1221                log_task_aborted("controller");
1222            }
1223        })
1224        .await
1225            == Ok(())
1226        {
1227            log::debug!("Controller task finished");
1228        } else {
1229            log::error!("Timeout waiting for controller task to finish");
1230
1231            if !self.controller_task.is_finished() {
1232                self.controller_task.abort();
1233                log_task_aborted("controller");
1234            }
1235            self.connection_mode
1236                .store(ConnectionMode::Closed.as_u8(), Ordering::SeqCst);
1237        }
1238    }
1239
1240    /// Sends the given text `data` to the server.
1241    ///
1242    /// Returns `Ok(())` when the message is enqueued to the writer channel. This does NOT
1243    /// guarantee delivery: if a disconnect occurs concurrently, the writer task may drop the
1244    /// message. During reconnection, messages are buffered and replayed on the new connection.
1245    ///
1246    /// # Errors
1247    ///
1248    /// Returns a websocket error if unable to send.
1249    #[allow(unused_variables)]
1250    pub async fn send_text(&self, data: String, keys: Option<&[Ustr]>) -> Result<(), SendError> {
1251        self.check_not_terminal()?;
1252
1253        self.await_rate_limit_or_closed(keys).await?;
1254        self.wait_for_active().await?;
1255
1256        log::trace!("Sending text: {data:?}");
1257
1258        let msg = Message::Text(data.into());
1259        self.writer_tx
1260            .send(WriterCommand::Send(msg))
1261            .map_err(|e| SendError::BrokenPipe(e.to_string()))
1262    }
1263
1264    /// Sends a pong frame back to the server.
1265    ///
1266    /// # Errors
1267    ///
1268    /// Returns a websocket error if unable to send.
1269    pub async fn send_pong(&self, data: Vec<u8>) -> Result<(), SendError> {
1270        self.wait_for_active().await?;
1271
1272        log::trace!("Sending pong frame ({} bytes)", data.len());
1273
1274        let msg = Message::Pong(data.into());
1275        self.writer_tx
1276            .send(WriterCommand::Send(msg))
1277            .map_err(|e| SendError::BrokenPipe(e.to_string()))
1278    }
1279
1280    /// Sends the given bytes `data` to the server.
1281    ///
1282    /// Returns `Ok(())` when the message is enqueued to the writer channel. This does NOT
1283    /// guarantee delivery: if a disconnect occurs concurrently, the writer task may drop the
1284    /// message. During reconnection, messages are buffered and replayed on the new connection.
1285    ///
1286    /// # Errors
1287    ///
1288    /// Returns a websocket error if unable to send.
1289    #[allow(unused_variables)]
1290    pub async fn send_bytes(&self, data: Vec<u8>, keys: Option<&[Ustr]>) -> Result<(), SendError> {
1291        self.check_not_terminal()?;
1292
1293        self.await_rate_limit_or_closed(keys).await?;
1294        self.wait_for_active().await?;
1295
1296        log::trace!("Sending bytes: {data:?}");
1297
1298        let msg = Message::Binary(data.into());
1299        self.writer_tx
1300            .send(WriterCommand::Send(msg))
1301            .map_err(|e| SendError::BrokenPipe(e.to_string()))
1302    }
1303
1304    /// Sends a close message to the server.
1305    ///
1306    /// # Errors
1307    ///
1308    /// Returns a websocket error if unable to send.
1309    pub async fn send_close_message(&self) -> Result<(), SendError> {
1310        self.wait_for_active().await?;
1311
1312        let msg = Message::Close(None);
1313        self.writer_tx
1314            .send(WriterCommand::Send(msg))
1315            .map_err(|e| SendError::BrokenPipe(e.to_string()))
1316    }
1317
1318    fn spawn_controller_task(
1319        mut inner: WebSocketClientInner,
1320        connection_mode: Arc<AtomicU8>,
1321        state_notify: Arc<tokio::sync::Notify>,
1322        post_reconnection: Option<Arc<dyn Fn() + Send + Sync>>,
1323    ) -> tokio::task::JoinHandle<()> {
1324        const CONTROLLER_FALLBACK_INTERVAL_MS: u64 = 100;
1325
1326        tokio::task::spawn(async move {
1327            log_task_started("controller");
1328
1329            let fallback_interval = Duration::from_millis(CONTROLLER_FALLBACK_INTERVAL_MS);
1330
1331            loop {
1332                tokio::select! {
1333                    () = state_notify.notified() => {}
1334                    () = tokio::time::sleep(fallback_interval) => {}
1335                }
1336
1337                let mut mode = ConnectionMode::from_atomic(&connection_mode);
1338
1339                if mode.is_disconnect() {
1340                    log::debug!("Disconnecting");
1341
1342                    let timeout = Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS);
1343                    if tokio::time::timeout(timeout, async {
1344                        // Delay awaiting graceful shutdown
1345                        tokio::time::sleep(Duration::from_millis(GRACEFUL_SHUTDOWN_DELAY_MS)).await;
1346
1347                        if let Some(task) = &inner.read_task
1348                            && !task.is_finished()
1349                        {
1350                            task.abort();
1351                            log_task_aborted("read");
1352                        }
1353
1354                        if let Some(task) = &inner.heartbeat_task
1355                            && !task.is_finished()
1356                        {
1357                            task.abort();
1358                            log_task_aborted("heartbeat");
1359                        }
1360                    })
1361                    .await
1362                    .is_err()
1363                    {
1364                        log::error!("Shutdown timed out after {}s", timeout.as_secs());
1365                    }
1366
1367                    log::debug!("Closed");
1368                    break; // Controller finished
1369                }
1370
1371                if mode.is_closed() {
1372                    log::debug!("Connection closed");
1373                    break;
1374                }
1375
1376                if mode.is_active() && !inner.is_alive() {
1377                    let target = if inner.is_stream_mode {
1378                        ConnectionMode::Closed
1379                    } else {
1380                        ConnectionMode::Reconnect
1381                    };
1382
1383                    if connection_mode
1384                        .compare_exchange(
1385                            ConnectionMode::Active.as_u8(),
1386                            target.as_u8(),
1387                            Ordering::SeqCst,
1388                            Ordering::SeqCst,
1389                        )
1390                        .is_ok()
1391                    {
1392                        log::debug!("Detected dead connection, transitioning to {target:?}");
1393                    }
1394                    mode = ConnectionMode::from_atomic(&connection_mode);
1395                }
1396
1397                if mode.is_reconnect() {
1398                    // Check if max reconnection attempts exceeded
1399                    if let Some(max_attempts) = inner.reconnect_max_attempts
1400                        && inner.reconnection_attempt_count >= max_attempts
1401                    {
1402                        log::error!(
1403                            "Max reconnection attempts ({max_attempts}) exceeded, transitioning to CLOSED"
1404                        );
1405                        connection_mode.store(ConnectionMode::Closed.as_u8(), Ordering::SeqCst);
1406                        state_notify.notify_waiters();
1407                        break;
1408                    }
1409
1410                    inner.reconnection_attempt_count += 1;
1411                    log::debug!(
1412                        "Reconnection attempt {} of {}",
1413                        inner.reconnection_attempt_count,
1414                        inner
1415                            .reconnect_max_attempts
1416                            .map_or_else(|| "unlimited".to_string(), |m| m.to_string())
1417                    );
1418
1419                    // Race reconnect against disconnect notification
1420                    let reconnect_result = tokio::select! {
1421                        result = inner.reconnect() => Some(result),
1422                        () = async {
1423                            loop {
1424                                state_notify.notified().await;
1425
1426                                if ConnectionMode::from_atomic(&connection_mode).is_disconnect() {
1427                                    break;
1428                                }
1429                            }
1430                        } => None,
1431                    };
1432
1433                    match reconnect_result {
1434                        None => {
1435                            log::debug!("Reconnect interrupted by disconnect");
1436                        }
1437                        Some(Ok(())) => {
1438                            inner.backoff.reset();
1439                            inner.reconnection_attempt_count = 0;
1440
1441                            state_notify.notify_waiters();
1442
1443                            if ConnectionMode::from_atomic(&connection_mode).is_active() {
1444                                if let Some(ref handler) = inner.message_handler {
1445                                    let reconnected_msg =
1446                                        Message::Text(RECONNECTED.to_string().into());
1447                                    handler(reconnected_msg);
1448                                    log::debug!("Sent reconnected message to handler");
1449                                }
1450
1451                                // TODO: Retain this legacy callback for use from Python
1452                                if let Some(ref callback) = post_reconnection {
1453                                    callback();
1454                                    log::debug!("Called `post_reconnection` handler");
1455                                }
1456
1457                                log::debug!("Reconnected successfully");
1458                            } else {
1459                                log::debug!(
1460                                    "Skipping post_reconnection handlers due to disconnect state"
1461                                );
1462                            }
1463                        }
1464                        Some(Err(e)) => {
1465                            let duration = inner.backoff.next_duration();
1466                            log::warn!(
1467                                "Reconnect attempt {} failed: {e}",
1468                                inner.reconnection_attempt_count
1469                            );
1470
1471                            if !duration.is_zero() {
1472                                log::warn!("Backing off for {}s...", duration.as_secs_f64());
1473                                // Race backoff sleep against disconnect
1474                                tokio::select! {
1475                                    () = tokio::time::sleep(duration) => {}
1476                                    () = async {
1477                                        loop {
1478                                            state_notify.notified().await;
1479
1480                                            if ConnectionMode::from_atomic(&connection_mode).is_disconnect() {
1481                                                break;
1482                                            }
1483                                        }
1484                                    } => {
1485                                        log::debug!("Backoff interrupted by disconnect");
1486                                    }
1487                                }
1488                            }
1489                        }
1490                    }
1491                }
1492            }
1493            inner
1494                .connection_mode
1495                .store(ConnectionMode::Closed.as_u8(), Ordering::SeqCst);
1496
1497            log_task_stopped("controller");
1498        })
1499    }
1500}
1501
1502// Abort controller task on drop to clean up background tasks
1503impl Drop for WebSocketClient {
1504    fn drop(&mut self) {
1505        if !self.controller_task.is_finished() {
1506            self.controller_task.abort();
1507            log_task_aborted("controller");
1508        }
1509    }
1510}
1511
1512#[cfg(test)]
1513#[cfg(not(feature = "turmoil"))]
1514#[cfg(target_os = "linux")] // Only run network tests on Linux (CI stability)
1515mod tests {
1516    use std::{num::NonZeroU32, sync::Arc};
1517
1518    use futures_util::{SinkExt, StreamExt};
1519    use tokio::{
1520        net::TcpListener,
1521        task::{self, JoinHandle},
1522    };
1523    use tokio_tungstenite::{
1524        accept_hdr_async,
1525        tungstenite::{
1526            handshake::server::{self, Callback},
1527            http::HeaderValue,
1528        },
1529    };
1530
1531    use crate::{
1532        ratelimiter::quota::Quota,
1533        websocket::{WebSocketClient, WebSocketConfig},
1534    };
1535
1536    struct TestServer {
1537        task: JoinHandle<()>,
1538        port: u16,
1539    }
1540
1541    #[derive(Debug, Clone)]
1542    struct TestCallback {
1543        key: String,
1544        value: HeaderValue,
1545    }
1546
1547    impl Callback for TestCallback {
1548        #[allow(clippy::panic_in_result_fn)]
1549        fn on_request(
1550            self,
1551            request: &server::Request,
1552            response: server::Response,
1553        ) -> Result<server::Response, server::ErrorResponse> {
1554            let _ = response;
1555            let value = request.headers().get(&self.key);
1556            assert!(value.is_some());
1557
1558            if let Some(value) = request.headers().get(&self.key) {
1559                assert_eq!(value, self.value);
1560            }
1561
1562            Ok(response)
1563        }
1564    }
1565
1566    impl TestServer {
1567        async fn setup() -> Self {
1568            let server = TcpListener::bind("127.0.0.1:0").await.unwrap();
1569            let port = TcpListener::local_addr(&server).unwrap().port();
1570
1571            let header_key = "test".to_string();
1572            let header_value = "test".to_string();
1573
1574            let test_call_back = TestCallback {
1575                key: header_key,
1576                value: HeaderValue::from_str(&header_value).unwrap(),
1577            };
1578
1579            let task = task::spawn(async move {
1580                // Keep accepting connections
1581                loop {
1582                    let (conn, _) = server.accept().await.unwrap();
1583                    let mut websocket = accept_hdr_async(conn, test_call_back.clone())
1584                        .await
1585                        .unwrap();
1586
1587                    task::spawn(async move {
1588                        while let Some(Ok(msg)) = websocket.next().await {
1589                            match msg {
1590                                tokio_tungstenite::tungstenite::protocol::Message::Text(txt)
1591                                    if txt == "close-now" =>
1592                                {
1593                                    log::debug!("Forcibly closing from server side");
1594                                    // This sends a close frame, then stops reading
1595                                    let _ = websocket.close(None).await;
1596                                    break;
1597                                }
1598                                // Echo text/binary frames
1599                                tokio_tungstenite::tungstenite::protocol::Message::Text(_)
1600                                | tokio_tungstenite::tungstenite::protocol::Message::Binary(_) => {
1601                                    if websocket.send(msg).await.is_err() {
1602                                        break;
1603                                    }
1604                                }
1605                                // If the client closes, we also break
1606                                tokio_tungstenite::tungstenite::protocol::Message::Close(
1607                                    _frame,
1608                                ) => {
1609                                    let _ = websocket.close(None).await;
1610                                    break;
1611                                }
1612                                // Ignore pings/pongs
1613                                _ => {}
1614                            }
1615                        }
1616                    });
1617                }
1618            });
1619
1620            Self { task, port }
1621        }
1622    }
1623
1624    impl Drop for TestServer {
1625        fn drop(&mut self) {
1626            self.task.abort();
1627        }
1628    }
1629
1630    async fn setup_test_client(port: u16) -> WebSocketClient {
1631        let config = WebSocketConfig {
1632            url: format!("ws://127.0.0.1:{port}"),
1633            headers: vec![("test".into(), "test".into())],
1634            heartbeat: None,
1635            heartbeat_msg: None,
1636            reconnect_timeout_ms: None,
1637            reconnect_delay_initial_ms: None,
1638            reconnect_backoff_factor: None,
1639            reconnect_delay_max_ms: None,
1640            reconnect_jitter_ms: None,
1641            reconnect_max_attempts: None,
1642            idle_timeout_ms: None,
1643        };
1644        WebSocketClient::connect(config, Some(Arc::new(|_| {})), None, None, vec![], None)
1645            .await
1646            .expect("Failed to connect")
1647    }
1648
1649    #[tokio::test]
1650    async fn test_websocket_basic() {
1651        let server = TestServer::setup().await;
1652        let client = setup_test_client(server.port).await;
1653
1654        assert!(!client.is_disconnected());
1655
1656        client.disconnect().await;
1657        assert!(client.is_disconnected());
1658    }
1659
1660    #[tokio::test]
1661    async fn test_websocket_heartbeat() {
1662        let server = TestServer::setup().await;
1663        let client = setup_test_client(server.port).await;
1664
1665        // Wait ~3s => server should see multiple "ping"
1666        tokio::time::sleep(std::time::Duration::from_secs(3)).await;
1667
1668        // Cleanup
1669        client.disconnect().await;
1670        assert!(client.is_disconnected());
1671    }
1672
1673    #[tokio::test]
1674    async fn test_websocket_reconnect_exhausted() {
1675        let config = WebSocketConfig {
1676            url: "ws://127.0.0.1:9997".into(), // <-- No server
1677            headers: vec![],
1678            heartbeat: None,
1679            heartbeat_msg: None,
1680            reconnect_timeout_ms: None,
1681            reconnect_delay_initial_ms: None,
1682            reconnect_backoff_factor: None,
1683            reconnect_delay_max_ms: None,
1684            reconnect_jitter_ms: None,
1685            reconnect_max_attempts: None,
1686            idle_timeout_ms: None,
1687        };
1688        let res =
1689            WebSocketClient::connect(config, Some(Arc::new(|_| {})), None, None, vec![], None)
1690                .await;
1691        assert!(res.is_err(), "Should fail quickly with no server");
1692    }
1693
1694    #[tokio::test]
1695    async fn test_websocket_forced_close_reconnect() {
1696        let server = TestServer::setup().await;
1697        let client = setup_test_client(server.port).await;
1698
1699        // 1) Send normal message
1700        client.send_text("Hello".into(), None).await.unwrap();
1701
1702        // 2) Trigger forced close from server
1703        client.send_text("close-now".into(), None).await.unwrap();
1704
1705        // 3) Wait a bit => read loop sees close => reconnect
1706        tokio::time::sleep(std::time::Duration::from_secs(1)).await;
1707
1708        // Confirm not disconnected
1709        assert!(!client.is_disconnected());
1710
1711        // Cleanup
1712        client.disconnect().await;
1713        assert!(client.is_disconnected());
1714    }
1715
1716    #[tokio::test]
1717    async fn test_rate_limiter() {
1718        let server = TestServer::setup().await;
1719        let quota = Quota::per_second(NonZeroU32::new(2).unwrap()).unwrap();
1720
1721        let config = WebSocketConfig {
1722            url: format!("ws://127.0.0.1:{}", server.port),
1723            headers: vec![("test".into(), "test".into())],
1724            heartbeat: None,
1725            heartbeat_msg: None,
1726            reconnect_timeout_ms: None,
1727            reconnect_delay_initial_ms: None,
1728            reconnect_backoff_factor: None,
1729            reconnect_delay_max_ms: None,
1730            reconnect_jitter_ms: None,
1731            reconnect_max_attempts: None,
1732            idle_timeout_ms: None,
1733        };
1734
1735        let client = WebSocketClient::connect(
1736            config,
1737            Some(Arc::new(|_| {})),
1738            None,
1739            None,
1740            vec![("default".into(), quota)],
1741            None,
1742        )
1743        .await
1744        .unwrap();
1745
1746        // First 2 should succeed
1747        client.send_text("test1".into(), None).await.unwrap();
1748        client.send_text("test2".into(), None).await.unwrap();
1749
1750        // Third should error
1751        client.send_text("test3".into(), None).await.unwrap();
1752
1753        // Cleanup
1754        client.disconnect().await;
1755        assert!(client.is_disconnected());
1756    }
1757
1758    #[tokio::test]
1759    async fn test_concurrent_writers() {
1760        let server = TestServer::setup().await;
1761        let client = Arc::new(setup_test_client(server.port).await);
1762
1763        let mut handles = vec![];
1764        for i in 0..10 {
1765            let client = client.clone();
1766            handles.push(task::spawn(async move {
1767                client.send_text(format!("test{i}"), None).await.unwrap();
1768            }));
1769        }
1770
1771        for handle in handles {
1772            handle.await.unwrap();
1773        }
1774
1775        // Cleanup
1776        client.disconnect().await;
1777        assert!(client.is_disconnected());
1778    }
1779}
1780
1781#[cfg(test)]
1782#[cfg(not(feature = "turmoil"))]
1783mod rust_tests {
1784    use futures_util::{SinkExt, StreamExt};
1785    use nautilus_common::testing::wait_until_async;
1786    use rstest::rstest;
1787    use tokio::{
1788        net::TcpListener,
1789        task,
1790        time::{Duration, sleep},
1791    };
1792    use tokio_tungstenite::accept_async;
1793
1794    use super::*;
1795    use crate::websocket::types::channel_message_handler;
1796
1797    #[rstest]
1798    #[tokio::test]
1799    async fn test_reconnect_then_disconnect() {
1800        // Bind an ephemeral port
1801        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1802        let port = listener.local_addr().unwrap().port();
1803
1804        // Server task: accept one ws connection then close it
1805        let server = task::spawn(async move {
1806            let (stream, _) = listener.accept().await.unwrap();
1807            let ws = accept_async(stream).await.unwrap();
1808            drop(ws);
1809            // Keep alive briefly
1810            sleep(Duration::from_secs(1)).await;
1811        });
1812
1813        // Build a channel-based message handler for incoming messages (unused here)
1814        let (handler, _rx) = channel_message_handler();
1815
1816        // Configure client with short reconnect backoff
1817        let config = WebSocketConfig {
1818            url: format!("ws://127.0.0.1:{port}"),
1819            headers: vec![],
1820            heartbeat: None,
1821            heartbeat_msg: None,
1822            reconnect_timeout_ms: Some(1_000),
1823            reconnect_delay_initial_ms: Some(50),
1824            reconnect_delay_max_ms: Some(100),
1825            reconnect_backoff_factor: Some(1.0),
1826            reconnect_jitter_ms: Some(0),
1827            reconnect_max_attempts: None,
1828            idle_timeout_ms: None,
1829        };
1830
1831        // Connect the client
1832        let client = WebSocketClient::connect(config, Some(handler), None, None, vec![], None)
1833            .await
1834            .unwrap();
1835
1836        // Allow server to drop connection and client to detect
1837        sleep(Duration::from_millis(100)).await;
1838        // Now immediately disconnect the client
1839        client.disconnect().await;
1840        assert!(client.is_disconnected());
1841        server.abort();
1842    }
1843
1844    #[rstest]
1845    #[tokio::test]
1846    async fn test_reconnect_state_flips_when_reader_stops() {
1847        // Bind an ephemeral port and accept a single websocket connection which we drop.
1848        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1849        let port = listener.local_addr().unwrap().port();
1850
1851        let server = task::spawn(async move {
1852            if let Ok((stream, _)) = listener.accept().await
1853                && let Ok(ws) = accept_async(stream).await
1854            {
1855                drop(ws);
1856            }
1857            sleep(Duration::from_millis(50)).await;
1858        });
1859
1860        let (handler, _rx) = channel_message_handler();
1861
1862        let config = WebSocketConfig {
1863            url: format!("ws://127.0.0.1:{port}"),
1864            headers: vec![],
1865            heartbeat: None,
1866            heartbeat_msg: None,
1867            reconnect_timeout_ms: Some(1_000),
1868            reconnect_delay_initial_ms: Some(50),
1869            reconnect_delay_max_ms: Some(100),
1870            reconnect_backoff_factor: Some(1.0),
1871            reconnect_jitter_ms: Some(0),
1872            reconnect_max_attempts: None,
1873            idle_timeout_ms: None,
1874        };
1875
1876        let client = WebSocketClient::connect(config, Some(handler), None, None, vec![], None)
1877            .await
1878            .unwrap();
1879
1880        tokio::time::timeout(Duration::from_secs(2), async {
1881            loop {
1882                if client.is_reconnecting() {
1883                    break;
1884                }
1885                tokio::time::sleep(Duration::from_millis(10)).await;
1886            }
1887        })
1888        .await
1889        .expect("client did not enter RECONNECT state");
1890
1891        client.disconnect().await;
1892        server.abort();
1893    }
1894
1895    #[rstest]
1896    #[tokio::test]
1897    async fn test_stream_mode_disables_auto_reconnect() {
1898        // Test that stream-based clients (created via connect_stream) set is_stream_mode flag
1899        // and that reconnect() transitions to CLOSED state for stream mode
1900        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1901        let port = listener.local_addr().unwrap().port();
1902
1903        let server = task::spawn(async move {
1904            if let Ok((stream, _)) = listener.accept().await
1905                && let Ok(_ws) = accept_async(stream).await
1906            {
1907                // Keep connection alive briefly
1908                sleep(Duration::from_millis(100)).await;
1909            }
1910        });
1911
1912        let config = WebSocketConfig {
1913            url: format!("ws://127.0.0.1:{port}"),
1914            headers: vec![],
1915            heartbeat: None,
1916            heartbeat_msg: None,
1917            reconnect_timeout_ms: Some(1_000),
1918            reconnect_delay_initial_ms: Some(50),
1919            reconnect_delay_max_ms: Some(100),
1920            reconnect_backoff_factor: Some(1.0),
1921            reconnect_jitter_ms: Some(0),
1922            reconnect_max_attempts: None,
1923            idle_timeout_ms: None,
1924        };
1925
1926        let (_reader, _client) = WebSocketClient::connect_stream(config, vec![], None, None)
1927            .await
1928            .unwrap();
1929
1930        // Note: We can't easily test the reconnect behavior from the outside since
1931        // the inner client is private. The key fix is that WebSocketClientInner
1932        // now has is_stream_mode=true for connect_stream, and reconnect() will
1933        // transition to CLOSED state instead of creating a new reader that gets dropped.
1934        // This is tested implicitly by the fact that stream users won't get stuck
1935        // in an infinite reconnect loop.
1936
1937        server.abort();
1938    }
1939
1940    #[rstest]
1941    #[tokio::test]
1942    async fn test_message_handler_mode_allows_auto_reconnect() {
1943        // Test that regular clients (with message handler) can auto-reconnect
1944        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1945        let port = listener.local_addr().unwrap().port();
1946
1947        let server = task::spawn(async move {
1948            // Accept first connection and close it
1949            if let Ok((stream, _)) = listener.accept().await
1950                && let Ok(ws) = accept_async(stream).await
1951            {
1952                drop(ws);
1953            }
1954            sleep(Duration::from_millis(50)).await;
1955        });
1956
1957        let (handler, _rx) = channel_message_handler();
1958
1959        let config = WebSocketConfig {
1960            url: format!("ws://127.0.0.1:{port}"),
1961            headers: vec![],
1962            heartbeat: None,
1963            heartbeat_msg: None,
1964            reconnect_timeout_ms: Some(1_000),
1965            reconnect_delay_initial_ms: Some(50),
1966            reconnect_delay_max_ms: Some(100),
1967            reconnect_backoff_factor: Some(1.0),
1968            reconnect_jitter_ms: Some(0),
1969            reconnect_max_attempts: None,
1970            idle_timeout_ms: None,
1971        };
1972
1973        let client = WebSocketClient::connect(config, Some(handler), None, None, vec![], None)
1974            .await
1975            .unwrap();
1976
1977        // Wait for the connection to be dropped and reconnection to be attempted
1978        tokio::time::timeout(Duration::from_secs(2), async {
1979            loop {
1980                if client.is_reconnecting() || client.is_closed() {
1981                    break;
1982                }
1983                tokio::time::sleep(Duration::from_millis(10)).await;
1984            }
1985        })
1986        .await
1987        .expect("client should attempt reconnection or close");
1988
1989        // Should either be reconnecting or closed (depending on timing)
1990        // The important thing is it's not staying active forever
1991        assert!(
1992            client.is_reconnecting() || client.is_closed(),
1993            "Client with message handler should attempt reconnection"
1994        );
1995
1996        client.disconnect().await;
1997        server.abort();
1998    }
1999
2000    #[rstest]
2001    #[tokio::test]
2002    async fn test_handler_mode_reconnect_with_new_connection() {
2003        // Test that handler mode successfully reconnects and messages continue flowing
2004        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
2005        let port = listener.local_addr().unwrap().port();
2006
2007        let server = task::spawn(async move {
2008            // First connection - accept and immediately close
2009            if let Ok((stream, _)) = listener.accept().await
2010                && let Ok(ws) = accept_async(stream).await
2011            {
2012                drop(ws);
2013            }
2014
2015            // Small delay to let client detect disconnection
2016            sleep(Duration::from_millis(100)).await;
2017
2018            // Second connection - accept, send a message, then keep alive
2019            if let Ok((stream, _)) = listener.accept().await
2020                && let Ok(mut ws) = accept_async(stream).await
2021            {
2022                use futures_util::SinkExt;
2023                let _ = ws
2024                    .send(Message::Text("reconnected".to_string().into()))
2025                    .await;
2026                sleep(Duration::from_secs(1)).await;
2027            }
2028        });
2029
2030        let (handler, mut rx) = channel_message_handler();
2031
2032        let config = WebSocketConfig {
2033            url: format!("ws://127.0.0.1:{port}"),
2034            headers: vec![],
2035            heartbeat: None,
2036            heartbeat_msg: None,
2037            reconnect_timeout_ms: Some(2_000),
2038            reconnect_delay_initial_ms: Some(50),
2039            reconnect_delay_max_ms: Some(200),
2040            reconnect_backoff_factor: Some(1.5),
2041            reconnect_jitter_ms: Some(10),
2042            reconnect_max_attempts: None,
2043            idle_timeout_ms: None,
2044        };
2045
2046        let client = WebSocketClient::connect(config, Some(handler), None, None, vec![], None)
2047            .await
2048            .unwrap();
2049
2050        // Wait for reconnection to happen and message to arrive
2051        let result = tokio::time::timeout(Duration::from_secs(5), async {
2052            loop {
2053                if let Ok(msg) = rx.try_recv()
2054                    && matches!(msg, Message::Text(ref text) if AsRef::<str>::as_ref(text) == "reconnected")
2055                {
2056                    return true;
2057                }
2058                tokio::time::sleep(Duration::from_millis(10)).await;
2059            }
2060        })
2061        .await;
2062
2063        assert!(
2064            result.is_ok(),
2065            "Should receive message after reconnection within timeout"
2066        );
2067
2068        client.disconnect().await;
2069        server.abort();
2070    }
2071
2072    #[rstest]
2073    #[tokio::test]
2074    async fn test_stream_mode_no_auto_reconnect() {
2075        // Test that stream mode does not automatically reconnect when connection is lost
2076        // The caller owns the reader and is responsible for detecting disconnection
2077        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
2078        let port = listener.local_addr().unwrap().port();
2079
2080        let server = task::spawn(async move {
2081            // Accept connection and send one message, then close
2082            if let Ok((stream, _)) = listener.accept().await
2083                && let Ok(mut ws) = accept_async(stream).await
2084            {
2085                use futures_util::SinkExt;
2086                let _ = ws.send(Message::Text("hello".to_string().into())).await;
2087                sleep(Duration::from_millis(50)).await;
2088                // Connection closes when ws is dropped
2089            }
2090        });
2091
2092        let config = WebSocketConfig {
2093            url: format!("ws://127.0.0.1:{port}"),
2094            headers: vec![],
2095            heartbeat: None,
2096            heartbeat_msg: None,
2097            reconnect_timeout_ms: Some(1_000),
2098            reconnect_delay_initial_ms: Some(50),
2099            reconnect_delay_max_ms: Some(100),
2100            reconnect_backoff_factor: Some(1.0),
2101            reconnect_jitter_ms: Some(0),
2102            reconnect_max_attempts: None,
2103            idle_timeout_ms: None,
2104        };
2105
2106        let (mut reader, client) = WebSocketClient::connect_stream(config, vec![], None, None)
2107            .await
2108            .unwrap();
2109
2110        // Initially active
2111        assert!(client.is_active(), "Client should start as active");
2112
2113        // Read the hello message
2114        let msg = reader.next().await;
2115        assert!(
2116            matches!(msg, Some(Ok(Message::Text(ref text))) if AsRef::<str>::as_ref(text) == "hello"),
2117            "Should receive initial message"
2118        );
2119
2120        // Read until connection closes (reader will return None or error)
2121        while let Some(msg) = reader.next().await {
2122            if msg.is_err() || matches!(msg, Ok(Message::Close(_))) {
2123                break;
2124            }
2125        }
2126
2127        // Controller cannot detect reader EOF (reader is owned by caller),
2128        // so the client stays ACTIVE until the caller signals.
2129        sleep(Duration::from_millis(200)).await;
2130        assert!(
2131            client.is_active(),
2132            "Stream mode client stays ACTIVE before notify_closed()"
2133        );
2134
2135        // Caller signals EOF via notify_closed()
2136        client.notify_closed();
2137
2138        assert!(
2139            client.is_closed(),
2140            "Stream mode client should be CLOSED after notify_closed()"
2141        );
2142        assert!(
2143            !client.is_reconnecting(),
2144            "Stream mode client should never attempt reconnection"
2145        );
2146
2147        client.disconnect().await;
2148        server.abort();
2149    }
2150
2151    #[rstest]
2152    #[tokio::test]
2153    async fn test_send_timeout_uses_configured_reconnect_timeout() {
2154        // Test that send operations respect the configured reconnect_timeout.
2155        // When a client is stuck in RECONNECT longer than the timeout, sends should fail with Timeout.
2156        use nautilus_common::testing::wait_until_async;
2157
2158        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
2159        let port = listener.local_addr().unwrap().port();
2160
2161        let server = task::spawn(async move {
2162            // Accept first connection and immediately close it
2163            if let Ok((stream, _)) = listener.accept().await
2164                && let Ok(ws) = accept_async(stream).await
2165            {
2166                drop(ws);
2167            }
2168            // Don't accept second connection - client will be stuck in RECONNECT
2169            sleep(Duration::from_secs(60)).await;
2170        });
2171
2172        let (handler, _rx) = channel_message_handler();
2173
2174        // Configure with SHORT 2s reconnect timeout
2175        let config = WebSocketConfig {
2176            url: format!("ws://127.0.0.1:{port}"),
2177            headers: vec![],
2178            heartbeat: None,
2179            heartbeat_msg: None,
2180            reconnect_timeout_ms: Some(2_000), // 2s timeout
2181            reconnect_delay_initial_ms: Some(50),
2182            reconnect_delay_max_ms: Some(100),
2183            reconnect_backoff_factor: Some(1.0),
2184            reconnect_jitter_ms: Some(0),
2185            reconnect_max_attempts: None,
2186            idle_timeout_ms: None,
2187        };
2188
2189        let client = WebSocketClient::connect(config, Some(handler), None, None, vec![], None)
2190            .await
2191            .unwrap();
2192
2193        // Wait for client to enter RECONNECT state
2194        wait_until_async(
2195            || async { client.is_reconnecting() },
2196            Duration::from_secs(3),
2197        )
2198        .await;
2199
2200        // Attempt send while stuck in RECONNECT - should timeout after 2s (configured timeout)
2201        let start = std::time::Instant::now();
2202        let send_result = client.send_text("test".to_string(), None).await;
2203        let elapsed = start.elapsed();
2204
2205        assert!(
2206            send_result.is_err(),
2207            "Send should fail when client stuck in RECONNECT"
2208        );
2209        assert!(
2210            matches!(send_result, Err(crate::error::SendError::Timeout)),
2211            "Send should return Timeout error, was: {send_result:?}"
2212        );
2213        // Verify timeout respects configured value (2s), but don't check upper bound
2214        // as CI scheduler jitter can cause legitimate delays beyond the timeout
2215        assert!(
2216            elapsed >= Duration::from_millis(1800),
2217            "Send should timeout after at least 2s (configured timeout), took {elapsed:?}"
2218        );
2219
2220        client.disconnect().await;
2221        server.abort();
2222    }
2223
2224    #[rstest]
2225    #[tokio::test]
2226    async fn test_send_waits_during_reconnection() {
2227        // Test that send operations wait for reconnection to complete (up to timeout)
2228        use nautilus_common::testing::wait_until_async;
2229
2230        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
2231        let port = listener.local_addr().unwrap().port();
2232
2233        let server = task::spawn(async move {
2234            // First connection - accept and immediately close
2235            if let Ok((stream, _)) = listener.accept().await
2236                && let Ok(ws) = accept_async(stream).await
2237            {
2238                drop(ws);
2239            }
2240
2241            // Wait a bit before accepting second connection
2242            sleep(Duration::from_millis(500)).await;
2243
2244            // Second connection - accept and keep alive
2245            if let Ok((stream, _)) = listener.accept().await
2246                && let Ok(mut ws) = accept_async(stream).await
2247            {
2248                // Echo messages
2249                while let Some(Ok(msg)) = ws.next().await {
2250                    if ws.send(msg).await.is_err() {
2251                        break;
2252                    }
2253                }
2254            }
2255        });
2256
2257        let (handler, _rx) = channel_message_handler();
2258
2259        let config = WebSocketConfig {
2260            url: format!("ws://127.0.0.1:{port}"),
2261            headers: vec![],
2262            heartbeat: None,
2263            heartbeat_msg: None,
2264            reconnect_timeout_ms: Some(5_000), // 5s timeout - enough for reconnect
2265            reconnect_delay_initial_ms: Some(100),
2266            reconnect_delay_max_ms: Some(200),
2267            reconnect_backoff_factor: Some(1.0),
2268            reconnect_jitter_ms: Some(0),
2269            reconnect_max_attempts: None,
2270            idle_timeout_ms: None,
2271        };
2272
2273        let client = WebSocketClient::connect(config, Some(handler), None, None, vec![], None)
2274            .await
2275            .unwrap();
2276
2277        // Wait for reconnection to trigger
2278        wait_until_async(
2279            || async { client.is_reconnecting() },
2280            Duration::from_secs(2),
2281        )
2282        .await;
2283
2284        // Try to send while reconnecting - should wait and succeed after reconnect
2285        let send_result = tokio::time::timeout(
2286            Duration::from_secs(3),
2287            client.send_text("test_message".to_string(), None),
2288        )
2289        .await;
2290
2291        assert!(
2292            send_result.is_ok() && send_result.unwrap().is_ok(),
2293            "Send should succeed after waiting for reconnection"
2294        );
2295
2296        client.disconnect().await;
2297        server.abort();
2298    }
2299
2300    #[rstest]
2301    #[tokio::test]
2302    async fn test_rate_limiter_before_active_wait() {
2303        // Test that rate limiting happens BEFORE active state check.
2304        // This prevents race conditions where connection state changes during rate limit wait.
2305        // We verify this by: (1) exhausting rate limit, (2) ensuring client is RECONNECTING,
2306        // (3) sending again and confirming it waits for rate limit THEN reconnection.
2307        use std::{num::NonZeroU32, sync::Arc};
2308
2309        use nautilus_common::testing::wait_until_async;
2310
2311        use crate::ratelimiter::quota::Quota;
2312
2313        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
2314        let port = listener.local_addr().unwrap().port();
2315
2316        let server = task::spawn(async move {
2317            // First connection - accept and close after receiving one message
2318            if let Ok((stream, _)) = listener.accept().await
2319                && let Ok(mut ws) = accept_async(stream).await
2320            {
2321                // Receive first message then close
2322                if let Some(Ok(_)) = ws.next().await {
2323                    drop(ws);
2324                }
2325            }
2326
2327            // Wait before accepting reconnection
2328            sleep(Duration::from_millis(500)).await;
2329
2330            // Second connection - accept and keep alive
2331            if let Ok((stream, _)) = listener.accept().await
2332                && let Ok(mut ws) = accept_async(stream).await
2333            {
2334                while let Some(Ok(msg)) = ws.next().await {
2335                    if ws.send(msg).await.is_err() {
2336                        break;
2337                    }
2338                }
2339            }
2340        });
2341
2342        let (handler, _rx) = channel_message_handler();
2343
2344        let config = WebSocketConfig {
2345            url: format!("ws://127.0.0.1:{port}"),
2346            headers: vec![],
2347            heartbeat: None,
2348            heartbeat_msg: None,
2349            reconnect_timeout_ms: Some(5_000),
2350            reconnect_delay_initial_ms: Some(50),
2351            reconnect_delay_max_ms: Some(100),
2352            reconnect_backoff_factor: Some(1.0),
2353            reconnect_jitter_ms: Some(0),
2354            reconnect_max_attempts: None,
2355            idle_timeout_ms: None,
2356        };
2357
2358        // Very restrictive rate limit: 1 request per second, burst of 1
2359        let quota = Quota::per_second(NonZeroU32::new(1).unwrap())
2360            .unwrap()
2361            .allow_burst(NonZeroU32::new(1).unwrap());
2362
2363        let client = Arc::new(
2364            WebSocketClient::connect(
2365                config,
2366                Some(handler),
2367                None,
2368                None,
2369                vec![("test_key".to_string(), quota)],
2370                None,
2371            )
2372            .await
2373            .unwrap(),
2374        );
2375
2376        // First send exhausts burst capacity and triggers connection close
2377        let test_key: [Ustr; 1] = [Ustr::from("test_key")];
2378        client
2379            .send_text("msg1".to_string(), Some(test_key.as_slice()))
2380            .await
2381            .unwrap();
2382
2383        // Wait for client to enter RECONNECT state
2384        wait_until_async(
2385            || async { client.is_reconnecting() },
2386            Duration::from_secs(2),
2387        )
2388        .await;
2389
2390        // Second send: will hit rate limit (~1s) THEN wait for reconnection (~0.5s)
2391        let start = std::time::Instant::now();
2392        let send_result = client
2393            .send_text("msg2".to_string(), Some(test_key.as_slice()))
2394            .await;
2395        let elapsed = start.elapsed();
2396
2397        // Should succeed after both rate limit AND reconnection
2398        assert!(
2399            send_result.is_ok(),
2400            "Send should succeed after rate limit + reconnection, was: {send_result:?}"
2401        );
2402        // Total wait should be at least rate limit time (~1s)
2403        // The reconnection completes while rate limiting or after
2404        // Use 850ms threshold to account for timing jitter in CI
2405        assert!(
2406            elapsed >= Duration::from_millis(850),
2407            "Should wait for rate limit (~1s), waited {elapsed:?}"
2408        );
2409
2410        client.disconnect().await;
2411        server.abort();
2412    }
2413
2414    #[rstest]
2415    #[tokio::test]
2416    async fn test_disconnect_during_reconnect_exits_cleanly() {
2417        // Test CAS race condition: disconnect called during reconnection
2418        // Should exit cleanly without spawning new tasks
2419        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
2420        let port = listener.local_addr().unwrap().port();
2421
2422        let server = task::spawn(async move {
2423            // Accept first connection and immediately close
2424            if let Ok((stream, _)) = listener.accept().await
2425                && let Ok(ws) = accept_async(stream).await
2426            {
2427                drop(ws);
2428            }
2429            // Don't accept second connection - let reconnect hang
2430            sleep(Duration::from_secs(60)).await;
2431        });
2432
2433        let (handler, _rx) = channel_message_handler();
2434
2435        let config = WebSocketConfig {
2436            url: format!("ws://127.0.0.1:{port}"),
2437            headers: vec![],
2438            heartbeat: None,
2439            heartbeat_msg: None,
2440            reconnect_timeout_ms: Some(2_000), // 2s timeout - shorter than disconnect timeout
2441            reconnect_delay_initial_ms: Some(100),
2442            reconnect_delay_max_ms: Some(200),
2443            reconnect_backoff_factor: Some(1.0),
2444            reconnect_jitter_ms: Some(0),
2445            reconnect_max_attempts: None,
2446            idle_timeout_ms: None,
2447        };
2448
2449        let client = WebSocketClient::connect(config, Some(handler), None, None, vec![], None)
2450            .await
2451            .unwrap();
2452
2453        // Wait for reconnection to start
2454        tokio::time::timeout(Duration::from_secs(2), async {
2455            while !client.is_reconnecting() {
2456                sleep(Duration::from_millis(10)).await;
2457            }
2458        })
2459        .await
2460        .expect("Client should enter RECONNECT state");
2461
2462        // Disconnect while reconnecting
2463        client.disconnect().await;
2464
2465        // Should be cleanly closed
2466        assert!(
2467            client.is_disconnected(),
2468            "Client should be cleanly disconnected"
2469        );
2470
2471        server.abort();
2472    }
2473
2474    #[rstest]
2475    #[tokio::test]
2476    async fn test_send_fails_fast_when_closed_before_rate_limit() {
2477        // Test that send operations check connection state BEFORE rate limiting,
2478        // preventing unnecessary delays when the connection is already closed.
2479        use std::{num::NonZeroU32, sync::Arc};
2480
2481        use nautilus_common::testing::wait_until_async;
2482
2483        use crate::ratelimiter::quota::Quota;
2484
2485        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
2486        let port = listener.local_addr().unwrap().port();
2487
2488        let server = task::spawn(async move {
2489            // Accept connection and immediately close
2490            if let Ok((stream, _)) = listener.accept().await
2491                && let Ok(ws) = accept_async(stream).await
2492            {
2493                drop(ws);
2494            }
2495            sleep(Duration::from_secs(60)).await;
2496        });
2497
2498        let (handler, _rx) = channel_message_handler();
2499
2500        let config = WebSocketConfig {
2501            url: format!("ws://127.0.0.1:{port}"),
2502            headers: vec![],
2503            heartbeat: None,
2504            heartbeat_msg: None,
2505            reconnect_timeout_ms: Some(5_000),
2506            reconnect_delay_initial_ms: Some(50),
2507            reconnect_delay_max_ms: Some(100),
2508            reconnect_backoff_factor: Some(1.0),
2509            reconnect_jitter_ms: Some(0),
2510            reconnect_max_attempts: None,
2511            idle_timeout_ms: None,
2512        };
2513
2514        // Very restrictive rate limit: 1 request per 10 seconds
2515        // This ensures that if we wait for rate limit, the test will timeout
2516        let quota = Quota::with_period(Duration::from_secs(10))
2517            .unwrap()
2518            .allow_burst(NonZeroU32::new(1).unwrap());
2519
2520        let client = Arc::new(
2521            WebSocketClient::connect(
2522                config,
2523                Some(handler),
2524                None,
2525                None,
2526                vec![("test_key".to_string(), quota)],
2527                None,
2528            )
2529            .await
2530            .unwrap(),
2531        );
2532
2533        // Wait for disconnection
2534        wait_until_async(
2535            || async { client.is_reconnecting() || client.is_closed() },
2536            Duration::from_secs(2),
2537        )
2538        .await;
2539
2540        // Explicitly disconnect to move away from ACTIVE state
2541        client.disconnect().await;
2542        assert!(
2543            !client.is_active(),
2544            "Client should not be active after disconnect"
2545        );
2546
2547        // Attempt send - should fail IMMEDIATELY without waiting for rate limit
2548        let start = std::time::Instant::now();
2549        let test_key: [Ustr; 1] = [Ustr::from("test_key")];
2550        let result = client
2551            .send_text("test".to_string(), Some(test_key.as_slice()))
2552            .await;
2553        let elapsed = start.elapsed();
2554
2555        // Should fail with Closed error
2556        assert!(result.is_err(), "Send should fail when client is closed");
2557        assert!(
2558            matches!(result, Err(crate::error::SendError::Closed)),
2559            "Send should return Closed error, was: {result:?}"
2560        );
2561
2562        // Should fail FAST (< 100ms) without waiting for rate limit (10s)
2563        assert!(
2564            elapsed < Duration::from_millis(100),
2565            "Send should fail fast without rate limiting, took {elapsed:?}"
2566        );
2567
2568        server.abort();
2569    }
2570
2571    #[rstest]
2572    #[tokio::test]
2573    async fn test_connect_rejects_none_message_handler() {
2574        // Test that connect() properly rejects None message_handler
2575        // to prevent zombie connections that appear alive but never detect disconnections
2576
2577        let config = WebSocketConfig {
2578            url: "ws://127.0.0.1:9999".to_string(),
2579            headers: vec![],
2580            heartbeat: None,
2581            heartbeat_msg: None,
2582            reconnect_timeout_ms: Some(1_000),
2583            reconnect_delay_initial_ms: Some(100),
2584            reconnect_delay_max_ms: Some(500),
2585            reconnect_backoff_factor: Some(1.5),
2586            reconnect_jitter_ms: Some(0),
2587            reconnect_max_attempts: None,
2588            idle_timeout_ms: None,
2589        };
2590
2591        // Pass None for message_handler - should be rejected
2592        let result = WebSocketClient::connect(config, None, None, None, vec![], None).await;
2593
2594        assert!(
2595            result.is_err(),
2596            "connect() should reject None message_handler"
2597        );
2598
2599        let err = result.unwrap_err();
2600        let err_msg = err.to_string();
2601        assert!(
2602            err_msg.contains("Handler mode requires message_handler"),
2603            "Error should mention missing message_handler, was: {err_msg}"
2604        );
2605    }
2606
2607    #[rstest]
2608    #[tokio::test]
2609    async fn test_client_without_handler_sets_stream_mode() {
2610        // Test that if a client is created without a handler via connect_url,
2611        // it properly sets is_stream_mode=true to prevent zombie connections
2612
2613        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
2614        let port = listener.local_addr().unwrap().port();
2615
2616        let server = task::spawn(async move {
2617            // Accept and immediately close to simulate server disconnect
2618            if let Ok((stream, _)) = listener.accept().await
2619                && let Ok(ws) = accept_async(stream).await
2620            {
2621                drop(ws); // Drop connection immediately
2622            }
2623        });
2624
2625        let config = WebSocketConfig {
2626            url: format!("ws://127.0.0.1:{port}"),
2627            headers: vec![],
2628            heartbeat: None,
2629            heartbeat_msg: None,
2630            reconnect_timeout_ms: Some(1_000),
2631            reconnect_delay_initial_ms: Some(100),
2632            reconnect_delay_max_ms: Some(500),
2633            reconnect_backoff_factor: Some(1.5),
2634            reconnect_jitter_ms: Some(0),
2635            reconnect_max_attempts: None,
2636            idle_timeout_ms: None,
2637        };
2638
2639        // Create client directly via connect_url with no handler (stream mode)
2640        let inner = WebSocketClientInner::connect_url(config, None, None)
2641            .await
2642            .unwrap();
2643
2644        // Verify is_stream_mode is true when no handler
2645        assert!(
2646            inner.is_stream_mode,
2647            "Client without handler should have is_stream_mode=true"
2648        );
2649
2650        // Verify that when stream mode is enabled, reconnection is disabled
2651        // (documented behavior - stream mode clients close instead of reconnecting)
2652
2653        server.abort();
2654    }
2655
2656    #[rstest]
2657    #[tokio::test]
2658    async fn test_idle_timeout_triggers_reconnect() {
2659        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
2660        let port = listener.local_addr().unwrap().port();
2661
2662        // Server accepts WS connection but sends nothing (simulates silent death)
2663        let server = task::spawn(async move {
2664            let (stream, _) = listener.accept().await.unwrap();
2665            let _ws = accept_async(stream).await.unwrap();
2666            // Hold connection open but send nothing
2667            sleep(Duration::from_secs(5)).await;
2668        });
2669
2670        let (handler, _rx) = channel_message_handler();
2671
2672        let config = WebSocketConfig {
2673            url: format!("ws://127.0.0.1:{port}"),
2674            headers: vec![],
2675            heartbeat: None,
2676            heartbeat_msg: None,
2677            reconnect_timeout_ms: Some(2_000),
2678            reconnect_delay_initial_ms: Some(50),
2679            reconnect_delay_max_ms: Some(100),
2680            reconnect_backoff_factor: Some(1.0),
2681            reconnect_jitter_ms: Some(0),
2682            reconnect_max_attempts: Some(1),
2683            idle_timeout_ms: Some(500),
2684        };
2685
2686        let client = WebSocketClient::connect(config, Some(handler), None, None, vec![], None)
2687            .await
2688            .unwrap();
2689
2690        assert!(client.is_active());
2691
2692        // Wait for idle timeout to fire and client to enter reconnect/closed
2693        wait_until_async(
2694            || async { client.is_reconnecting() || client.is_disconnected() },
2695            Duration::from_secs(3),
2696        )
2697        .await;
2698
2699        assert!(
2700            !client.is_active(),
2701            "Client should not be active after idle timeout"
2702        );
2703
2704        client.disconnect().await;
2705        server.abort();
2706    }
2707
2708    #[rstest]
2709    #[tokio::test]
2710    async fn test_idle_timeout_resets_on_data() {
2711        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
2712        let port = listener.local_addr().unwrap().port();
2713
2714        // Server sends a message every 200ms (well within 1s idle timeout)
2715        let server = task::spawn(async move {
2716            let (stream, _) = listener.accept().await.unwrap();
2717            let mut ws = accept_async(stream).await.unwrap();
2718            for _ in 0..10 {
2719                sleep(Duration::from_millis(200)).await;
2720
2721                if ws
2722                    .send(tokio_tungstenite::tungstenite::Message::Text("ping".into()))
2723                    .await
2724                    .is_err()
2725                {
2726                    break;
2727                }
2728            }
2729        });
2730
2731        let (handler, _rx) = channel_message_handler();
2732
2733        let config = WebSocketConfig {
2734            url: format!("ws://127.0.0.1:{port}"),
2735            headers: vec![],
2736            heartbeat: None,
2737            heartbeat_msg: None,
2738            reconnect_timeout_ms: Some(2_000),
2739            reconnect_delay_initial_ms: Some(50),
2740            reconnect_delay_max_ms: Some(100),
2741            reconnect_backoff_factor: Some(1.0),
2742            reconnect_jitter_ms: Some(0),
2743            reconnect_max_attempts: Some(1),
2744            idle_timeout_ms: Some(1_000),
2745        };
2746
2747        let client = WebSocketClient::connect(config, Some(handler), None, None, vec![], None)
2748            .await
2749            .unwrap();
2750
2751        assert!(client.is_active());
2752
2753        // Wait 1.5s - data arrives every 200ms so idle timeout (1s) should NOT fire
2754        sleep(Duration::from_millis(1_500)).await;
2755
2756        assert!(
2757            client.is_active(),
2758            "Client should remain active when data is flowing"
2759        );
2760
2761        client.disconnect().await;
2762        server.abort();
2763    }
2764
2765    #[rstest]
2766    #[tokio::test]
2767    async fn test_disconnect_during_backoff_exits_promptly() {
2768        // Verify that disconnect interrupts backoff sleep (Finding 1).
2769        // Server accepts then drops, no second listener -> reconnect fails -> enters backoff.
2770        // We disconnect while backing off and assert the client shuts down quickly.
2771        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
2772        let port = listener.local_addr().unwrap().port();
2773
2774        let server = task::spawn(async move {
2775            // Accept first connection, close immediately
2776            if let Ok((stream, _)) = listener.accept().await {
2777                let _ = accept_async(stream).await;
2778            }
2779            // Don't accept again so reconnect fails and enters backoff
2780            sleep(Duration::from_secs(60)).await;
2781        });
2782
2783        let (handler, _rx) = channel_message_handler();
2784
2785        let config = WebSocketConfig {
2786            url: format!("ws://127.0.0.1:{port}"),
2787            headers: vec![],
2788            heartbeat: None,
2789            heartbeat_msg: None,
2790            reconnect_timeout_ms: Some(1_000),
2791            reconnect_delay_initial_ms: Some(10_000), // 10s backoff to ensure we're sleeping
2792            reconnect_delay_max_ms: Some(10_000),
2793            reconnect_backoff_factor: Some(1.0),
2794            reconnect_jitter_ms: Some(0),
2795            reconnect_max_attempts: None,
2796            idle_timeout_ms: None,
2797        };
2798
2799        let client = WebSocketClient::connect(config, Some(handler), None, None, vec![], None)
2800            .await
2801            .unwrap();
2802
2803        // Wait for client to enter reconnect
2804        wait_until_async(
2805            || async { client.is_reconnecting() },
2806            Duration::from_secs(3),
2807        )
2808        .await;
2809
2810        // Wait a bit more for the reconnect attempt to fail and enter backoff sleep
2811        sleep(Duration::from_millis(1_500)).await;
2812
2813        // Disconnect while backing off
2814        let start = std::time::Instant::now();
2815        client.disconnect().await;
2816        let elapsed = start.elapsed();
2817
2818        assert!(client.is_disconnected(), "Client should be disconnected");
2819        // Should exit well before the 10s backoff sleep completes
2820        assert!(
2821            elapsed < Duration::from_secs(2),
2822            "Disconnect should interrupt backoff sleep, took {elapsed:?}"
2823        );
2824
2825        server.abort();
2826    }
2827
2828    #[rstest]
2829    #[tokio::test]
2830    async fn test_rate_limit_cancelled_on_disconnect() {
2831        // Verify that a send blocked on rate limiting returns Closed when
2832        // the client disconnects (Finding 6).
2833        use std::{num::NonZeroU32, sync::Arc};
2834
2835        use crate::ratelimiter::quota::Quota;
2836
2837        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
2838        let port = listener.local_addr().unwrap().port();
2839
2840        let server = task::spawn(async move {
2841            if let Ok((stream, _)) = listener.accept().await {
2842                let mut ws = accept_async(stream).await.unwrap();
2843                // Keep alive and echo
2844                while let Some(Ok(msg)) = ws.next().await {
2845                    if ws.send(msg).await.is_err() {
2846                        break;
2847                    }
2848                }
2849            }
2850        });
2851
2852        let (handler, _rx) = channel_message_handler();
2853
2854        let config = WebSocketConfig {
2855            url: format!("ws://127.0.0.1:{port}"),
2856            headers: vec![],
2857            heartbeat: None,
2858            heartbeat_msg: None,
2859            reconnect_timeout_ms: Some(5_000),
2860            reconnect_delay_initial_ms: Some(100),
2861            reconnect_delay_max_ms: Some(500),
2862            reconnect_backoff_factor: Some(1.5),
2863            reconnect_jitter_ms: Some(0),
2864            reconnect_max_attempts: None,
2865            idle_timeout_ms: None,
2866        };
2867
2868        // Very restrictive: 1 req per 60 seconds
2869        let quota = Quota::with_period(Duration::from_secs(60))
2870            .unwrap()
2871            .allow_burst(NonZeroU32::new(1).unwrap());
2872
2873        let client = Arc::new(
2874            WebSocketClient::connect(
2875                config,
2876                Some(handler),
2877                None,
2878                None,
2879                vec![("rate_key".to_string(), quota)],
2880                None,
2881            )
2882            .await
2883            .unwrap(),
2884        );
2885
2886        let test_key: [Ustr; 1] = [Ustr::from("rate_key")];
2887
2888        // Exhaust the burst quota
2889        client
2890            .send_text("exhaust".to_string(), Some(test_key.as_slice()))
2891            .await
2892            .unwrap();
2893
2894        // Spawn a send that will block on rate limiter
2895        let client_clone = client.clone();
2896        let send_handle = task::spawn(async move {
2897            client_clone
2898                .send_text("blocked".to_string(), Some(&[Ustr::from("rate_key")]))
2899                .await
2900        });
2901
2902        // Let the send block on rate limit
2903        sleep(Duration::from_millis(200)).await;
2904
2905        // Disconnect while send is blocked
2906        let start = std::time::Instant::now();
2907        client.disconnect().await;
2908        let elapsed_disconnect = start.elapsed();
2909
2910        // The blocked send should return Closed
2911        let result = tokio::time::timeout(Duration::from_secs(2), send_handle)
2912            .await
2913            .expect("Send task should complete quickly")
2914            .expect("Send task should not panic");
2915
2916        assert!(
2917            matches!(result, Err(crate::error::SendError::Closed)),
2918            "Blocked send should return Closed, was: {result:?}"
2919        );
2920
2921        // Disconnect should be fast, not waiting for the 60s rate limit
2922        assert!(
2923            elapsed_disconnect < Duration::from_secs(3),
2924            "Disconnect should not wait for rate limiter, took {elapsed_disconnect:?}"
2925        );
2926
2927        server.abort();
2928    }
2929
2930    #[rstest]
2931    #[tokio::test]
2932    async fn test_stream_mode_transitions_to_closed_on_dead_write_task() {
2933        // Verify that stream mode transitions to CLOSED (not RECONNECT) when
2934        // the write task dies (Finding 4). We force write failure by sending
2935        // after the server closes the connection.
2936        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
2937        let port = listener.local_addr().unwrap().port();
2938
2939        let server = task::spawn(async move {
2940            if let Ok((stream, _)) = listener.accept().await
2941                && let Ok(ws) = accept_async(stream).await
2942            {
2943                // Close immediately to cause write errors
2944                drop(ws);
2945            }
2946        });
2947
2948        let config = WebSocketConfig {
2949            url: format!("ws://127.0.0.1:{port}"),
2950            headers: vec![],
2951            heartbeat: None,
2952            heartbeat_msg: None,
2953            reconnect_timeout_ms: Some(1_000),
2954            reconnect_delay_initial_ms: Some(50),
2955            reconnect_delay_max_ms: Some(100),
2956            reconnect_backoff_factor: Some(1.0),
2957            reconnect_jitter_ms: Some(0),
2958            reconnect_max_attempts: None,
2959            idle_timeout_ms: None,
2960        };
2961
2962        let (_reader, client) = WebSocketClient::connect_stream(config, vec![], None, None)
2963            .await
2964            .unwrap();
2965
2966        assert!(client.is_active(), "Client should start active");
2967
2968        // Wait for server to close, then send to trigger write task failure
2969        sleep(Duration::from_millis(100)).await;
2970
2971        // Keep sending until the write task detects the broken connection
2972        for _ in 0..20 {
2973            let _ = client.send_text("ping".to_string(), None).await;
2974            sleep(Duration::from_millis(50)).await;
2975
2976            if !client.is_active() {
2977                break;
2978            }
2979        }
2980
2981        // Wait for controller to process the state change
2982        wait_until_async(|| async { !client.is_active() }, Duration::from_secs(5)).await;
2983
2984        // Stream mode should go to CLOSED, not RECONNECT
2985        assert!(
2986            client.is_closed() || client.is_disconnected(),
2987            "Stream mode should transition to CLOSED, not RECONNECT. \
2988             is_reconnecting={}, is_closed={}, is_disconnected={}",
2989            client.is_reconnecting(),
2990            client.is_closed(),
2991            client.is_disconnected(),
2992        );
2993        assert!(
2994            !client.is_reconnecting(),
2995            "Stream mode should never attempt reconnection"
2996        );
2997
2998        server.abort();
2999    }
3000
3001    #[rstest]
3002    #[tokio::test]
3003    async fn test_zero_idle_timeout_rejected() {
3004        let (handler, _rx) = channel_message_handler();
3005
3006        let config = WebSocketConfig {
3007            url: "ws://127.0.0.1:9999".to_string(),
3008            headers: vec![],
3009            heartbeat: None,
3010            heartbeat_msg: None,
3011            reconnect_timeout_ms: None,
3012            reconnect_delay_initial_ms: None,
3013            reconnect_delay_max_ms: None,
3014            reconnect_backoff_factor: None,
3015            reconnect_jitter_ms: None,
3016            reconnect_max_attempts: None,
3017            idle_timeout_ms: Some(0),
3018        };
3019
3020        let result =
3021            WebSocketClient::connect(config, Some(handler), None, None, vec![], None).await;
3022
3023        assert!(result.is_err(), "Zero idle timeout should be rejected");
3024        let err_msg = result.unwrap_err().to_string();
3025        assert!(
3026            err_msg.contains("Idle timeout cannot be zero"),
3027            "Error should mention zero idle timeout, was: {err_msg}"
3028        );
3029    }
3030}