1use std::{
26 collections::VecDeque,
27 fmt::Debug,
28 sync::{
29 Arc,
30 atomic::{AtomicU8, Ordering},
31 },
32 time::Duration,
33};
34
35use futures_util::{SinkExt, StreamExt};
36use http::HeaderName;
37use nautilus_core::CleanDrop;
38use nautilus_cryptography::providers::install_cryptographic_provider;
39#[cfg(feature = "turmoil")]
40use tokio_tungstenite::MaybeTlsStream;
41#[cfg(feature = "turmoil")]
42use tokio_tungstenite::client_async;
43#[cfg(not(feature = "turmoil"))]
44use tokio_tungstenite::connect_async_with_config;
45use tokio_tungstenite::tungstenite::{
46 Error, Message, client::IntoClientRequest, http::HeaderValue,
47};
48use ustr::Ustr;
49
50use super::{
51 config::WebSocketConfig,
52 consts::{
53 CONNECTION_STATE_CHECK_INTERVAL_MS, GRACEFUL_SHUTDOWN_DELAY_MS,
54 GRACEFUL_SHUTDOWN_TIMEOUT_SECS, SEND_OPERATION_CHECK_INTERVAL_MS,
55 },
56 types::{MessageHandler, MessageReader, MessageWriter, PingHandler, WriterCommand},
57};
58#[cfg(feature = "turmoil")]
59use crate::net::TcpConnector;
60use crate::{
61 RECONNECTED,
62 backoff::ExponentialBackoff,
63 error::SendError,
64 logging::{log_task_aborted, log_task_started, log_task_stopped},
65 mode::ConnectionMode,
66 ratelimiter::{RateLimiter, clock::MonotonicClock, quota::Quota},
67};
68
69pub struct WebSocketClientInner {
85 config: WebSocketConfig,
86 message_handler: Option<MessageHandler>,
88 ping_handler: Option<PingHandler>,
90 read_task: Option<tokio::task::JoinHandle<()>>,
91 write_task: tokio::task::JoinHandle<()>,
92 writer_tx: tokio::sync::mpsc::UnboundedSender<WriterCommand>,
93 heartbeat_task: Option<tokio::task::JoinHandle<()>>,
94 connection_mode: Arc<AtomicU8>,
95 reconnect_timeout: Duration,
96 backoff: ExponentialBackoff,
97 is_stream_mode: bool,
101 reconnect_max_attempts: Option<u32>,
103 reconnection_attempt_count: u32,
105}
106
107impl WebSocketClientInner {
108 pub async fn new_with_writer(
116 config: WebSocketConfig,
117 writer: MessageWriter,
118 ) -> Result<Self, Error> {
119 install_cryptographic_provider();
120
121 let connection_mode = Arc::new(AtomicU8::new(ConnectionMode::Active.as_u8()));
122
123 let read_task = None;
125
126 let backoff = ExponentialBackoff::new(
127 Duration::from_millis(config.reconnect_delay_initial_ms.unwrap_or(2_000)),
128 Duration::from_millis(config.reconnect_delay_max_ms.unwrap_or(30_000)),
129 config.reconnect_backoff_factor.unwrap_or(1.5),
130 config.reconnect_jitter_ms.unwrap_or(100),
131 true, )
133 .map_err(|e| Error::Io(std::io::Error::new(std::io::ErrorKind::InvalidInput, e)))?;
134
135 let (writer_tx, writer_rx) = tokio::sync::mpsc::unbounded_channel::<WriterCommand>();
136 let write_task = Self::spawn_write_task(connection_mode.clone(), writer, writer_rx);
137
138 let heartbeat_task = if let Some(heartbeat_interval) = config.heartbeat {
139 Some(Self::spawn_heartbeat_task(
140 connection_mode.clone(),
141 heartbeat_interval,
142 config.heartbeat_msg.clone(),
143 writer_tx.clone(),
144 ))
145 } else {
146 None
147 };
148
149 let reconnect_max_attempts = config.reconnect_max_attempts;
150 let reconnect_timeout = Duration::from_millis(config.reconnect_timeout_ms.unwrap_or(10000));
151
152 Ok(Self {
153 config,
154 message_handler: None, ping_handler: None,
156 writer_tx,
157 connection_mode,
158 reconnect_timeout,
159 heartbeat_task,
160 read_task,
161 write_task,
162 backoff,
163 is_stream_mode: true,
164 reconnect_max_attempts,
165 reconnection_attempt_count: 0,
166 })
167 }
168
169 pub async fn connect_url(
177 config: WebSocketConfig,
178 message_handler: Option<MessageHandler>,
179 ping_handler: Option<PingHandler>,
180 ) -> Result<Self, Error> {
181 install_cryptographic_provider();
182
183 if config.heartbeat == Some(0) {
184 return Err(Error::Io(std::io::Error::new(
185 std::io::ErrorKind::InvalidInput,
186 "Heartbeat interval cannot be zero",
187 )));
188 }
189
190 if config.idle_timeout_ms == Some(0) {
191 return Err(Error::Io(std::io::Error::new(
192 std::io::ErrorKind::InvalidInput,
193 "Idle timeout cannot be zero",
194 )));
195 }
196
197 let is_stream_mode = message_handler.is_none();
199 let reconnect_max_attempts = config.reconnect_max_attempts;
200
201 let (writer, reader) =
202 Self::connect_with_server(&config.url, config.headers.clone()).await?;
203
204 let connection_mode = Arc::new(AtomicU8::new(ConnectionMode::Active.as_u8()));
205
206 let read_task = if message_handler.is_some() {
207 Some(Self::spawn_message_handler_task(
208 connection_mode.clone(),
209 reader,
210 message_handler.as_ref(),
211 ping_handler.as_ref(),
212 config.idle_timeout_ms,
213 ))
214 } else {
215 None
216 };
217
218 let (writer_tx, writer_rx) = tokio::sync::mpsc::unbounded_channel::<WriterCommand>();
219 let write_task = Self::spawn_write_task(connection_mode.clone(), writer, writer_rx);
220
221 let heartbeat_task = config.heartbeat.map(|heartbeat_secs| {
223 Self::spawn_heartbeat_task(
224 connection_mode.clone(),
225 heartbeat_secs,
226 config.heartbeat_msg.clone(),
227 writer_tx.clone(),
228 )
229 });
230
231 let reconnect_timeout =
232 Duration::from_millis(config.reconnect_timeout_ms.unwrap_or(10_000));
233 let backoff = ExponentialBackoff::new(
234 Duration::from_millis(config.reconnect_delay_initial_ms.unwrap_or(2_000)),
235 Duration::from_millis(config.reconnect_delay_max_ms.unwrap_or(30_000)),
236 config.reconnect_backoff_factor.unwrap_or(1.5),
237 config.reconnect_jitter_ms.unwrap_or(100),
238 true, )
240 .map_err(|e| Error::Io(std::io::Error::new(std::io::ErrorKind::InvalidInput, e)))?;
241
242 Ok(Self {
243 config,
244 message_handler,
245 ping_handler,
246 read_task,
247 write_task,
248 writer_tx,
249 heartbeat_task,
250 connection_mode,
251 reconnect_timeout,
252 backoff,
253 is_stream_mode,
255 reconnect_max_attempts,
256 reconnection_attempt_count: 0,
257 })
258 }
259
260 #[inline]
270 #[cfg(not(feature = "turmoil"))]
271 pub async fn connect_with_server(
272 url: &str,
273 headers: Vec<(String, String)>,
274 ) -> Result<(MessageWriter, MessageReader), Error> {
275 let mut request = url.into_client_request()?;
276 let req_headers = request.headers_mut();
277
278 let mut header_names: Vec<HeaderName> = Vec::new();
279 for (key, val) in headers {
280 let header_value = HeaderValue::from_str(&val)?;
281 let header_name: HeaderName = key.parse()?;
282 header_names.push(header_name.clone());
283 req_headers.insert(header_name, header_value);
284 }
285
286 connect_async_with_config(request, None, true)
287 .await
288 .map(|resp| resp.0.split())
289 }
290
291 #[inline]
304 #[cfg(feature = "turmoil")]
305 pub async fn connect_with_server(
306 url: &str,
307 headers: Vec<(String, String)>,
308 ) -> Result<(MessageWriter, MessageReader), Error> {
309 use rustls::ClientConfig;
310 use tokio_rustls::TlsConnector;
311
312 let mut request = url.into_client_request()?;
313 let req_headers = request.headers_mut();
314
315 let mut header_names: Vec<HeaderName> = Vec::new();
316 for (key, val) in headers {
317 let header_value = HeaderValue::from_str(&val)?;
318 let header_name: HeaderName = key.parse()?;
319 header_names.push(header_name.clone());
320 req_headers.insert(header_name, header_value);
321 }
322
323 let uri = request.uri();
324 let scheme = uri.scheme_str().unwrap_or("ws");
325 let host = uri.host().ok_or_else(|| {
326 Error::Url(tokio_tungstenite::tungstenite::error::UrlError::NoHostName)
327 })?;
328
329 let port = uri
331 .port_u16()
332 .unwrap_or_else(|| if scheme == "wss" { 443 } else { 80 });
333
334 let addr = format!("{host}:{port}");
335
336 let connector = crate::net::RealTcpConnector;
338 let tcp_stream = connector.connect(&addr).await?;
339 if let Err(e) = tcp_stream.set_nodelay(true) {
340 log::warn!("Failed to enable TCP_NODELAY for socket client: {e:?}");
341 }
342
343 let maybe_tls_stream = if scheme == "wss" {
345 let mut root_store = rustls::RootCertStore::empty();
347 root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
348
349 let config = ClientConfig::builder()
350 .with_root_certificates(root_store)
351 .with_no_client_auth();
352
353 let tls_connector = TlsConnector::from(std::sync::Arc::new(config));
354 let domain =
355 rustls::pki_types::ServerName::try_from(host.to_string()).map_err(|e| {
356 Error::Io(std::io::Error::new(
357 std::io::ErrorKind::InvalidInput,
358 format!("Invalid DNS name: {e}"),
359 ))
360 })?;
361
362 let tls_stream = tls_connector.connect(domain, tcp_stream).await?;
363 MaybeTlsStream::Rustls(tls_stream)
364 } else {
365 MaybeTlsStream::Plain(tcp_stream)
366 };
367
368 client_async(request, maybe_tls_stream)
370 .await
371 .map(|resp| resp.0.split())
372 }
373
374 pub async fn reconnect(&mut self) -> Result<(), Error> {
389 log::debug!("Reconnecting");
390
391 if self.is_stream_mode {
392 log::warn!(
393 "Auto-reconnect disabled for stream-based WebSocket client; \
394 stream users must manually reconnect by creating a new connection"
395 );
396 self.connection_mode
398 .store(ConnectionMode::Closed.as_u8(), Ordering::SeqCst);
399 return Ok(());
400 }
401
402 if ConnectionMode::from_atomic(&self.connection_mode).is_disconnect() {
403 log::debug!("Reconnect aborted due to disconnect state");
404 return Ok(());
405 }
406
407 tokio::time::timeout(self.reconnect_timeout, async {
408 let (new_writer, reader) =
410 Self::connect_with_server(&self.config.url, self.config.headers.clone()).await?;
411
412 if ConnectionMode::from_atomic(&self.connection_mode).is_disconnect() {
413 log::debug!("Reconnect aborted mid-flight (after connect)");
414 return Ok(());
415 }
416
417 let (tx, rx) = tokio::sync::oneshot::channel();
421 if let Err(e) = self.writer_tx.send(WriterCommand::Update(new_writer, tx)) {
422 log::error!("{e}");
423 return Err(Error::Io(std::io::Error::new(
424 std::io::ErrorKind::BrokenPipe,
425 format!("Failed to send update command: {e}"),
426 )));
427 }
428
429 match rx.await {
431 Ok(true) => log::debug!("Writer confirmed buffer drain success"),
432 Ok(false) => {
433 log::warn!("Writer failed to drain buffer, aborting reconnect");
434 return Err(Error::Io(std::io::Error::other(
436 "Failed to drain reconnection buffer",
437 )));
438 }
439 Err(e) => {
440 log::error!("Writer dropped update channel: {e}");
441 return Err(Error::Io(std::io::Error::new(
442 std::io::ErrorKind::BrokenPipe,
443 "Writer task dropped response channel",
444 )));
445 }
446 }
447
448 tokio::time::sleep(Duration::from_millis(GRACEFUL_SHUTDOWN_DELAY_MS)).await;
450
451 if ConnectionMode::from_atomic(&self.connection_mode).is_disconnect() {
452 log::debug!("Reconnect aborted mid-flight (after delay)");
453 return Ok(());
454 }
455
456 if let Some(ref read_task) = self.read_task.take()
457 && !read_task.is_finished()
458 {
459 read_task.abort();
460 log_task_aborted("read");
461 }
462
463 if self
466 .connection_mode
467 .compare_exchange(
468 ConnectionMode::Reconnect.as_u8(),
469 ConnectionMode::Active.as_u8(),
470 Ordering::SeqCst,
471 Ordering::SeqCst,
472 )
473 .is_err()
474 {
475 log::debug!("Reconnect aborted (state changed during reconnect)");
476 return Ok(());
477 }
478
479 self.read_task = if self.message_handler.is_some() {
480 Some(Self::spawn_message_handler_task(
481 self.connection_mode.clone(),
482 reader,
483 self.message_handler.as_ref(),
484 self.ping_handler.as_ref(),
485 self.config.idle_timeout_ms,
486 ))
487 } else {
488 None
489 };
490
491 log::debug!("Reconnect succeeded");
492 Ok(())
493 })
494 .await
495 .map_err(|_| {
496 Error::Io(std::io::Error::new(
497 std::io::ErrorKind::TimedOut,
498 format!(
499 "reconnection timed out after {}s",
500 self.reconnect_timeout.as_secs_f64()
501 ),
502 ))
503 })?
504 }
505
506 #[inline]
512 #[must_use]
513 pub fn is_alive(&self) -> bool {
514 match &self.read_task {
515 Some(read_task) => !read_task.is_finished() && !self.write_task.is_finished(),
516 None => !self.write_task.is_finished(),
517 }
518 }
519
520 fn spawn_message_handler_task(
521 connection_state: Arc<AtomicU8>,
522 mut reader: MessageReader,
523 message_handler: Option<&MessageHandler>,
524 ping_handler: Option<&PingHandler>,
525 idle_timeout_ms: Option<u64>,
526 ) -> tokio::task::JoinHandle<()> {
527 log::debug!("Started message handler task 'read'");
528
529 let check_interval = Duration::from_millis(CONNECTION_STATE_CHECK_INTERVAL_MS);
530 let idle_timeout = idle_timeout_ms.map(Duration::from_millis);
531
532 let message_handler = message_handler.cloned();
534 let ping_handler = ping_handler.cloned();
535
536 tokio::task::spawn(async move {
537 let mut last_data_time = tokio::time::Instant::now();
538
539 loop {
540 if !ConnectionMode::from_atomic(&connection_state).is_active() {
541 break;
542 }
543
544 match tokio::time::timeout(check_interval, reader.next()).await {
545 Ok(Some(Ok(Message::Binary(data)))) => {
546 log::trace!("Received message <binary> {} bytes", data.len());
547 last_data_time = tokio::time::Instant::now();
548 if let Some(ref handler) = message_handler {
549 handler(Message::Binary(data));
550 }
551 }
552 Ok(Some(Ok(Message::Text(data)))) => {
553 log::trace!("Received message: {data}");
554 last_data_time = tokio::time::Instant::now();
555 if let Some(ref handler) = message_handler {
556 handler(Message::Text(data));
557 }
558 }
559 Ok(Some(Ok(Message::Ping(ping_data)))) => {
560 log::trace!("Received ping: {ping_data:?}");
561 last_data_time = tokio::time::Instant::now();
562 if let Some(ref handler) = ping_handler {
563 handler(ping_data.to_vec());
564 }
565 }
566 Ok(Some(Ok(Message::Pong(_)))) => {
567 log::trace!("Received pong");
568 last_data_time = tokio::time::Instant::now();
569 }
570 Ok(Some(Ok(Message::Close(_)))) => {
571 log::debug!("Received close message - terminating");
572 break;
573 }
574 Ok(Some(Ok(_))) => (),
575 Ok(Some(Err(e))) => {
576 log::error!("Received error message - terminating: {e}");
577 break;
578 }
579 Ok(None) => {
580 log::debug!("No message received - terminating");
581 break;
582 }
583 Err(_) => {
584 if let Some(timeout) = idle_timeout {
585 let idle_duration = last_data_time.elapsed();
586 if idle_duration >= timeout {
587 log::warn!(
588 "Read idle timeout: no data received for {:.1}s",
589 idle_duration.as_secs_f64()
590 );
591 break;
592 }
593 }
594 continue;
595 }
596 }
597 }
598 })
599 }
600
601 async fn drain_reconnect_buffer(
606 buffer: &mut VecDeque<Message>,
607 writer: &mut MessageWriter,
608 ) -> bool {
609 if buffer.is_empty() {
610 return false;
611 }
612
613 let initial_buffer_len = buffer.len();
614 log::info!("Sending {initial_buffer_len} buffered messages after reconnection");
615
616 let mut send_error_occurred = false;
617
618 while let Some(buffered_msg) = buffer.front() {
619 let msg_to_send = buffered_msg.clone();
621
622 if let Err(e) = writer.send(msg_to_send).await {
623 log::error!(
624 "Failed to send buffered message after reconnection: {e}, {} messages remain in buffer",
625 buffer.len()
626 );
627 send_error_occurred = true;
628 break; }
630
631 buffer.pop_front();
633 }
634
635 if buffer.is_empty() {
636 log::info!("Successfully sent all {initial_buffer_len} buffered messages");
637 }
638
639 send_error_occurred
640 }
641
642 fn spawn_write_task(
643 connection_state: Arc<AtomicU8>,
644 writer: MessageWriter,
645 mut writer_rx: tokio::sync::mpsc::UnboundedReceiver<WriterCommand>,
646 ) -> tokio::task::JoinHandle<()> {
647 log_task_started("write");
648
649 let check_interval = Duration::from_millis(CONNECTION_STATE_CHECK_INTERVAL_MS);
651
652 tokio::task::spawn(async move {
653 let mut active_writer = writer;
654 let mut reconnect_buffer: VecDeque<Message> = VecDeque::new();
657
658 loop {
659 match ConnectionMode::from_atomic(&connection_state) {
660 ConnectionMode::Disconnect => {
661 if !reconnect_buffer.is_empty() {
663 log::warn!(
664 "Discarding {} buffered messages due to disconnect",
665 reconnect_buffer.len()
666 );
667 reconnect_buffer.clear();
668 }
669
670 _ = tokio::time::timeout(
673 Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS),
674 active_writer.close(),
675 )
676 .await;
677 break;
678 }
679 ConnectionMode::Closed => {
680 if !reconnect_buffer.is_empty() {
682 log::warn!(
683 "Discarding {} buffered messages due to closed connection",
684 reconnect_buffer.len()
685 );
686 reconnect_buffer.clear();
687 }
688 break;
689 }
690 _ => {}
691 }
692
693 match tokio::time::timeout(check_interval, writer_rx.recv()).await {
694 Ok(Some(msg)) => {
695 let mode = ConnectionMode::from_atomic(&connection_state);
697 if matches!(mode, ConnectionMode::Disconnect | ConnectionMode::Closed) {
698 break;
699 }
700
701 match msg {
702 WriterCommand::Update(new_writer, tx) => {
703 log::debug!("Received new writer");
704
705 tokio::time::sleep(Duration::from_millis(100)).await;
707
708 _ = tokio::time::timeout(
711 Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS),
712 active_writer.close(),
713 )
714 .await;
715
716 active_writer = new_writer;
717 log::debug!("Updated writer");
718
719 let send_error = Self::drain_reconnect_buffer(
720 &mut reconnect_buffer,
721 &mut active_writer,
722 )
723 .await;
724
725 if let Err(e) = tx.send(!send_error) {
726 log::error!(
727 "Failed to report drain status to controller: {e:?}"
728 );
729 }
730 }
731 WriterCommand::Send(msg) if mode.is_reconnect() => {
732 log::debug!(
734 "Buffering message during reconnection (buffer size: {})",
735 reconnect_buffer.len() + 1
736 );
737 reconnect_buffer.push_back(msg);
738 }
739 WriterCommand::Send(msg) => {
740 if let Err(e) = active_writer.send(msg.clone()).await {
741 log::error!("Failed to send message: {e}");
742 log::warn!("Writer triggering reconnect");
743 reconnect_buffer.push_back(msg);
744 connection_state
745 .store(ConnectionMode::Reconnect.as_u8(), Ordering::SeqCst);
746 }
747 }
748 }
749 }
750 Ok(None) => {
751 log::debug!("Writer channel closed, terminating writer task");
753 break;
754 }
755 Err(_) => {
756 continue;
758 }
759 }
760 }
761
762 _ = tokio::time::timeout(
765 Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS),
766 active_writer.close(),
767 )
768 .await;
769
770 log_task_stopped("write");
771 })
772 }
773
774 fn spawn_heartbeat_task(
775 connection_state: Arc<AtomicU8>,
776 heartbeat_secs: u64,
777 message: Option<String>,
778 writer_tx: tokio::sync::mpsc::UnboundedSender<WriterCommand>,
779 ) -> tokio::task::JoinHandle<()> {
780 log_task_started("heartbeat");
781
782 tokio::task::spawn(async move {
783 let interval = Duration::from_secs(heartbeat_secs);
784
785 loop {
786 tokio::time::sleep(interval).await;
787
788 match ConnectionMode::from_u8(connection_state.load(Ordering::SeqCst)) {
789 ConnectionMode::Active => {
790 let msg = match &message {
791 Some(text) => WriterCommand::Send(Message::Text(text.clone().into())),
792 None => WriterCommand::Send(Message::Ping(vec![].into())),
793 };
794
795 match writer_tx.send(msg) {
796 Ok(()) => log::trace!("Sent heartbeat to writer task"),
797 Err(e) => {
798 log::error!("Failed to send heartbeat to writer task: {e}");
799 }
800 }
801 }
802 ConnectionMode::Reconnect => continue,
803 ConnectionMode::Disconnect | ConnectionMode::Closed => break,
804 }
805 }
806
807 log_task_stopped("heartbeat");
808 })
809 }
810}
811
812impl Drop for WebSocketClientInner {
813 fn drop(&mut self) {
814 self.clean_drop();
816 }
817}
818
819impl CleanDrop for WebSocketClientInner {
821 fn clean_drop(&mut self) {
822 if let Some(ref read_task) = self.read_task.take()
823 && !read_task.is_finished()
824 {
825 read_task.abort();
826 log_task_aborted("read");
827 }
828
829 if !self.write_task.is_finished() {
830 self.write_task.abort();
831 log_task_aborted("write");
832 }
833
834 if let Some(ref handle) = self.heartbeat_task.take()
835 && !handle.is_finished()
836 {
837 handle.abort();
838 log_task_aborted("heartbeat");
839 }
840
841 self.message_handler = None;
843 self.ping_handler = None;
844 }
845}
846
847impl Debug for WebSocketClientInner {
848 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
849 f.debug_struct(stringify!(WebSocketClientInner))
850 .field("config", &self.config)
851 .field(
852 "connection_mode",
853 &ConnectionMode::from_atomic(&self.connection_mode),
854 )
855 .field("reconnect_timeout", &self.reconnect_timeout)
856 .field("is_stream_mode", &self.is_stream_mode)
857 .finish()
858 }
859}
860
861#[cfg_attr(
866 feature = "python",
867 pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.network")
868)]
869pub struct WebSocketClient {
870 pub(crate) controller_task: tokio::task::JoinHandle<()>,
871 pub(crate) connection_mode: Arc<AtomicU8>,
872 pub(crate) reconnect_timeout: Duration,
873 pub(crate) rate_limiter: Arc<RateLimiter<Ustr, MonotonicClock>>,
874 pub(crate) writer_tx: tokio::sync::mpsc::UnboundedSender<WriterCommand>,
875}
876
877impl Debug for WebSocketClient {
878 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
879 f.debug_struct(stringify!(WebSocketClient)).finish()
880 }
881}
882
883impl WebSocketClient {
884 #[allow(clippy::too_many_arguments)]
900 pub async fn connect_stream(
901 config: WebSocketConfig,
902 keyed_quotas: Vec<(String, Quota)>,
903 default_quota: Option<Quota>,
904 post_reconnect: Option<Arc<dyn Fn() + Send + Sync>>,
905 ) -> Result<(MessageReader, Self), Error> {
906 install_cryptographic_provider();
907
908 let (writer, reader) =
910 WebSocketClientInner::connect_with_server(&config.url, config.headers.clone()).await?;
911
912 let inner = WebSocketClientInner::new_with_writer(config, writer).await?;
914
915 let connection_mode = inner.connection_mode.clone();
916 let reconnect_timeout = inner.reconnect_timeout;
917 let keyed_quotas = keyed_quotas
918 .into_iter()
919 .map(|(key, quota)| (Ustr::from(&key), quota))
920 .collect();
921 let rate_limiter = Arc::new(RateLimiter::new_with_quota(default_quota, keyed_quotas));
922 let writer_tx = inner.writer_tx.clone();
923
924 let controller_task =
925 Self::spawn_controller_task(inner, connection_mode.clone(), post_reconnect);
926
927 Ok((
928 reader,
929 Self {
930 controller_task,
931 connection_mode,
932 reconnect_timeout,
933 rate_limiter,
934 writer_tx,
935 },
936 ))
937 }
938
939 pub async fn connect(
957 config: WebSocketConfig,
958 message_handler: Option<MessageHandler>,
959 ping_handler: Option<PingHandler>,
960 post_reconnection: Option<Arc<dyn Fn() + Send + Sync>>,
961 keyed_quotas: Vec<(String, Quota)>,
962 default_quota: Option<Quota>,
963 ) -> Result<Self, Error> {
964 if message_handler.is_none() {
966 return Err(Error::Io(std::io::Error::new(
967 std::io::ErrorKind::InvalidInput,
968 "Handler mode requires message_handler to be set. Use connect_stream() for stream mode without a handler.",
969 )));
970 }
971
972 log::debug!("Connecting");
973 let inner =
974 WebSocketClientInner::connect_url(config, message_handler, ping_handler).await?;
975 let connection_mode = inner.connection_mode.clone();
976 let writer_tx = inner.writer_tx.clone();
977 let reconnect_timeout = inner.reconnect_timeout;
978
979 let controller_task =
980 Self::spawn_controller_task(inner, connection_mode.clone(), post_reconnection);
981
982 let keyed_quotas = keyed_quotas
983 .into_iter()
984 .map(|(key, quota)| (Ustr::from(&key), quota))
985 .collect();
986 let rate_limiter = Arc::new(RateLimiter::new_with_quota(default_quota, keyed_quotas));
987
988 Ok(Self {
989 controller_task,
990 connection_mode,
991 reconnect_timeout,
992 rate_limiter,
993 writer_tx,
994 })
995 }
996
997 #[must_use]
999 pub fn connection_mode(&self) -> ConnectionMode {
1000 ConnectionMode::from_atomic(&self.connection_mode)
1001 }
1002
1003 #[must_use]
1008 pub fn connection_mode_atomic(&self) -> Arc<AtomicU8> {
1009 Arc::clone(&self.connection_mode)
1010 }
1011
1012 #[inline]
1017 #[must_use]
1018 pub fn is_active(&self) -> bool {
1019 self.connection_mode().is_active()
1020 }
1021
1022 #[must_use]
1024 pub fn is_disconnected(&self) -> bool {
1025 self.controller_task.is_finished()
1026 }
1027
1028 #[inline]
1033 #[must_use]
1034 pub fn is_reconnecting(&self) -> bool {
1035 self.connection_mode().is_reconnect()
1036 }
1037
1038 #[inline]
1042 #[must_use]
1043 pub fn is_disconnecting(&self) -> bool {
1044 self.connection_mode().is_disconnect()
1045 }
1046
1047 #[inline]
1053 #[must_use]
1054 pub fn is_closed(&self) -> bool {
1055 self.connection_mode().is_closed()
1056 }
1057
1058 async fn wait_for_active(&self) -> Result<(), SendError> {
1062 if self.is_closed() {
1063 return Err(SendError::Closed);
1064 }
1065
1066 let timeout = self.reconnect_timeout;
1067 let check_interval = Duration::from_millis(SEND_OPERATION_CHECK_INTERVAL_MS);
1068
1069 if !self.is_active() {
1070 log::debug!("Waiting for client to become ACTIVE before sending...");
1071
1072 let inner = tokio::time::timeout(timeout, async {
1073 loop {
1074 if self.is_active() {
1075 return Ok(());
1076 }
1077 if matches!(
1078 self.connection_mode(),
1079 ConnectionMode::Disconnect | ConnectionMode::Closed
1080 ) {
1081 return Err(());
1082 }
1083 tokio::time::sleep(check_interval).await;
1084 }
1085 })
1086 .await
1087 .map_err(|_| SendError::Timeout)?;
1088 inner.map_err(|()| SendError::Closed)?;
1089 }
1090
1091 Ok(())
1092 }
1093
1094 pub async fn disconnect(&self) {
1099 log::debug!("Disconnecting");
1100 self.connection_mode
1101 .store(ConnectionMode::Disconnect.as_u8(), Ordering::SeqCst);
1102
1103 if tokio::time::timeout(Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS), async {
1104 while !self.is_disconnected() {
1105 tokio::time::sleep(Duration::from_millis(CONNECTION_STATE_CHECK_INTERVAL_MS)).await;
1106 }
1107
1108 if !self.controller_task.is_finished() {
1109 self.controller_task.abort();
1110 log_task_aborted("controller");
1111 }
1112 })
1113 .await
1114 == Ok(())
1115 {
1116 log::debug!("Controller task finished");
1117 } else {
1118 log::error!("Timeout waiting for controller task to finish");
1119 if !self.controller_task.is_finished() {
1120 self.controller_task.abort();
1121 log_task_aborted("controller");
1122 }
1123 self.connection_mode
1124 .store(ConnectionMode::Closed.as_u8(), Ordering::SeqCst);
1125 }
1126 }
1127
1128 #[allow(unused_variables)]
1134 pub async fn send_text(&self, data: String, keys: Option<&[Ustr]>) -> Result<(), SendError> {
1135 if self.is_closed() || self.is_disconnecting() {
1137 return Err(SendError::Closed);
1138 }
1139
1140 self.rate_limiter.await_keys_ready(keys).await;
1141 self.wait_for_active().await?;
1142
1143 log::trace!("Sending text: {data:?}");
1144
1145 let msg = Message::Text(data.into());
1146 self.writer_tx
1147 .send(WriterCommand::Send(msg))
1148 .map_err(|e| SendError::BrokenPipe(e.to_string()))
1149 }
1150
1151 pub async fn send_pong(&self, data: Vec<u8>) -> Result<(), SendError> {
1157 self.wait_for_active().await?;
1158
1159 log::trace!("Sending pong frame ({} bytes)", data.len());
1160
1161 let msg = Message::Pong(data.into());
1162 self.writer_tx
1163 .send(WriterCommand::Send(msg))
1164 .map_err(|e| SendError::BrokenPipe(e.to_string()))
1165 }
1166
1167 #[allow(unused_variables)]
1173 pub async fn send_bytes(&self, data: Vec<u8>, keys: Option<&[Ustr]>) -> Result<(), SendError> {
1174 if self.is_closed() || self.is_disconnecting() {
1176 return Err(SendError::Closed);
1177 }
1178
1179 self.rate_limiter.await_keys_ready(keys).await;
1180 self.wait_for_active().await?;
1181
1182 log::trace!("Sending bytes: {data:?}");
1183
1184 let msg = Message::Binary(data.into());
1185 self.writer_tx
1186 .send(WriterCommand::Send(msg))
1187 .map_err(|e| SendError::BrokenPipe(e.to_string()))
1188 }
1189
1190 pub async fn send_close_message(&self) -> Result<(), SendError> {
1196 self.wait_for_active().await?;
1197
1198 let msg = Message::Close(None);
1199 self.writer_tx
1200 .send(WriterCommand::Send(msg))
1201 .map_err(|e| SendError::BrokenPipe(e.to_string()))
1202 }
1203
1204 fn spawn_controller_task(
1205 mut inner: WebSocketClientInner,
1206 connection_mode: Arc<AtomicU8>,
1207 post_reconnection: Option<Arc<dyn Fn() + Send + Sync>>,
1208 ) -> tokio::task::JoinHandle<()> {
1209 tokio::task::spawn(async move {
1210 log_task_started("controller");
1211
1212 let check_interval = Duration::from_millis(CONNECTION_STATE_CHECK_INTERVAL_MS);
1213
1214 loop {
1215 tokio::time::sleep(check_interval).await;
1216 let mut mode = ConnectionMode::from_atomic(&connection_mode);
1217
1218 if mode.is_disconnect() {
1219 log::debug!("Disconnecting");
1220
1221 let timeout = Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS);
1222 if tokio::time::timeout(timeout, async {
1223 tokio::time::sleep(Duration::from_millis(GRACEFUL_SHUTDOWN_DELAY_MS)).await;
1225
1226 if let Some(task) = &inner.read_task
1227 && !task.is_finished()
1228 {
1229 task.abort();
1230 log_task_aborted("read");
1231 }
1232
1233 if let Some(task) = &inner.heartbeat_task
1234 && !task.is_finished()
1235 {
1236 task.abort();
1237 log_task_aborted("heartbeat");
1238 }
1239 })
1240 .await
1241 .is_err()
1242 {
1243 log::error!("Shutdown timed out after {}s", timeout.as_secs());
1244 }
1245
1246 log::debug!("Closed");
1247 break; }
1249
1250 if mode.is_closed() {
1251 log::debug!("Connection closed");
1252 break;
1253 }
1254
1255 if mode.is_active() && !inner.is_alive() {
1256 if connection_mode
1257 .compare_exchange(
1258 ConnectionMode::Active.as_u8(),
1259 ConnectionMode::Reconnect.as_u8(),
1260 Ordering::SeqCst,
1261 Ordering::SeqCst,
1262 )
1263 .is_ok()
1264 {
1265 log::debug!("Detected dead read task, transitioning to RECONNECT");
1266 }
1267 mode = ConnectionMode::from_atomic(&connection_mode);
1268 }
1269
1270 if mode.is_reconnect() {
1271 if let Some(max_attempts) = inner.reconnect_max_attempts
1273 && inner.reconnection_attempt_count >= max_attempts
1274 {
1275 log::error!(
1276 "Max reconnection attempts ({max_attempts}) exceeded, transitioning to CLOSED"
1277 );
1278 connection_mode.store(ConnectionMode::Closed.as_u8(), Ordering::SeqCst);
1279 break;
1280 }
1281
1282 inner.reconnection_attempt_count += 1;
1283 log::debug!(
1284 "Reconnection attempt {} of {}",
1285 inner.reconnection_attempt_count,
1286 inner
1287 .reconnect_max_attempts
1288 .map_or_else(|| "unlimited".to_string(), |m| m.to_string())
1289 );
1290
1291 match inner.reconnect().await {
1292 Ok(()) => {
1293 inner.backoff.reset();
1294 inner.reconnection_attempt_count = 0; if ConnectionMode::from_atomic(&connection_mode).is_active() {
1298 if let Some(ref handler) = inner.message_handler {
1299 let reconnected_msg =
1300 Message::Text(RECONNECTED.to_string().into());
1301 handler(reconnected_msg);
1302 log::debug!("Sent reconnected message to handler");
1303 }
1304
1305 if let Some(ref callback) = post_reconnection {
1307 callback();
1308 log::debug!("Called `post_reconnection` handler");
1309 }
1310
1311 log::debug!("Reconnected successfully");
1312 } else {
1313 log::debug!(
1314 "Skipping post_reconnection handlers due to disconnect state"
1315 );
1316 }
1317 }
1318 Err(e) => {
1319 let duration = inner.backoff.next_duration();
1320 log::warn!(
1321 "Reconnect attempt {} failed: {e}",
1322 inner.reconnection_attempt_count
1323 );
1324 if !duration.is_zero() {
1325 log::warn!("Backing off for {}s...", duration.as_secs_f64());
1326 }
1327 tokio::time::sleep(duration).await;
1328 }
1329 }
1330 }
1331 }
1332 inner
1333 .connection_mode
1334 .store(ConnectionMode::Closed.as_u8(), Ordering::SeqCst);
1335
1336 log_task_stopped("controller");
1337 })
1338 }
1339}
1340
1341impl Drop for WebSocketClient {
1343 fn drop(&mut self) {
1344 if !self.controller_task.is_finished() {
1345 self.controller_task.abort();
1346 log_task_aborted("controller");
1347 }
1348 }
1349}
1350
1351#[cfg(test)]
1352#[cfg(not(feature = "turmoil"))]
1353#[cfg(target_os = "linux")] mod tests {
1355 use std::{num::NonZeroU32, sync::Arc};
1356
1357 use futures_util::{SinkExt, StreamExt};
1358 use tokio::{
1359 net::TcpListener,
1360 task::{self, JoinHandle},
1361 };
1362 use tokio_tungstenite::{
1363 accept_hdr_async,
1364 tungstenite::{
1365 handshake::server::{self, Callback},
1366 http::HeaderValue,
1367 },
1368 };
1369
1370 use crate::{
1371 ratelimiter::quota::Quota,
1372 websocket::{WebSocketClient, WebSocketConfig},
1373 };
1374
1375 struct TestServer {
1376 task: JoinHandle<()>,
1377 port: u16,
1378 }
1379
1380 #[derive(Debug, Clone)]
1381 struct TestCallback {
1382 key: String,
1383 value: HeaderValue,
1384 }
1385
1386 impl Callback for TestCallback {
1387 #[allow(clippy::panic_in_result_fn)]
1388 fn on_request(
1389 self,
1390 request: &server::Request,
1391 response: server::Response,
1392 ) -> Result<server::Response, server::ErrorResponse> {
1393 let _ = response;
1394 let value = request.headers().get(&self.key);
1395 assert!(value.is_some());
1396
1397 if let Some(value) = request.headers().get(&self.key) {
1398 assert_eq!(value, self.value);
1399 }
1400
1401 Ok(response)
1402 }
1403 }
1404
1405 impl TestServer {
1406 async fn setup() -> Self {
1407 let server = TcpListener::bind("127.0.0.1:0").await.unwrap();
1408 let port = TcpListener::local_addr(&server).unwrap().port();
1409
1410 let header_key = "test".to_string();
1411 let header_value = "test".to_string();
1412
1413 let test_call_back = TestCallback {
1414 key: header_key,
1415 value: HeaderValue::from_str(&header_value).unwrap(),
1416 };
1417
1418 let task = task::spawn(async move {
1419 loop {
1421 let (conn, _) = server.accept().await.unwrap();
1422 let mut websocket = accept_hdr_async(conn, test_call_back.clone())
1423 .await
1424 .unwrap();
1425
1426 task::spawn(async move {
1427 while let Some(Ok(msg)) = websocket.next().await {
1428 match msg {
1429 tokio_tungstenite::tungstenite::protocol::Message::Text(txt)
1430 if txt == "close-now" =>
1431 {
1432 log::debug!("Forcibly closing from server side");
1433 let _ = websocket.close(None).await;
1435 break;
1436 }
1437 tokio_tungstenite::tungstenite::protocol::Message::Text(_)
1439 | tokio_tungstenite::tungstenite::protocol::Message::Binary(_) => {
1440 if websocket.send(msg).await.is_err() {
1441 break;
1442 }
1443 }
1444 tokio_tungstenite::tungstenite::protocol::Message::Close(
1446 _frame,
1447 ) => {
1448 let _ = websocket.close(None).await;
1449 break;
1450 }
1451 _ => {}
1453 }
1454 }
1455 });
1456 }
1457 });
1458
1459 Self { task, port }
1460 }
1461 }
1462
1463 impl Drop for TestServer {
1464 fn drop(&mut self) {
1465 self.task.abort();
1466 }
1467 }
1468
1469 async fn setup_test_client(port: u16) -> WebSocketClient {
1470 let config = WebSocketConfig {
1471 url: format!("ws://127.0.0.1:{port}"),
1472 headers: vec![("test".into(), "test".into())],
1473 heartbeat: None,
1474 heartbeat_msg: None,
1475 reconnect_timeout_ms: None,
1476 reconnect_delay_initial_ms: None,
1477 reconnect_backoff_factor: None,
1478 reconnect_delay_max_ms: None,
1479 reconnect_jitter_ms: None,
1480 reconnect_max_attempts: None,
1481 idle_timeout_ms: None,
1482 };
1483 WebSocketClient::connect(config, Some(Arc::new(|_| {})), None, None, vec![], None)
1484 .await
1485 .expect("Failed to connect")
1486 }
1487
1488 #[tokio::test]
1489 async fn test_websocket_basic() {
1490 let server = TestServer::setup().await;
1491 let client = setup_test_client(server.port).await;
1492
1493 assert!(!client.is_disconnected());
1494
1495 client.disconnect().await;
1496 assert!(client.is_disconnected());
1497 }
1498
1499 #[tokio::test]
1500 async fn test_websocket_heartbeat() {
1501 let server = TestServer::setup().await;
1502 let client = setup_test_client(server.port).await;
1503
1504 tokio::time::sleep(std::time::Duration::from_secs(3)).await;
1506
1507 client.disconnect().await;
1509 assert!(client.is_disconnected());
1510 }
1511
1512 #[tokio::test]
1513 async fn test_websocket_reconnect_exhausted() {
1514 let config = WebSocketConfig {
1515 url: "ws://127.0.0.1:9997".into(), headers: vec![],
1517 heartbeat: None,
1518 heartbeat_msg: None,
1519 reconnect_timeout_ms: None,
1520 reconnect_delay_initial_ms: None,
1521 reconnect_backoff_factor: None,
1522 reconnect_delay_max_ms: None,
1523 reconnect_jitter_ms: None,
1524 reconnect_max_attempts: None,
1525 idle_timeout_ms: None,
1526 };
1527 let res =
1528 WebSocketClient::connect(config, Some(Arc::new(|_| {})), None, None, vec![], None)
1529 .await;
1530 assert!(res.is_err(), "Should fail quickly with no server");
1531 }
1532
1533 #[tokio::test]
1534 async fn test_websocket_forced_close_reconnect() {
1535 let server = TestServer::setup().await;
1536 let client = setup_test_client(server.port).await;
1537
1538 client.send_text("Hello".into(), None).await.unwrap();
1540
1541 client.send_text("close-now".into(), None).await.unwrap();
1543
1544 tokio::time::sleep(std::time::Duration::from_secs(1)).await;
1546
1547 assert!(!client.is_disconnected());
1549
1550 client.disconnect().await;
1552 assert!(client.is_disconnected());
1553 }
1554
1555 #[tokio::test]
1556 async fn test_rate_limiter() {
1557 let server = TestServer::setup().await;
1558 let quota = Quota::per_second(NonZeroU32::new(2).unwrap()).unwrap();
1559
1560 let config = WebSocketConfig {
1561 url: format!("ws://127.0.0.1:{}", server.port),
1562 headers: vec![("test".into(), "test".into())],
1563 heartbeat: None,
1564 heartbeat_msg: None,
1565 reconnect_timeout_ms: None,
1566 reconnect_delay_initial_ms: None,
1567 reconnect_backoff_factor: None,
1568 reconnect_delay_max_ms: None,
1569 reconnect_jitter_ms: None,
1570 reconnect_max_attempts: None,
1571 idle_timeout_ms: None,
1572 };
1573
1574 let client = WebSocketClient::connect(
1575 config,
1576 Some(Arc::new(|_| {})),
1577 None,
1578 None,
1579 vec![("default".into(), quota)],
1580 None,
1581 )
1582 .await
1583 .unwrap();
1584
1585 client.send_text("test1".into(), None).await.unwrap();
1587 client.send_text("test2".into(), None).await.unwrap();
1588
1589 client.send_text("test3".into(), None).await.unwrap();
1591
1592 client.disconnect().await;
1594 assert!(client.is_disconnected());
1595 }
1596
1597 #[tokio::test]
1598 async fn test_concurrent_writers() {
1599 let server = TestServer::setup().await;
1600 let client = Arc::new(setup_test_client(server.port).await);
1601
1602 let mut handles = vec![];
1603 for i in 0..10 {
1604 let client = client.clone();
1605 handles.push(task::spawn(async move {
1606 client.send_text(format!("test{i}"), None).await.unwrap();
1607 }));
1608 }
1609
1610 for handle in handles {
1611 handle.await.unwrap();
1612 }
1613
1614 client.disconnect().await;
1616 assert!(client.is_disconnected());
1617 }
1618}
1619
1620#[cfg(test)]
1621#[cfg(not(feature = "turmoil"))]
1622mod rust_tests {
1623 use futures_util::{SinkExt, StreamExt};
1624 use nautilus_common::testing::wait_until_async;
1625 use rstest::rstest;
1626 use tokio::{
1627 net::TcpListener,
1628 task,
1629 time::{Duration, sleep},
1630 };
1631 use tokio_tungstenite::accept_async;
1632
1633 use super::*;
1634 use crate::websocket::types::channel_message_handler;
1635
1636 #[rstest]
1637 #[tokio::test]
1638 async fn test_reconnect_then_disconnect() {
1639 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1641 let port = listener.local_addr().unwrap().port();
1642
1643 let server = task::spawn(async move {
1645 let (stream, _) = listener.accept().await.unwrap();
1646 let ws = accept_async(stream).await.unwrap();
1647 drop(ws);
1648 sleep(Duration::from_secs(1)).await;
1650 });
1651
1652 let (handler, _rx) = channel_message_handler();
1654
1655 let config = WebSocketConfig {
1657 url: format!("ws://127.0.0.1:{port}"),
1658 headers: vec![],
1659 heartbeat: None,
1660 heartbeat_msg: None,
1661 reconnect_timeout_ms: Some(1_000),
1662 reconnect_delay_initial_ms: Some(50),
1663 reconnect_delay_max_ms: Some(100),
1664 reconnect_backoff_factor: Some(1.0),
1665 reconnect_jitter_ms: Some(0),
1666 reconnect_max_attempts: None,
1667 idle_timeout_ms: None,
1668 };
1669
1670 let client = WebSocketClient::connect(config, Some(handler), None, None, vec![], None)
1672 .await
1673 .unwrap();
1674
1675 sleep(Duration::from_millis(100)).await;
1677 client.disconnect().await;
1679 assert!(client.is_disconnected());
1680 server.abort();
1681 }
1682
1683 #[rstest]
1684 #[tokio::test]
1685 async fn test_reconnect_state_flips_when_reader_stops() {
1686 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1688 let port = listener.local_addr().unwrap().port();
1689
1690 let server = task::spawn(async move {
1691 if let Ok((stream, _)) = listener.accept().await
1692 && let Ok(ws) = accept_async(stream).await
1693 {
1694 drop(ws);
1695 }
1696 sleep(Duration::from_millis(50)).await;
1697 });
1698
1699 let (handler, _rx) = channel_message_handler();
1700
1701 let config = WebSocketConfig {
1702 url: format!("ws://127.0.0.1:{port}"),
1703 headers: vec![],
1704 heartbeat: None,
1705 heartbeat_msg: None,
1706 reconnect_timeout_ms: Some(1_000),
1707 reconnect_delay_initial_ms: Some(50),
1708 reconnect_delay_max_ms: Some(100),
1709 reconnect_backoff_factor: Some(1.0),
1710 reconnect_jitter_ms: Some(0),
1711 reconnect_max_attempts: None,
1712 idle_timeout_ms: None,
1713 };
1714
1715 let client = WebSocketClient::connect(config, Some(handler), None, None, vec![], None)
1716 .await
1717 .unwrap();
1718
1719 tokio::time::timeout(Duration::from_secs(2), async {
1720 loop {
1721 if client.is_reconnecting() {
1722 break;
1723 }
1724 tokio::time::sleep(Duration::from_millis(10)).await;
1725 }
1726 })
1727 .await
1728 .expect("client did not enter RECONNECT state");
1729
1730 client.disconnect().await;
1731 server.abort();
1732 }
1733
1734 #[rstest]
1735 #[tokio::test]
1736 async fn test_stream_mode_disables_auto_reconnect() {
1737 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1740 let port = listener.local_addr().unwrap().port();
1741
1742 let server = task::spawn(async move {
1743 if let Ok((stream, _)) = listener.accept().await
1744 && let Ok(_ws) = accept_async(stream).await
1745 {
1746 sleep(Duration::from_millis(100)).await;
1748 }
1749 });
1750
1751 let config = WebSocketConfig {
1752 url: format!("ws://127.0.0.1:{port}"),
1753 headers: vec![],
1754 heartbeat: None,
1755 heartbeat_msg: None,
1756 reconnect_timeout_ms: Some(1_000),
1757 reconnect_delay_initial_ms: Some(50),
1758 reconnect_delay_max_ms: Some(100),
1759 reconnect_backoff_factor: Some(1.0),
1760 reconnect_jitter_ms: Some(0),
1761 reconnect_max_attempts: None,
1762 idle_timeout_ms: None,
1763 };
1764
1765 let (_reader, _client) = WebSocketClient::connect_stream(config, vec![], None, None)
1766 .await
1767 .unwrap();
1768
1769 server.abort();
1777 }
1778
1779 #[rstest]
1780 #[tokio::test]
1781 async fn test_message_handler_mode_allows_auto_reconnect() {
1782 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1784 let port = listener.local_addr().unwrap().port();
1785
1786 let server = task::spawn(async move {
1787 if let Ok((stream, _)) = listener.accept().await
1789 && let Ok(ws) = accept_async(stream).await
1790 {
1791 drop(ws);
1792 }
1793 sleep(Duration::from_millis(50)).await;
1794 });
1795
1796 let (handler, _rx) = channel_message_handler();
1797
1798 let config = WebSocketConfig {
1799 url: format!("ws://127.0.0.1:{port}"),
1800 headers: vec![],
1801 heartbeat: None,
1802 heartbeat_msg: None,
1803 reconnect_timeout_ms: Some(1_000),
1804 reconnect_delay_initial_ms: Some(50),
1805 reconnect_delay_max_ms: Some(100),
1806 reconnect_backoff_factor: Some(1.0),
1807 reconnect_jitter_ms: Some(0),
1808 reconnect_max_attempts: None,
1809 idle_timeout_ms: None,
1810 };
1811
1812 let client = WebSocketClient::connect(config, Some(handler), None, None, vec![], None)
1813 .await
1814 .unwrap();
1815
1816 tokio::time::timeout(Duration::from_secs(2), async {
1818 loop {
1819 if client.is_reconnecting() || client.is_closed() {
1820 break;
1821 }
1822 tokio::time::sleep(Duration::from_millis(10)).await;
1823 }
1824 })
1825 .await
1826 .expect("client should attempt reconnection or close");
1827
1828 assert!(
1831 client.is_reconnecting() || client.is_closed(),
1832 "Client with message handler should attempt reconnection"
1833 );
1834
1835 client.disconnect().await;
1836 server.abort();
1837 }
1838
1839 #[rstest]
1840 #[tokio::test]
1841 async fn test_handler_mode_reconnect_with_new_connection() {
1842 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1844 let port = listener.local_addr().unwrap().port();
1845
1846 let server = task::spawn(async move {
1847 if let Ok((stream, _)) = listener.accept().await
1849 && let Ok(ws) = accept_async(stream).await
1850 {
1851 drop(ws);
1852 }
1853
1854 sleep(Duration::from_millis(100)).await;
1856
1857 if let Ok((stream, _)) = listener.accept().await
1859 && let Ok(mut ws) = accept_async(stream).await
1860 {
1861 use futures_util::SinkExt;
1862 let _ = ws
1863 .send(Message::Text("reconnected".to_string().into()))
1864 .await;
1865 sleep(Duration::from_secs(1)).await;
1866 }
1867 });
1868
1869 let (handler, mut rx) = channel_message_handler();
1870
1871 let config = WebSocketConfig {
1872 url: format!("ws://127.0.0.1:{port}"),
1873 headers: vec![],
1874 heartbeat: None,
1875 heartbeat_msg: None,
1876 reconnect_timeout_ms: Some(2_000),
1877 reconnect_delay_initial_ms: Some(50),
1878 reconnect_delay_max_ms: Some(200),
1879 reconnect_backoff_factor: Some(1.5),
1880 reconnect_jitter_ms: Some(10),
1881 reconnect_max_attempts: None,
1882 idle_timeout_ms: None,
1883 };
1884
1885 let client = WebSocketClient::connect(config, Some(handler), None, None, vec![], None)
1886 .await
1887 .unwrap();
1888
1889 let result = tokio::time::timeout(Duration::from_secs(5), async {
1891 loop {
1892 if let Ok(msg) = rx.try_recv()
1893 && matches!(msg, Message::Text(ref text) if AsRef::<str>::as_ref(text) == "reconnected")
1894 {
1895 return true;
1896 }
1897 tokio::time::sleep(Duration::from_millis(10)).await;
1898 }
1899 })
1900 .await;
1901
1902 assert!(
1903 result.is_ok(),
1904 "Should receive message after reconnection within timeout"
1905 );
1906
1907 client.disconnect().await;
1908 server.abort();
1909 }
1910
1911 #[rstest]
1912 #[tokio::test]
1913 async fn test_stream_mode_no_auto_reconnect() {
1914 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1917 let port = listener.local_addr().unwrap().port();
1918
1919 let server = task::spawn(async move {
1920 if let Ok((stream, _)) = listener.accept().await
1922 && let Ok(mut ws) = accept_async(stream).await
1923 {
1924 use futures_util::SinkExt;
1925 let _ = ws.send(Message::Text("hello".to_string().into())).await;
1926 sleep(Duration::from_millis(50)).await;
1927 }
1929 });
1930
1931 let config = WebSocketConfig {
1932 url: format!("ws://127.0.0.1:{port}"),
1933 headers: vec![],
1934 heartbeat: None,
1935 heartbeat_msg: None,
1936 reconnect_timeout_ms: Some(1_000),
1937 reconnect_delay_initial_ms: Some(50),
1938 reconnect_delay_max_ms: Some(100),
1939 reconnect_backoff_factor: Some(1.0),
1940 reconnect_jitter_ms: Some(0),
1941 reconnect_max_attempts: None,
1942 idle_timeout_ms: None,
1943 };
1944
1945 let (mut reader, client) = WebSocketClient::connect_stream(config, vec![], None, None)
1946 .await
1947 .unwrap();
1948
1949 assert!(client.is_active(), "Client should start as active");
1951
1952 let msg = reader.next().await;
1954 assert!(
1955 matches!(msg, Some(Ok(Message::Text(ref text))) if AsRef::<str>::as_ref(text) == "hello"),
1956 "Should receive initial message"
1957 );
1958
1959 while let Some(msg) = reader.next().await {
1961 if msg.is_err() || matches!(msg, Ok(Message::Close(_))) {
1962 break;
1963 }
1964 }
1965
1966 sleep(Duration::from_millis(200)).await;
1969
1970 assert!(
1973 client.is_active() || client.is_closed(),
1974 "Stream mode client stays ACTIVE (caller owns reader) or caller disconnected"
1975 );
1976 assert!(
1977 !client.is_reconnecting(),
1978 "Stream mode client should never attempt reconnection"
1979 );
1980
1981 client.disconnect().await;
1982 server.abort();
1983 }
1984
1985 #[rstest]
1986 #[tokio::test]
1987 async fn test_send_timeout_uses_configured_reconnect_timeout() {
1988 use nautilus_common::testing::wait_until_async;
1991
1992 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1993 let port = listener.local_addr().unwrap().port();
1994
1995 let server = task::spawn(async move {
1996 if let Ok((stream, _)) = listener.accept().await
1998 && let Ok(ws) = accept_async(stream).await
1999 {
2000 drop(ws);
2001 }
2002 sleep(Duration::from_secs(60)).await;
2004 });
2005
2006 let (handler, _rx) = channel_message_handler();
2007
2008 let config = WebSocketConfig {
2010 url: format!("ws://127.0.0.1:{port}"),
2011 headers: vec![],
2012 heartbeat: None,
2013 heartbeat_msg: None,
2014 reconnect_timeout_ms: Some(2_000), reconnect_delay_initial_ms: Some(50),
2016 reconnect_delay_max_ms: Some(100),
2017 reconnect_backoff_factor: Some(1.0),
2018 reconnect_jitter_ms: Some(0),
2019 reconnect_max_attempts: None,
2020 idle_timeout_ms: None,
2021 };
2022
2023 let client = WebSocketClient::connect(config, Some(handler), None, None, vec![], None)
2024 .await
2025 .unwrap();
2026
2027 wait_until_async(
2029 || async { client.is_reconnecting() },
2030 Duration::from_secs(3),
2031 )
2032 .await;
2033
2034 let start = std::time::Instant::now();
2036 let send_result = client.send_text("test".to_string(), None).await;
2037 let elapsed = start.elapsed();
2038
2039 assert!(
2040 send_result.is_err(),
2041 "Send should fail when client stuck in RECONNECT"
2042 );
2043 assert!(
2044 matches!(send_result, Err(crate::error::SendError::Timeout)),
2045 "Send should return Timeout error, was: {send_result:?}"
2046 );
2047 assert!(
2050 elapsed >= Duration::from_millis(1800),
2051 "Send should timeout after at least 2s (configured timeout), took {elapsed:?}"
2052 );
2053
2054 client.disconnect().await;
2055 server.abort();
2056 }
2057
2058 #[rstest]
2059 #[tokio::test]
2060 async fn test_send_waits_during_reconnection() {
2061 use nautilus_common::testing::wait_until_async;
2063
2064 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
2065 let port = listener.local_addr().unwrap().port();
2066
2067 let server = task::spawn(async move {
2068 if let Ok((stream, _)) = listener.accept().await
2070 && let Ok(ws) = accept_async(stream).await
2071 {
2072 drop(ws);
2073 }
2074
2075 sleep(Duration::from_millis(500)).await;
2077
2078 if let Ok((stream, _)) = listener.accept().await
2080 && let Ok(mut ws) = accept_async(stream).await
2081 {
2082 while let Some(Ok(msg)) = ws.next().await {
2084 if ws.send(msg).await.is_err() {
2085 break;
2086 }
2087 }
2088 }
2089 });
2090
2091 let (handler, _rx) = channel_message_handler();
2092
2093 let config = WebSocketConfig {
2094 url: format!("ws://127.0.0.1:{port}"),
2095 headers: vec![],
2096 heartbeat: None,
2097 heartbeat_msg: None,
2098 reconnect_timeout_ms: Some(5_000), reconnect_delay_initial_ms: Some(100),
2100 reconnect_delay_max_ms: Some(200),
2101 reconnect_backoff_factor: Some(1.0),
2102 reconnect_jitter_ms: Some(0),
2103 reconnect_max_attempts: None,
2104 idle_timeout_ms: None,
2105 };
2106
2107 let client = WebSocketClient::connect(config, Some(handler), None, None, vec![], None)
2108 .await
2109 .unwrap();
2110
2111 wait_until_async(
2113 || async { client.is_reconnecting() },
2114 Duration::from_secs(2),
2115 )
2116 .await;
2117
2118 let send_result = tokio::time::timeout(
2120 Duration::from_secs(3),
2121 client.send_text("test_message".to_string(), None),
2122 )
2123 .await;
2124
2125 assert!(
2126 send_result.is_ok() && send_result.unwrap().is_ok(),
2127 "Send should succeed after waiting for reconnection"
2128 );
2129
2130 client.disconnect().await;
2131 server.abort();
2132 }
2133
2134 #[rstest]
2135 #[tokio::test]
2136 async fn test_rate_limiter_before_active_wait() {
2137 use std::{num::NonZeroU32, sync::Arc};
2142
2143 use nautilus_common::testing::wait_until_async;
2144
2145 use crate::ratelimiter::quota::Quota;
2146
2147 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
2148 let port = listener.local_addr().unwrap().port();
2149
2150 let server = task::spawn(async move {
2151 if let Ok((stream, _)) = listener.accept().await
2153 && let Ok(mut ws) = accept_async(stream).await
2154 {
2155 if let Some(Ok(_)) = ws.next().await {
2157 drop(ws);
2158 }
2159 }
2160
2161 sleep(Duration::from_millis(500)).await;
2163
2164 if let Ok((stream, _)) = listener.accept().await
2166 && let Ok(mut ws) = accept_async(stream).await
2167 {
2168 while let Some(Ok(msg)) = ws.next().await {
2169 if ws.send(msg).await.is_err() {
2170 break;
2171 }
2172 }
2173 }
2174 });
2175
2176 let (handler, _rx) = channel_message_handler();
2177
2178 let config = WebSocketConfig {
2179 url: format!("ws://127.0.0.1:{port}"),
2180 headers: vec![],
2181 heartbeat: None,
2182 heartbeat_msg: None,
2183 reconnect_timeout_ms: Some(5_000),
2184 reconnect_delay_initial_ms: Some(50),
2185 reconnect_delay_max_ms: Some(100),
2186 reconnect_backoff_factor: Some(1.0),
2187 reconnect_jitter_ms: Some(0),
2188 reconnect_max_attempts: None,
2189 idle_timeout_ms: None,
2190 };
2191
2192 let quota = Quota::per_second(NonZeroU32::new(1).unwrap())
2194 .unwrap()
2195 .allow_burst(NonZeroU32::new(1).unwrap());
2196
2197 let client = Arc::new(
2198 WebSocketClient::connect(
2199 config,
2200 Some(handler),
2201 None,
2202 None,
2203 vec![("test_key".to_string(), quota)],
2204 None,
2205 )
2206 .await
2207 .unwrap(),
2208 );
2209
2210 let test_key: [Ustr; 1] = [Ustr::from("test_key")];
2212 client
2213 .send_text("msg1".to_string(), Some(test_key.as_slice()))
2214 .await
2215 .unwrap();
2216
2217 wait_until_async(
2219 || async { client.is_reconnecting() },
2220 Duration::from_secs(2),
2221 )
2222 .await;
2223
2224 let start = std::time::Instant::now();
2226 let send_result = client
2227 .send_text("msg2".to_string(), Some(test_key.as_slice()))
2228 .await;
2229 let elapsed = start.elapsed();
2230
2231 assert!(
2233 send_result.is_ok(),
2234 "Send should succeed after rate limit + reconnection, was: {send_result:?}"
2235 );
2236 assert!(
2240 elapsed >= Duration::from_millis(850),
2241 "Should wait for rate limit (~1s), waited {elapsed:?}"
2242 );
2243
2244 client.disconnect().await;
2245 server.abort();
2246 }
2247
2248 #[rstest]
2249 #[tokio::test]
2250 async fn test_disconnect_during_reconnect_exits_cleanly() {
2251 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
2254 let port = listener.local_addr().unwrap().port();
2255
2256 let server = task::spawn(async move {
2257 if let Ok((stream, _)) = listener.accept().await
2259 && let Ok(ws) = accept_async(stream).await
2260 {
2261 drop(ws);
2262 }
2263 sleep(Duration::from_secs(60)).await;
2265 });
2266
2267 let (handler, _rx) = channel_message_handler();
2268
2269 let config = WebSocketConfig {
2270 url: format!("ws://127.0.0.1:{port}"),
2271 headers: vec![],
2272 heartbeat: None,
2273 heartbeat_msg: None,
2274 reconnect_timeout_ms: Some(2_000), reconnect_delay_initial_ms: Some(100),
2276 reconnect_delay_max_ms: Some(200),
2277 reconnect_backoff_factor: Some(1.0),
2278 reconnect_jitter_ms: Some(0),
2279 reconnect_max_attempts: None,
2280 idle_timeout_ms: None,
2281 };
2282
2283 let client = WebSocketClient::connect(config, Some(handler), None, None, vec![], None)
2284 .await
2285 .unwrap();
2286
2287 tokio::time::timeout(Duration::from_secs(2), async {
2289 while !client.is_reconnecting() {
2290 sleep(Duration::from_millis(10)).await;
2291 }
2292 })
2293 .await
2294 .expect("Client should enter RECONNECT state");
2295
2296 client.disconnect().await;
2298
2299 assert!(
2301 client.is_disconnected(),
2302 "Client should be cleanly disconnected"
2303 );
2304
2305 server.abort();
2306 }
2307
2308 #[rstest]
2309 #[tokio::test]
2310 async fn test_send_fails_fast_when_closed_before_rate_limit() {
2311 use std::{num::NonZeroU32, sync::Arc};
2314
2315 use nautilus_common::testing::wait_until_async;
2316
2317 use crate::ratelimiter::quota::Quota;
2318
2319 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
2320 let port = listener.local_addr().unwrap().port();
2321
2322 let server = task::spawn(async move {
2323 if let Ok((stream, _)) = listener.accept().await
2325 && let Ok(ws) = accept_async(stream).await
2326 {
2327 drop(ws);
2328 }
2329 sleep(Duration::from_secs(60)).await;
2330 });
2331
2332 let (handler, _rx) = channel_message_handler();
2333
2334 let config = WebSocketConfig {
2335 url: format!("ws://127.0.0.1:{port}"),
2336 headers: vec![],
2337 heartbeat: None,
2338 heartbeat_msg: None,
2339 reconnect_timeout_ms: Some(5_000),
2340 reconnect_delay_initial_ms: Some(50),
2341 reconnect_delay_max_ms: Some(100),
2342 reconnect_backoff_factor: Some(1.0),
2343 reconnect_jitter_ms: Some(0),
2344 reconnect_max_attempts: None,
2345 idle_timeout_ms: None,
2346 };
2347
2348 let quota = Quota::with_period(Duration::from_secs(10))
2351 .unwrap()
2352 .allow_burst(NonZeroU32::new(1).unwrap());
2353
2354 let client = Arc::new(
2355 WebSocketClient::connect(
2356 config,
2357 Some(handler),
2358 None,
2359 None,
2360 vec![("test_key".to_string(), quota)],
2361 None,
2362 )
2363 .await
2364 .unwrap(),
2365 );
2366
2367 wait_until_async(
2369 || async { client.is_reconnecting() || client.is_closed() },
2370 Duration::from_secs(2),
2371 )
2372 .await;
2373
2374 client.disconnect().await;
2376 assert!(
2377 !client.is_active(),
2378 "Client should not be active after disconnect"
2379 );
2380
2381 let start = std::time::Instant::now();
2383 let test_key: [Ustr; 1] = [Ustr::from("test_key")];
2384 let result = client
2385 .send_text("test".to_string(), Some(test_key.as_slice()))
2386 .await;
2387 let elapsed = start.elapsed();
2388
2389 assert!(result.is_err(), "Send should fail when client is closed");
2391 assert!(
2392 matches!(result, Err(crate::error::SendError::Closed)),
2393 "Send should return Closed error, was: {result:?}"
2394 );
2395
2396 assert!(
2398 elapsed < Duration::from_millis(100),
2399 "Send should fail fast without rate limiting, took {elapsed:?}"
2400 );
2401
2402 server.abort();
2403 }
2404
2405 #[rstest]
2406 #[tokio::test]
2407 async fn test_connect_rejects_none_message_handler() {
2408 let config = WebSocketConfig {
2412 url: "ws://127.0.0.1:9999".to_string(),
2413 headers: vec![],
2414 heartbeat: None,
2415 heartbeat_msg: None,
2416 reconnect_timeout_ms: Some(1_000),
2417 reconnect_delay_initial_ms: Some(100),
2418 reconnect_delay_max_ms: Some(500),
2419 reconnect_backoff_factor: Some(1.5),
2420 reconnect_jitter_ms: Some(0),
2421 reconnect_max_attempts: None,
2422 idle_timeout_ms: None,
2423 };
2424
2425 let result = WebSocketClient::connect(config, None, None, None, vec![], None).await;
2427
2428 assert!(
2429 result.is_err(),
2430 "connect() should reject None message_handler"
2431 );
2432
2433 let err = result.unwrap_err();
2434 let err_msg = err.to_string();
2435 assert!(
2436 err_msg.contains("Handler mode requires message_handler"),
2437 "Error should mention missing message_handler, was: {err_msg}"
2438 );
2439 }
2440
2441 #[rstest]
2442 #[tokio::test]
2443 async fn test_client_without_handler_sets_stream_mode() {
2444 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
2448 let port = listener.local_addr().unwrap().port();
2449
2450 let server = task::spawn(async move {
2451 if let Ok((stream, _)) = listener.accept().await
2453 && let Ok(ws) = accept_async(stream).await
2454 {
2455 drop(ws); }
2457 });
2458
2459 let config = WebSocketConfig {
2460 url: format!("ws://127.0.0.1:{port}"),
2461 headers: vec![],
2462 heartbeat: None,
2463 heartbeat_msg: None,
2464 reconnect_timeout_ms: Some(1_000),
2465 reconnect_delay_initial_ms: Some(100),
2466 reconnect_delay_max_ms: Some(500),
2467 reconnect_backoff_factor: Some(1.5),
2468 reconnect_jitter_ms: Some(0),
2469 reconnect_max_attempts: None,
2470 idle_timeout_ms: None,
2471 };
2472
2473 let inner = WebSocketClientInner::connect_url(config, None, None)
2475 .await
2476 .unwrap();
2477
2478 assert!(
2480 inner.is_stream_mode,
2481 "Client without handler should have is_stream_mode=true"
2482 );
2483
2484 server.abort();
2488 }
2489
2490 #[rstest]
2491 #[tokio::test]
2492 async fn test_idle_timeout_triggers_reconnect() {
2493 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
2494 let port = listener.local_addr().unwrap().port();
2495
2496 let server = task::spawn(async move {
2498 let (stream, _) = listener.accept().await.unwrap();
2499 let _ws = accept_async(stream).await.unwrap();
2500 sleep(Duration::from_secs(5)).await;
2502 });
2503
2504 let (handler, _rx) = channel_message_handler();
2505
2506 let config = WebSocketConfig {
2507 url: format!("ws://127.0.0.1:{port}"),
2508 headers: vec![],
2509 heartbeat: None,
2510 heartbeat_msg: None,
2511 reconnect_timeout_ms: Some(2_000),
2512 reconnect_delay_initial_ms: Some(50),
2513 reconnect_delay_max_ms: Some(100),
2514 reconnect_backoff_factor: Some(1.0),
2515 reconnect_jitter_ms: Some(0),
2516 reconnect_max_attempts: Some(1),
2517 idle_timeout_ms: Some(500),
2518 };
2519
2520 let client = WebSocketClient::connect(config, Some(handler), None, None, vec![], None)
2521 .await
2522 .unwrap();
2523
2524 assert!(client.is_active());
2525
2526 wait_until_async(
2528 || async { client.is_reconnecting() || client.is_disconnected() },
2529 Duration::from_secs(3),
2530 )
2531 .await;
2532
2533 assert!(
2534 !client.is_active(),
2535 "Client should not be active after idle timeout"
2536 );
2537
2538 client.disconnect().await;
2539 server.abort();
2540 }
2541
2542 #[rstest]
2543 #[tokio::test]
2544 async fn test_idle_timeout_resets_on_data() {
2545 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
2546 let port = listener.local_addr().unwrap().port();
2547
2548 let server = task::spawn(async move {
2550 let (stream, _) = listener.accept().await.unwrap();
2551 let mut ws = accept_async(stream).await.unwrap();
2552 for _ in 0..10 {
2553 sleep(Duration::from_millis(200)).await;
2554 if ws
2555 .send(tokio_tungstenite::tungstenite::Message::Text("ping".into()))
2556 .await
2557 .is_err()
2558 {
2559 break;
2560 }
2561 }
2562 });
2563
2564 let (handler, _rx) = channel_message_handler();
2565
2566 let config = WebSocketConfig {
2567 url: format!("ws://127.0.0.1:{port}"),
2568 headers: vec![],
2569 heartbeat: None,
2570 heartbeat_msg: None,
2571 reconnect_timeout_ms: Some(2_000),
2572 reconnect_delay_initial_ms: Some(50),
2573 reconnect_delay_max_ms: Some(100),
2574 reconnect_backoff_factor: Some(1.0),
2575 reconnect_jitter_ms: Some(0),
2576 reconnect_max_attempts: Some(1),
2577 idle_timeout_ms: Some(1_000),
2578 };
2579
2580 let client = WebSocketClient::connect(config, Some(handler), None, None, vec![], None)
2581 .await
2582 .unwrap();
2583
2584 assert!(client.is_active());
2585
2586 sleep(Duration::from_millis(1_500)).await;
2588
2589 assert!(
2590 client.is_active(),
2591 "Client should remain active when data is flowing"
2592 );
2593
2594 client.disconnect().await;
2595 server.abort();
2596 }
2597
2598 #[rstest]
2599 #[tokio::test]
2600 async fn test_zero_idle_timeout_rejected() {
2601 let (handler, _rx) = channel_message_handler();
2602
2603 let config = WebSocketConfig {
2604 url: "ws://127.0.0.1:9999".to_string(),
2605 headers: vec![],
2606 heartbeat: None,
2607 heartbeat_msg: None,
2608 reconnect_timeout_ms: None,
2609 reconnect_delay_initial_ms: None,
2610 reconnect_delay_max_ms: None,
2611 reconnect_backoff_factor: None,
2612 reconnect_jitter_ms: None,
2613 reconnect_max_attempts: None,
2614 idle_timeout_ms: Some(0),
2615 };
2616
2617 let result =
2618 WebSocketClient::connect(config, Some(handler), None, None, vec![], None).await;
2619
2620 assert!(result.is_err(), "Zero idle timeout should be rejected");
2621 let err_msg = result.unwrap_err().to_string();
2622 assert!(
2623 err_msg.contains("Idle timeout cannot be zero"),
2624 "Error should mention zero idle timeout, was: {err_msg}"
2625 );
2626 }
2627}