1use std::{
27 collections::VecDeque,
28 fmt::Debug,
29 sync::{
30 Arc,
31 atomic::{AtomicU8, Ordering},
32 },
33 time::Duration,
34};
35
36use futures_util::{SinkExt, StreamExt};
37use http::HeaderName;
38use nautilus_core::CleanDrop;
39use nautilus_cryptography::providers::install_cryptographic_provider;
40#[cfg(feature = "turmoil")]
41use tokio_tungstenite::MaybeTlsStream;
42#[cfg(feature = "turmoil")]
43use tokio_tungstenite::client_async;
44#[cfg(not(feature = "turmoil"))]
45use tokio_tungstenite::connect_async_with_config;
46use tokio_tungstenite::tungstenite::{
47 Error, Message, client::IntoClientRequest, http::HeaderValue,
48};
49use ustr::Ustr;
50
51use super::{
52 config::WebSocketConfig,
53 consts::{
54 CONNECTION_STATE_CHECK_INTERVAL_MS, GRACEFUL_SHUTDOWN_DELAY_MS,
55 GRACEFUL_SHUTDOWN_TIMEOUT_SECS,
56 },
57 types::{MessageHandler, MessageReader, MessageWriter, PingHandler, WriterCommand},
58};
59#[cfg(feature = "turmoil")]
60use crate::net::TcpConnector;
61use crate::{
62 RECONNECTED,
63 backoff::ExponentialBackoff,
64 error::SendError,
65 logging::{log_task_aborted, log_task_started, log_task_stopped},
66 mode::ConnectionMode,
67 ratelimiter::{RateLimiter, clock::MonotonicClock, quota::Quota},
68};
69
70pub struct WebSocketClientInner {
86 config: WebSocketConfig,
87 message_handler: Option<MessageHandler>,
89 ping_handler: Option<PingHandler>,
91 read_task: Option<tokio::task::JoinHandle<()>>,
92 write_task: tokio::task::JoinHandle<()>,
93 writer_tx: tokio::sync::mpsc::UnboundedSender<WriterCommand>,
94 heartbeat_task: Option<tokio::task::JoinHandle<()>>,
95 connection_mode: Arc<AtomicU8>,
96 state_notify: Arc<tokio::sync::Notify>,
97 reconnect_timeout: Duration,
98 backoff: ExponentialBackoff,
99 is_stream_mode: bool,
103 reconnect_max_attempts: Option<u32>,
105 reconnection_attempt_count: u32,
107}
108
109impl WebSocketClientInner {
110 pub async fn new_with_writer(
118 config: WebSocketConfig,
119 writer: MessageWriter,
120 ) -> Result<Self, Error> {
121 install_cryptographic_provider();
122
123 let connection_mode = Arc::new(AtomicU8::new(ConnectionMode::Active.as_u8()));
124 let state_notify = Arc::new(tokio::sync::Notify::new());
125
126 let read_task = None;
128
129 let backoff = ExponentialBackoff::new(
131 Duration::from_secs(2),
132 Duration::from_secs(30),
133 1.5,
134 100,
135 true,
136 )
137 .map_err(|e| Error::Io(std::io::Error::new(std::io::ErrorKind::InvalidInput, e)))?;
138
139 let (writer_tx, writer_rx) = tokio::sync::mpsc::unbounded_channel::<WriterCommand>();
140 let write_task = Self::spawn_write_task(
141 connection_mode.clone(),
142 state_notify.clone(),
143 writer,
144 writer_rx,
145 );
146
147 let heartbeat_task = if let Some(heartbeat_interval) = config.heartbeat {
148 Some(Self::spawn_heartbeat_task(
149 connection_mode.clone(),
150 heartbeat_interval,
151 config.heartbeat_msg.clone(),
152 writer_tx.clone(),
153 ))
154 } else {
155 None
156 };
157
158 let reconnect_max_attempts = None; let reconnect_timeout = Duration::from_secs(10);
160
161 Ok(Self {
162 config,
163 message_handler: None, ping_handler: None,
165 writer_tx,
166 connection_mode,
167 state_notify,
168 reconnect_timeout,
169 heartbeat_task,
170 read_task,
171 write_task,
172 backoff,
173 is_stream_mode: true,
174 reconnect_max_attempts,
175 reconnection_attempt_count: 0,
176 })
177 }
178
179 pub async fn connect_url(
187 config: WebSocketConfig,
188 message_handler: Option<MessageHandler>,
189 ping_handler: Option<PingHandler>,
190 ) -> Result<Self, Error> {
191 install_cryptographic_provider();
192
193 if config.heartbeat == Some(0) {
194 return Err(Error::Io(std::io::Error::new(
195 std::io::ErrorKind::InvalidInput,
196 "Heartbeat interval cannot be zero",
197 )));
198 }
199
200 if config.idle_timeout_ms == Some(0) {
201 return Err(Error::Io(std::io::Error::new(
202 std::io::ErrorKind::InvalidInput,
203 "Idle timeout cannot be zero",
204 )));
205 }
206
207 let is_stream_mode = message_handler.is_none();
209 let reconnect_max_attempts = config.reconnect_max_attempts;
210
211 let (writer, reader) =
212 Self::connect_with_server(&config.url, config.headers.clone()).await?;
213
214 let connection_mode = Arc::new(AtomicU8::new(ConnectionMode::Active.as_u8()));
215 let state_notify = Arc::new(tokio::sync::Notify::new());
216
217 let read_task = if message_handler.is_some() {
218 Some(Self::spawn_message_handler_task(
219 connection_mode.clone(),
220 state_notify.clone(),
221 reader,
222 message_handler.as_ref(),
223 ping_handler.as_ref(),
224 config.idle_timeout_ms,
225 ))
226 } else {
227 None
228 };
229
230 let (writer_tx, writer_rx) = tokio::sync::mpsc::unbounded_channel::<WriterCommand>();
231 let write_task = Self::spawn_write_task(
232 connection_mode.clone(),
233 state_notify.clone(),
234 writer,
235 writer_rx,
236 );
237
238 let heartbeat_task = config.heartbeat.map(|heartbeat_secs| {
240 Self::spawn_heartbeat_task(
241 connection_mode.clone(),
242 heartbeat_secs,
243 config.heartbeat_msg.clone(),
244 writer_tx.clone(),
245 )
246 });
247
248 let reconnect_timeout =
249 Duration::from_millis(config.reconnect_timeout_ms.unwrap_or(10_000));
250 let backoff = ExponentialBackoff::new(
251 Duration::from_millis(config.reconnect_delay_initial_ms.unwrap_or(2_000)),
252 Duration::from_millis(config.reconnect_delay_max_ms.unwrap_or(30_000)),
253 config.reconnect_backoff_factor.unwrap_or(1.5),
254 config.reconnect_jitter_ms.unwrap_or(100),
255 true, )
257 .map_err(|e| Error::Io(std::io::Error::new(std::io::ErrorKind::InvalidInput, e)))?;
258
259 Ok(Self {
260 config,
261 message_handler,
262 ping_handler,
263 read_task,
264 write_task,
265 writer_tx,
266 heartbeat_task,
267 connection_mode,
268 state_notify,
269 reconnect_timeout,
270 backoff,
271 is_stream_mode,
273 reconnect_max_attempts,
274 reconnection_attempt_count: 0,
275 })
276 }
277
278 #[inline]
288 #[cfg(not(feature = "turmoil"))]
289 pub async fn connect_with_server(
290 url: &str,
291 headers: Vec<(String, String)>,
292 ) -> Result<(MessageWriter, MessageReader), Error> {
293 let mut request = url.into_client_request()?;
294 let req_headers = request.headers_mut();
295
296 let mut header_names: Vec<HeaderName> = Vec::new();
297 for (key, val) in headers {
298 let header_value = HeaderValue::from_str(&val)?;
299 let header_name: HeaderName = key.parse()?;
300 header_names.push(header_name.clone());
301 req_headers.insert(header_name, header_value);
302 }
303
304 connect_async_with_config(request, None, true)
305 .await
306 .map(|resp| resp.0.split())
307 }
308
309 #[inline]
322 #[cfg(feature = "turmoil")]
323 pub async fn connect_with_server(
324 url: &str,
325 headers: Vec<(String, String)>,
326 ) -> Result<(MessageWriter, MessageReader), Error> {
327 use rustls::ClientConfig;
328 use tokio_rustls::TlsConnector;
329
330 let mut request = url.into_client_request()?;
331 let req_headers = request.headers_mut();
332
333 let mut header_names: Vec<HeaderName> = Vec::new();
334 for (key, val) in headers {
335 let header_value = HeaderValue::from_str(&val)?;
336 let header_name: HeaderName = key.parse()?;
337 header_names.push(header_name.clone());
338 req_headers.insert(header_name, header_value);
339 }
340
341 let uri = request.uri();
342 let scheme = uri.scheme_str().unwrap_or("ws");
343 let host = uri.host().ok_or_else(|| {
344 Error::Url(tokio_tungstenite::tungstenite::error::UrlError::NoHostName)
345 })?;
346
347 let port = uri
349 .port_u16()
350 .unwrap_or_else(|| if scheme == "wss" { 443 } else { 80 });
351
352 let addr = format!("{host}:{port}");
353
354 let connector = crate::net::RealTcpConnector;
356 let tcp_stream = connector.connect(&addr).await?;
357 if let Err(e) = tcp_stream.set_nodelay(true) {
358 log::warn!("Failed to enable TCP_NODELAY for socket client: {e:?}");
359 }
360
361 let maybe_tls_stream = if scheme == "wss" {
363 let mut root_store = rustls::RootCertStore::empty();
365 root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
366
367 let config = ClientConfig::builder()
368 .with_root_certificates(root_store)
369 .with_no_client_auth();
370
371 let tls_connector = TlsConnector::from(std::sync::Arc::new(config));
372 let domain =
373 rustls::pki_types::ServerName::try_from(host.to_string()).map_err(|e| {
374 Error::Io(std::io::Error::new(
375 std::io::ErrorKind::InvalidInput,
376 format!("Invalid DNS name: {e}"),
377 ))
378 })?;
379
380 let tls_stream = tls_connector.connect(domain, tcp_stream).await?;
381 MaybeTlsStream::Rustls(tls_stream)
382 } else {
383 MaybeTlsStream::Plain(tcp_stream)
384 };
385
386 client_async(request, maybe_tls_stream)
388 .await
389 .map(|resp| resp.0.split())
390 }
391
392 pub async fn reconnect(&mut self) -> Result<(), Error> {
407 log::debug!("Reconnecting");
408
409 if self.is_stream_mode {
410 log::warn!(
411 "Auto-reconnect disabled for stream-based WebSocket client; \
412 stream users must manually reconnect by creating a new connection"
413 );
414 self.connection_mode
416 .store(ConnectionMode::Closed.as_u8(), Ordering::SeqCst);
417 return Ok(());
418 }
419
420 if ConnectionMode::from_atomic(&self.connection_mode).is_disconnect() {
421 log::debug!("Reconnect aborted due to disconnect state");
422 return Ok(());
423 }
424
425 tokio::time::timeout(self.reconnect_timeout, async {
426 let (new_writer, reader) =
428 Self::connect_with_server(&self.config.url, self.config.headers.clone()).await?;
429
430 if ConnectionMode::from_atomic(&self.connection_mode).is_disconnect() {
431 log::debug!("Reconnect aborted mid-flight (after connect)");
432 return Ok(());
433 }
434
435 let (tx, rx) = tokio::sync::oneshot::channel();
439 if let Err(e) = self.writer_tx.send(WriterCommand::Update(new_writer, tx)) {
440 log::error!("{e}");
441 return Err(Error::Io(std::io::Error::new(
442 std::io::ErrorKind::BrokenPipe,
443 format!("Failed to send update command: {e}"),
444 )));
445 }
446
447 match rx.await {
449 Ok(true) => log::debug!("Writer confirmed buffer drain success"),
450 Ok(false) => {
451 log::warn!("Writer failed to drain buffer, aborting reconnect");
452 return Err(Error::Io(std::io::Error::other(
454 "Failed to drain reconnection buffer",
455 )));
456 }
457 Err(e) => {
458 log::error!("Writer dropped update channel: {e}");
459 return Err(Error::Io(std::io::Error::new(
460 std::io::ErrorKind::BrokenPipe,
461 "Writer task dropped response channel",
462 )));
463 }
464 }
465
466 tokio::time::sleep(Duration::from_millis(GRACEFUL_SHUTDOWN_DELAY_MS)).await;
468
469 if ConnectionMode::from_atomic(&self.connection_mode).is_disconnect() {
470 log::debug!("Reconnect aborted mid-flight (after delay)");
471 return Ok(());
472 }
473
474 if let Some(ref read_task) = self.read_task.take()
475 && !read_task.is_finished()
476 {
477 read_task.abort();
478 log_task_aborted("read");
479 }
480
481 if self
484 .connection_mode
485 .compare_exchange(
486 ConnectionMode::Reconnect.as_u8(),
487 ConnectionMode::Active.as_u8(),
488 Ordering::SeqCst,
489 Ordering::SeqCst,
490 )
491 .is_err()
492 {
493 log::debug!("Reconnect aborted (state changed during reconnect)");
494 return Ok(());
495 }
496
497 self.read_task = if self.message_handler.is_some() {
498 Some(Self::spawn_message_handler_task(
499 self.connection_mode.clone(),
500 self.state_notify.clone(),
501 reader,
502 self.message_handler.as_ref(),
503 self.ping_handler.as_ref(),
504 self.config.idle_timeout_ms,
505 ))
506 } else {
507 None
508 };
509
510 log::debug!("Reconnect succeeded");
511 Ok(())
512 })
513 .await
514 .map_err(|_| {
515 Error::Io(std::io::Error::new(
516 std::io::ErrorKind::TimedOut,
517 format!(
518 "reconnection timed out after {}s",
519 self.reconnect_timeout.as_secs_f64()
520 ),
521 ))
522 })?
523 }
524
525 #[inline]
531 #[must_use]
532 pub fn is_alive(&self) -> bool {
533 match &self.read_task {
534 Some(read_task) => !read_task.is_finished() && !self.write_task.is_finished(),
535 None => !self.write_task.is_finished(),
536 }
537 }
538
539 fn spawn_message_handler_task(
540 connection_state: Arc<AtomicU8>,
541 state_notify: Arc<tokio::sync::Notify>,
542 mut reader: MessageReader,
543 message_handler: Option<&MessageHandler>,
544 ping_handler: Option<&PingHandler>,
545 idle_timeout_ms: Option<u64>,
546 ) -> tokio::task::JoinHandle<()> {
547 log::debug!("Started message handler task 'read'");
548
549 let check_interval = Duration::from_millis(CONNECTION_STATE_CHECK_INTERVAL_MS);
550 let idle_timeout = idle_timeout_ms.map(Duration::from_millis);
551
552 let message_handler = message_handler.cloned();
554 let ping_handler = ping_handler.cloned();
555
556 tokio::task::spawn(async move {
557 let mut last_data_time = tokio::time::Instant::now();
558
559 loop {
560 if !ConnectionMode::from_atomic(&connection_state).is_active() {
561 break;
562 }
563
564 match tokio::time::timeout(check_interval, reader.next()).await {
565 Ok(Some(Ok(Message::Binary(data)))) => {
566 log::trace!("Received message <binary> {} bytes", data.len());
567 last_data_time = tokio::time::Instant::now();
568
569 if let Some(ref handler) = message_handler {
570 handler(Message::Binary(data));
571 }
572 }
573 Ok(Some(Ok(Message::Text(data)))) => {
574 log::trace!("Received message: {data}");
575 last_data_time = tokio::time::Instant::now();
576
577 if let Some(ref handler) = message_handler {
578 handler(Message::Text(data));
579 }
580 }
581 Ok(Some(Ok(Message::Ping(ping_data)))) => {
582 log::trace!("Received ping: {ping_data:?}");
583 last_data_time = tokio::time::Instant::now();
584
585 if let Some(ref handler) = ping_handler {
586 handler(ping_data.to_vec());
587 }
588 }
589 Ok(Some(Ok(Message::Pong(_)))) => {
590 log::trace!("Received pong");
591 last_data_time = tokio::time::Instant::now();
592 }
593 Ok(Some(Ok(Message::Close(_)))) => {
594 log::debug!("Received close message - terminating");
595 break;
596 }
597 Ok(Some(Ok(_))) => (),
598 Ok(Some(Err(e))) => {
599 log::error!("Received error message - terminating: {e}");
600 break;
601 }
602 Ok(None) => {
603 log::debug!("No message received - terminating");
604 break;
605 }
606 Err(_) => {
607 if let Some(timeout) = idle_timeout {
608 let idle_duration = last_data_time.elapsed();
609 if idle_duration >= timeout {
610 log::warn!(
611 "Read idle timeout: no data received for {:.1}s",
612 idle_duration.as_secs_f64()
613 );
614 break;
615 }
616 }
617 }
618 }
619 }
620
621 state_notify.notify_one();
623 })
624 }
625
626 async fn drain_reconnect_buffer(
631 buffer: &mut VecDeque<Message>,
632 writer: &mut MessageWriter,
633 ) -> bool {
634 if buffer.is_empty() {
635 return false;
636 }
637
638 let initial_buffer_len = buffer.len();
639 log::info!("Sending {initial_buffer_len} buffered messages after reconnection");
640
641 let mut send_error_occurred = false;
642
643 while let Some(buffered_msg) = buffer.front() {
644 let msg_to_send = buffered_msg.clone();
646
647 if let Err(e) = writer.send(msg_to_send).await {
648 log::error!(
649 "Failed to send buffered message after reconnection: {e}, {} messages remain in buffer",
650 buffer.len()
651 );
652 send_error_occurred = true;
653 break; }
655
656 buffer.pop_front();
658 }
659
660 if buffer.is_empty() {
661 log::info!("Successfully sent all {initial_buffer_len} buffered messages");
662 }
663
664 send_error_occurred
665 }
666
667 fn spawn_write_task(
668 connection_state: Arc<AtomicU8>,
669 state_notify: Arc<tokio::sync::Notify>,
670 writer: MessageWriter,
671 mut writer_rx: tokio::sync::mpsc::UnboundedReceiver<WriterCommand>,
672 ) -> tokio::task::JoinHandle<()> {
673 log_task_started("write");
674
675 let check_interval = Duration::from_millis(CONNECTION_STATE_CHECK_INTERVAL_MS);
677
678 tokio::task::spawn(async move {
679 let mut active_writer = writer;
680 let mut reconnect_buffer: VecDeque<Message> = VecDeque::new();
683
684 loop {
685 match ConnectionMode::from_atomic(&connection_state) {
686 ConnectionMode::Disconnect => {
687 if !reconnect_buffer.is_empty() {
689 log::warn!(
690 "Discarding {} buffered messages due to disconnect",
691 reconnect_buffer.len()
692 );
693 reconnect_buffer.clear();
694 }
695
696 _ = tokio::time::timeout(
699 Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS),
700 active_writer.close(),
701 )
702 .await;
703 break;
704 }
705 ConnectionMode::Closed => {
706 if !reconnect_buffer.is_empty() {
708 log::warn!(
709 "Discarding {} buffered messages due to closed connection",
710 reconnect_buffer.len()
711 );
712 reconnect_buffer.clear();
713 }
714 break;
715 }
716 _ => {}
717 }
718
719 match tokio::time::timeout(check_interval, writer_rx.recv()).await {
720 Ok(Some(msg)) => {
721 let mode = ConnectionMode::from_atomic(&connection_state);
723 if matches!(mode, ConnectionMode::Disconnect | ConnectionMode::Closed) {
724 break;
725 }
726
727 match msg {
728 WriterCommand::Update(new_writer, tx) => {
729 log::debug!("Received new writer");
730
731 tokio::time::sleep(Duration::from_millis(100)).await;
733
734 _ = tokio::time::timeout(
737 Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS),
738 active_writer.close(),
739 )
740 .await;
741
742 active_writer = new_writer;
743 log::debug!("Updated writer");
744
745 let send_error = Self::drain_reconnect_buffer(
746 &mut reconnect_buffer,
747 &mut active_writer,
748 )
749 .await;
750
751 if let Err(e) = tx.send(!send_error) {
752 log::error!(
753 "Failed to report drain status to controller: {e:?}"
754 );
755 }
756 }
757 WriterCommand::Send(msg) if mode.is_reconnect() => {
758 log::debug!(
760 "Buffering message during reconnection (buffer size: {})",
761 reconnect_buffer.len() + 1
762 );
763 reconnect_buffer.push_back(msg);
764 }
765 WriterCommand::Send(msg) => {
766 if let Err(e) = active_writer.send(msg.clone()).await {
767 log::error!("Failed to send message: {e}");
768 log::warn!("Writer triggering reconnect");
769 reconnect_buffer.push_back(msg);
770 connection_state
771 .store(ConnectionMode::Reconnect.as_u8(), Ordering::SeqCst);
772 state_notify.notify_one();
773 }
774 }
775 }
776 }
777 Ok(None) => {
778 log::debug!("Writer channel closed, terminating writer task");
780 break;
781 }
782 Err(_) => {
783 }
785 }
786 }
787
788 _ = tokio::time::timeout(
791 Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS),
792 active_writer.close(),
793 )
794 .await;
795
796 log_task_stopped("write");
797 })
798 }
799
800 fn spawn_heartbeat_task(
801 connection_state: Arc<AtomicU8>,
802 heartbeat_secs: u64,
803 message: Option<String>,
804 writer_tx: tokio::sync::mpsc::UnboundedSender<WriterCommand>,
805 ) -> tokio::task::JoinHandle<()> {
806 log_task_started("heartbeat");
807
808 tokio::task::spawn(async move {
809 let interval = Duration::from_secs(heartbeat_secs);
810
811 loop {
812 tokio::time::sleep(interval).await;
813
814 match ConnectionMode::from_u8(connection_state.load(Ordering::SeqCst)) {
815 ConnectionMode::Active => {
816 let msg = match &message {
817 Some(text) => WriterCommand::Send(Message::Text(text.clone().into())),
818 None => WriterCommand::Send(Message::Ping(vec![].into())),
819 };
820
821 match writer_tx.send(msg) {
822 Ok(()) => log::trace!("Sent heartbeat to writer task"),
823 Err(e) => {
824 log::error!("Failed to send heartbeat to writer task: {e}");
825 }
826 }
827 }
828 ConnectionMode::Reconnect => {}
829 ConnectionMode::Disconnect | ConnectionMode::Closed => break,
830 }
831 }
832
833 log_task_stopped("heartbeat");
834 })
835 }
836}
837
838impl Drop for WebSocketClientInner {
839 fn drop(&mut self) {
840 self.clean_drop();
842 }
843}
844
845impl CleanDrop for WebSocketClientInner {
847 fn clean_drop(&mut self) {
848 if let Some(ref read_task) = self.read_task.take()
849 && !read_task.is_finished()
850 {
851 read_task.abort();
852 log_task_aborted("read");
853 }
854
855 if !self.write_task.is_finished() {
856 self.write_task.abort();
857 log_task_aborted("write");
858 }
859
860 if let Some(ref handle) = self.heartbeat_task.take()
861 && !handle.is_finished()
862 {
863 handle.abort();
864 log_task_aborted("heartbeat");
865 }
866
867 self.message_handler = None;
869 self.ping_handler = None;
870 }
871}
872
873impl Debug for WebSocketClientInner {
874 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
875 f.debug_struct(stringify!(WebSocketClientInner))
876 .field("config", &self.config)
877 .field(
878 "connection_mode",
879 &ConnectionMode::from_atomic(&self.connection_mode),
880 )
881 .field("reconnect_timeout", &self.reconnect_timeout)
882 .field("is_stream_mode", &self.is_stream_mode)
883 .finish()
884 }
885}
886
887#[cfg_attr(
892 feature = "python",
893 pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.network")
894)]
895#[cfg_attr(
896 feature = "python",
897 pyo3_stub_gen::derive::gen_stub_pyclass(module = "nautilus_trader.network")
898)]
899pub struct WebSocketClient {
900 pub(crate) controller_task: tokio::task::JoinHandle<()>,
901 pub(crate) connection_mode: Arc<AtomicU8>,
902 pub(crate) state_notify: Arc<tokio::sync::Notify>,
903 pub(crate) reconnect_timeout: Duration,
904 pub(crate) rate_limiter: Arc<RateLimiter<Ustr, MonotonicClock>>,
905 pub(crate) writer_tx: tokio::sync::mpsc::UnboundedSender<WriterCommand>,
906}
907
908impl Debug for WebSocketClient {
909 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
910 f.debug_struct(stringify!(WebSocketClient)).finish()
911 }
912}
913
914impl WebSocketClient {
915 #[allow(clippy::too_many_arguments)]
931 pub async fn connect_stream(
932 config: WebSocketConfig,
933 keyed_quotas: Vec<(String, Quota)>,
934 default_quota: Option<Quota>,
935 post_reconnect: Option<Arc<dyn Fn() + Send + Sync>>,
936 ) -> Result<(MessageReader, Self), Error> {
937 install_cryptographic_provider();
938
939 let (writer, reader) =
941 WebSocketClientInner::connect_with_server(&config.url, config.headers.clone()).await?;
942
943 let inner = WebSocketClientInner::new_with_writer(config, writer).await?;
945
946 let connection_mode = inner.connection_mode.clone();
947 let state_notify = inner.state_notify.clone();
948 let reconnect_timeout = inner.reconnect_timeout;
949 let keyed_quotas = keyed_quotas
950 .into_iter()
951 .map(|(key, quota)| (Ustr::from(&key), quota))
952 .collect();
953 let rate_limiter = Arc::new(RateLimiter::new_with_quota(default_quota, keyed_quotas));
954 let writer_tx = inner.writer_tx.clone();
955
956 let controller_task = Self::spawn_controller_task(
957 inner,
958 connection_mode.clone(),
959 state_notify.clone(),
960 post_reconnect,
961 );
962
963 Ok((
964 reader,
965 Self {
966 controller_task,
967 connection_mode,
968 state_notify,
969 reconnect_timeout,
970 rate_limiter,
971 writer_tx,
972 },
973 ))
974 }
975
976 pub async fn connect(
994 config: WebSocketConfig,
995 message_handler: Option<MessageHandler>,
996 ping_handler: Option<PingHandler>,
997 post_reconnection: Option<Arc<dyn Fn() + Send + Sync>>,
998 keyed_quotas: Vec<(String, Quota)>,
999 default_quota: Option<Quota>,
1000 ) -> Result<Self, Error> {
1001 if message_handler.is_none() {
1003 return Err(Error::Io(std::io::Error::new(
1004 std::io::ErrorKind::InvalidInput,
1005 "Handler mode requires message_handler to be set. Use connect_stream() for stream mode without a handler.",
1006 )));
1007 }
1008
1009 log::debug!("Connecting");
1010 let inner =
1011 WebSocketClientInner::connect_url(config, message_handler, ping_handler).await?;
1012 let connection_mode = inner.connection_mode.clone();
1013 let state_notify = inner.state_notify.clone();
1014 let writer_tx = inner.writer_tx.clone();
1015 let reconnect_timeout = inner.reconnect_timeout;
1016
1017 let controller_task = Self::spawn_controller_task(
1018 inner,
1019 connection_mode.clone(),
1020 state_notify.clone(),
1021 post_reconnection,
1022 );
1023
1024 let keyed_quotas = keyed_quotas
1025 .into_iter()
1026 .map(|(key, quota)| (Ustr::from(&key), quota))
1027 .collect();
1028 let rate_limiter = Arc::new(RateLimiter::new_with_quota(default_quota, keyed_quotas));
1029
1030 Ok(Self {
1031 controller_task,
1032 connection_mode,
1033 state_notify,
1034 reconnect_timeout,
1035 rate_limiter,
1036 writer_tx,
1037 })
1038 }
1039
1040 #[must_use]
1042 pub fn connection_mode(&self) -> ConnectionMode {
1043 ConnectionMode::from_atomic(&self.connection_mode)
1044 }
1045
1046 #[must_use]
1051 pub fn connection_mode_atomic(&self) -> Arc<AtomicU8> {
1052 Arc::clone(&self.connection_mode)
1053 }
1054
1055 #[inline]
1060 #[must_use]
1061 pub fn is_active(&self) -> bool {
1062 self.connection_mode().is_active()
1063 }
1064
1065 #[must_use]
1067 pub fn is_disconnected(&self) -> bool {
1068 self.controller_task.is_finished()
1069 }
1070
1071 #[inline]
1076 #[must_use]
1077 pub fn is_reconnecting(&self) -> bool {
1078 self.connection_mode().is_reconnect()
1079 }
1080
1081 #[inline]
1085 #[must_use]
1086 pub fn is_disconnecting(&self) -> bool {
1087 self.connection_mode().is_disconnect()
1088 }
1089
1090 #[inline]
1096 #[must_use]
1097 pub fn is_closed(&self) -> bool {
1098 self.connection_mode().is_closed()
1099 }
1100
1101 #[inline]
1105 fn check_not_terminal(&self) -> Result<(), SendError> {
1106 match self.connection_mode() {
1107 ConnectionMode::Disconnect | ConnectionMode::Closed => Err(SendError::Closed),
1108 _ => Ok(()),
1109 }
1110 }
1111
1112 async fn await_rate_limit_or_closed(&self, keys: Option<&[Ustr]>) -> Result<(), SendError> {
1114 const CHECK_INTERVAL_MS: u64 = 100;
1115
1116 tokio::select! {
1117 () = self.rate_limiter.await_keys_ready(keys) => Ok(()),
1118 () = async {
1119 loop {
1120 let notified = self.state_notify.notified();
1121
1122 if matches!(self.connection_mode(), ConnectionMode::Disconnect | ConnectionMode::Closed) {
1123 break;
1124 }
1125 tokio::select! {
1126 () = notified => {}
1127 () = tokio::time::sleep(Duration::from_millis(CHECK_INTERVAL_MS)) => {}
1128 }
1129 }
1130 } => Err(SendError::Closed),
1131 }
1132 }
1133
1134 async fn wait_for_active(&self) -> Result<(), SendError> {
1140 const FALLBACK_INTERVAL_MS: u64 = 100;
1141
1142 let mode = self.connection_mode();
1143 if mode.is_active() {
1144 return Ok(());
1145 }
1146
1147 if matches!(mode, ConnectionMode::Disconnect | ConnectionMode::Closed) {
1148 return Err(SendError::Closed);
1149 }
1150
1151 log::debug!("Waiting for client to become ACTIVE before sending...");
1152
1153 let fallback_interval = Duration::from_millis(FALLBACK_INTERVAL_MS);
1154
1155 tokio::time::timeout(self.reconnect_timeout, async {
1156 loop {
1157 let notified = self.state_notify.notified();
1160
1161 let mode = self.connection_mode();
1162 if mode.is_active() {
1163 return Ok(());
1164 }
1165
1166 if matches!(mode, ConnectionMode::Disconnect | ConnectionMode::Closed) {
1167 return Err(());
1168 }
1169
1170 tokio::select! {
1171 () = notified => {}
1172 () = tokio::time::sleep(fallback_interval) => {}
1173 }
1174 }
1175 })
1176 .await
1177 .map_err(|_| SendError::Timeout)?
1178 .map_err(|()| SendError::Closed)
1179 }
1180
1181 pub fn notify_closed(&self) {
1192 let mode = self.connection_mode();
1193 if mode.is_disconnect() || mode.is_closed() {
1194 return;
1195 }
1196
1197 log::debug!("Stream reader signalled EOF, transitioning to CLOSED");
1198
1199 self.connection_mode
1200 .store(ConnectionMode::Closed.as_u8(), Ordering::SeqCst);
1201 self.state_notify.notify_waiters();
1202 }
1203
1204 pub async fn disconnect(&self) {
1209 log::debug!("Disconnecting");
1210 self.connection_mode
1211 .store(ConnectionMode::Disconnect.as_u8(), Ordering::SeqCst);
1212 self.state_notify.notify_waiters();
1213
1214 if tokio::time::timeout(Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS), async {
1215 while !self.is_disconnected() {
1216 tokio::time::sleep(Duration::from_millis(CONNECTION_STATE_CHECK_INTERVAL_MS)).await;
1217 }
1218
1219 if !self.controller_task.is_finished() {
1220 self.controller_task.abort();
1221 log_task_aborted("controller");
1222 }
1223 })
1224 .await
1225 == Ok(())
1226 {
1227 log::debug!("Controller task finished");
1228 } else {
1229 log::error!("Timeout waiting for controller task to finish");
1230
1231 if !self.controller_task.is_finished() {
1232 self.controller_task.abort();
1233 log_task_aborted("controller");
1234 }
1235 self.connection_mode
1236 .store(ConnectionMode::Closed.as_u8(), Ordering::SeqCst);
1237 }
1238 }
1239
1240 #[allow(unused_variables)]
1250 pub async fn send_text(&self, data: String, keys: Option<&[Ustr]>) -> Result<(), SendError> {
1251 self.check_not_terminal()?;
1252
1253 self.await_rate_limit_or_closed(keys).await?;
1254 self.wait_for_active().await?;
1255
1256 log::trace!("Sending text: {data:?}");
1257
1258 let msg = Message::Text(data.into());
1259 self.writer_tx
1260 .send(WriterCommand::Send(msg))
1261 .map_err(|e| SendError::BrokenPipe(e.to_string()))
1262 }
1263
1264 pub async fn send_pong(&self, data: Vec<u8>) -> Result<(), SendError> {
1270 self.wait_for_active().await?;
1271
1272 log::trace!("Sending pong frame ({} bytes)", data.len());
1273
1274 let msg = Message::Pong(data.into());
1275 self.writer_tx
1276 .send(WriterCommand::Send(msg))
1277 .map_err(|e| SendError::BrokenPipe(e.to_string()))
1278 }
1279
1280 #[allow(unused_variables)]
1290 pub async fn send_bytes(&self, data: Vec<u8>, keys: Option<&[Ustr]>) -> Result<(), SendError> {
1291 self.check_not_terminal()?;
1292
1293 self.await_rate_limit_or_closed(keys).await?;
1294 self.wait_for_active().await?;
1295
1296 log::trace!("Sending bytes: {data:?}");
1297
1298 let msg = Message::Binary(data.into());
1299 self.writer_tx
1300 .send(WriterCommand::Send(msg))
1301 .map_err(|e| SendError::BrokenPipe(e.to_string()))
1302 }
1303
1304 pub async fn send_close_message(&self) -> Result<(), SendError> {
1310 self.wait_for_active().await?;
1311
1312 let msg = Message::Close(None);
1313 self.writer_tx
1314 .send(WriterCommand::Send(msg))
1315 .map_err(|e| SendError::BrokenPipe(e.to_string()))
1316 }
1317
1318 fn spawn_controller_task(
1319 mut inner: WebSocketClientInner,
1320 connection_mode: Arc<AtomicU8>,
1321 state_notify: Arc<tokio::sync::Notify>,
1322 post_reconnection: Option<Arc<dyn Fn() + Send + Sync>>,
1323 ) -> tokio::task::JoinHandle<()> {
1324 const CONTROLLER_FALLBACK_INTERVAL_MS: u64 = 100;
1325
1326 tokio::task::spawn(async move {
1327 log_task_started("controller");
1328
1329 let fallback_interval = Duration::from_millis(CONTROLLER_FALLBACK_INTERVAL_MS);
1330
1331 loop {
1332 tokio::select! {
1333 () = state_notify.notified() => {}
1334 () = tokio::time::sleep(fallback_interval) => {}
1335 }
1336
1337 let mut mode = ConnectionMode::from_atomic(&connection_mode);
1338
1339 if mode.is_disconnect() {
1340 log::debug!("Disconnecting");
1341
1342 let timeout = Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS);
1343 if tokio::time::timeout(timeout, async {
1344 tokio::time::sleep(Duration::from_millis(GRACEFUL_SHUTDOWN_DELAY_MS)).await;
1346
1347 if let Some(task) = &inner.read_task
1348 && !task.is_finished()
1349 {
1350 task.abort();
1351 log_task_aborted("read");
1352 }
1353
1354 if let Some(task) = &inner.heartbeat_task
1355 && !task.is_finished()
1356 {
1357 task.abort();
1358 log_task_aborted("heartbeat");
1359 }
1360 })
1361 .await
1362 .is_err()
1363 {
1364 log::error!("Shutdown timed out after {}s", timeout.as_secs());
1365 }
1366
1367 log::debug!("Closed");
1368 break; }
1370
1371 if mode.is_closed() {
1372 log::debug!("Connection closed");
1373 break;
1374 }
1375
1376 if mode.is_active() && !inner.is_alive() {
1377 let target = if inner.is_stream_mode {
1378 ConnectionMode::Closed
1379 } else {
1380 ConnectionMode::Reconnect
1381 };
1382
1383 if connection_mode
1384 .compare_exchange(
1385 ConnectionMode::Active.as_u8(),
1386 target.as_u8(),
1387 Ordering::SeqCst,
1388 Ordering::SeqCst,
1389 )
1390 .is_ok()
1391 {
1392 log::debug!("Detected dead connection, transitioning to {target:?}");
1393 }
1394 mode = ConnectionMode::from_atomic(&connection_mode);
1395 }
1396
1397 if mode.is_reconnect() {
1398 if let Some(max_attempts) = inner.reconnect_max_attempts
1400 && inner.reconnection_attempt_count >= max_attempts
1401 {
1402 log::error!(
1403 "Max reconnection attempts ({max_attempts}) exceeded, transitioning to CLOSED"
1404 );
1405 connection_mode.store(ConnectionMode::Closed.as_u8(), Ordering::SeqCst);
1406 state_notify.notify_waiters();
1407 break;
1408 }
1409
1410 inner.reconnection_attempt_count += 1;
1411 log::debug!(
1412 "Reconnection attempt {} of {}",
1413 inner.reconnection_attempt_count,
1414 inner
1415 .reconnect_max_attempts
1416 .map_or_else(|| "unlimited".to_string(), |m| m.to_string())
1417 );
1418
1419 let reconnect_result = tokio::select! {
1421 result = inner.reconnect() => Some(result),
1422 () = async {
1423 loop {
1424 state_notify.notified().await;
1425
1426 if ConnectionMode::from_atomic(&connection_mode).is_disconnect() {
1427 break;
1428 }
1429 }
1430 } => None,
1431 };
1432
1433 match reconnect_result {
1434 None => {
1435 log::debug!("Reconnect interrupted by disconnect");
1436 }
1437 Some(Ok(())) => {
1438 inner.backoff.reset();
1439 inner.reconnection_attempt_count = 0;
1440
1441 state_notify.notify_waiters();
1442
1443 if ConnectionMode::from_atomic(&connection_mode).is_active() {
1444 if let Some(ref handler) = inner.message_handler {
1445 let reconnected_msg =
1446 Message::Text(RECONNECTED.to_string().into());
1447 handler(reconnected_msg);
1448 log::debug!("Sent reconnected message to handler");
1449 }
1450
1451 if let Some(ref callback) = post_reconnection {
1453 callback();
1454 log::debug!("Called `post_reconnection` handler");
1455 }
1456
1457 log::debug!("Reconnected successfully");
1458 } else {
1459 log::debug!(
1460 "Skipping post_reconnection handlers due to disconnect state"
1461 );
1462 }
1463 }
1464 Some(Err(e)) => {
1465 let duration = inner.backoff.next_duration();
1466 log::warn!(
1467 "Reconnect attempt {} failed: {e}",
1468 inner.reconnection_attempt_count
1469 );
1470
1471 if !duration.is_zero() {
1472 log::warn!("Backing off for {}s...", duration.as_secs_f64());
1473 tokio::select! {
1475 () = tokio::time::sleep(duration) => {}
1476 () = async {
1477 loop {
1478 state_notify.notified().await;
1479
1480 if ConnectionMode::from_atomic(&connection_mode).is_disconnect() {
1481 break;
1482 }
1483 }
1484 } => {
1485 log::debug!("Backoff interrupted by disconnect");
1486 }
1487 }
1488 }
1489 }
1490 }
1491 }
1492 }
1493 inner
1494 .connection_mode
1495 .store(ConnectionMode::Closed.as_u8(), Ordering::SeqCst);
1496
1497 log_task_stopped("controller");
1498 })
1499 }
1500}
1501
1502impl Drop for WebSocketClient {
1504 fn drop(&mut self) {
1505 if !self.controller_task.is_finished() {
1506 self.controller_task.abort();
1507 log_task_aborted("controller");
1508 }
1509 }
1510}
1511
1512#[cfg(test)]
1513#[cfg(not(feature = "turmoil"))]
1514#[cfg(target_os = "linux")] mod tests {
1516 use std::{num::NonZeroU32, sync::Arc};
1517
1518 use futures_util::{SinkExt, StreamExt};
1519 use tokio::{
1520 net::TcpListener,
1521 task::{self, JoinHandle},
1522 };
1523 use tokio_tungstenite::{
1524 accept_hdr_async,
1525 tungstenite::{
1526 handshake::server::{self, Callback},
1527 http::HeaderValue,
1528 },
1529 };
1530
1531 use crate::{
1532 ratelimiter::quota::Quota,
1533 websocket::{WebSocketClient, WebSocketConfig},
1534 };
1535
1536 struct TestServer {
1537 task: JoinHandle<()>,
1538 port: u16,
1539 }
1540
1541 #[derive(Debug, Clone)]
1542 struct TestCallback {
1543 key: String,
1544 value: HeaderValue,
1545 }
1546
1547 impl Callback for TestCallback {
1548 #[allow(clippy::panic_in_result_fn)]
1549 fn on_request(
1550 self,
1551 request: &server::Request,
1552 response: server::Response,
1553 ) -> Result<server::Response, server::ErrorResponse> {
1554 let _ = response;
1555 let value = request.headers().get(&self.key);
1556 assert!(value.is_some());
1557
1558 if let Some(value) = request.headers().get(&self.key) {
1559 assert_eq!(value, self.value);
1560 }
1561
1562 Ok(response)
1563 }
1564 }
1565
1566 impl TestServer {
1567 async fn setup() -> Self {
1568 let server = TcpListener::bind("127.0.0.1:0").await.unwrap();
1569 let port = TcpListener::local_addr(&server).unwrap().port();
1570
1571 let header_key = "test".to_string();
1572 let header_value = "test".to_string();
1573
1574 let test_call_back = TestCallback {
1575 key: header_key,
1576 value: HeaderValue::from_str(&header_value).unwrap(),
1577 };
1578
1579 let task = task::spawn(async move {
1580 loop {
1582 let (conn, _) = server.accept().await.unwrap();
1583 let mut websocket = accept_hdr_async(conn, test_call_back.clone())
1584 .await
1585 .unwrap();
1586
1587 task::spawn(async move {
1588 while let Some(Ok(msg)) = websocket.next().await {
1589 match msg {
1590 tokio_tungstenite::tungstenite::protocol::Message::Text(txt)
1591 if txt == "close-now" =>
1592 {
1593 log::debug!("Forcibly closing from server side");
1594 let _ = websocket.close(None).await;
1596 break;
1597 }
1598 tokio_tungstenite::tungstenite::protocol::Message::Text(_)
1600 | tokio_tungstenite::tungstenite::protocol::Message::Binary(_) => {
1601 if websocket.send(msg).await.is_err() {
1602 break;
1603 }
1604 }
1605 tokio_tungstenite::tungstenite::protocol::Message::Close(
1607 _frame,
1608 ) => {
1609 let _ = websocket.close(None).await;
1610 break;
1611 }
1612 _ => {}
1614 }
1615 }
1616 });
1617 }
1618 });
1619
1620 Self { task, port }
1621 }
1622 }
1623
1624 impl Drop for TestServer {
1625 fn drop(&mut self) {
1626 self.task.abort();
1627 }
1628 }
1629
1630 async fn setup_test_client(port: u16) -> WebSocketClient {
1631 let config = WebSocketConfig {
1632 url: format!("ws://127.0.0.1:{port}"),
1633 headers: vec![("test".into(), "test".into())],
1634 heartbeat: None,
1635 heartbeat_msg: None,
1636 reconnect_timeout_ms: None,
1637 reconnect_delay_initial_ms: None,
1638 reconnect_backoff_factor: None,
1639 reconnect_delay_max_ms: None,
1640 reconnect_jitter_ms: None,
1641 reconnect_max_attempts: None,
1642 idle_timeout_ms: None,
1643 };
1644 WebSocketClient::connect(config, Some(Arc::new(|_| {})), None, None, vec![], None)
1645 .await
1646 .expect("Failed to connect")
1647 }
1648
1649 #[tokio::test]
1650 async fn test_websocket_basic() {
1651 let server = TestServer::setup().await;
1652 let client = setup_test_client(server.port).await;
1653
1654 assert!(!client.is_disconnected());
1655
1656 client.disconnect().await;
1657 assert!(client.is_disconnected());
1658 }
1659
1660 #[tokio::test]
1661 async fn test_websocket_heartbeat() {
1662 let server = TestServer::setup().await;
1663 let client = setup_test_client(server.port).await;
1664
1665 tokio::time::sleep(std::time::Duration::from_secs(3)).await;
1667
1668 client.disconnect().await;
1670 assert!(client.is_disconnected());
1671 }
1672
1673 #[tokio::test]
1674 async fn test_websocket_reconnect_exhausted() {
1675 let config = WebSocketConfig {
1676 url: "ws://127.0.0.1:9997".into(), headers: vec![],
1678 heartbeat: None,
1679 heartbeat_msg: None,
1680 reconnect_timeout_ms: None,
1681 reconnect_delay_initial_ms: None,
1682 reconnect_backoff_factor: None,
1683 reconnect_delay_max_ms: None,
1684 reconnect_jitter_ms: None,
1685 reconnect_max_attempts: None,
1686 idle_timeout_ms: None,
1687 };
1688 let res =
1689 WebSocketClient::connect(config, Some(Arc::new(|_| {})), None, None, vec![], None)
1690 .await;
1691 assert!(res.is_err(), "Should fail quickly with no server");
1692 }
1693
1694 #[tokio::test]
1695 async fn test_websocket_forced_close_reconnect() {
1696 let server = TestServer::setup().await;
1697 let client = setup_test_client(server.port).await;
1698
1699 client.send_text("Hello".into(), None).await.unwrap();
1701
1702 client.send_text("close-now".into(), None).await.unwrap();
1704
1705 tokio::time::sleep(std::time::Duration::from_secs(1)).await;
1707
1708 assert!(!client.is_disconnected());
1710
1711 client.disconnect().await;
1713 assert!(client.is_disconnected());
1714 }
1715
1716 #[tokio::test]
1717 async fn test_rate_limiter() {
1718 let server = TestServer::setup().await;
1719 let quota = Quota::per_second(NonZeroU32::new(2).unwrap()).unwrap();
1720
1721 let config = WebSocketConfig {
1722 url: format!("ws://127.0.0.1:{}", server.port),
1723 headers: vec![("test".into(), "test".into())],
1724 heartbeat: None,
1725 heartbeat_msg: None,
1726 reconnect_timeout_ms: None,
1727 reconnect_delay_initial_ms: None,
1728 reconnect_backoff_factor: None,
1729 reconnect_delay_max_ms: None,
1730 reconnect_jitter_ms: None,
1731 reconnect_max_attempts: None,
1732 idle_timeout_ms: None,
1733 };
1734
1735 let client = WebSocketClient::connect(
1736 config,
1737 Some(Arc::new(|_| {})),
1738 None,
1739 None,
1740 vec![("default".into(), quota)],
1741 None,
1742 )
1743 .await
1744 .unwrap();
1745
1746 client.send_text("test1".into(), None).await.unwrap();
1748 client.send_text("test2".into(), None).await.unwrap();
1749
1750 client.send_text("test3".into(), None).await.unwrap();
1752
1753 client.disconnect().await;
1755 assert!(client.is_disconnected());
1756 }
1757
1758 #[tokio::test]
1759 async fn test_concurrent_writers() {
1760 let server = TestServer::setup().await;
1761 let client = Arc::new(setup_test_client(server.port).await);
1762
1763 let mut handles = vec![];
1764 for i in 0..10 {
1765 let client = client.clone();
1766 handles.push(task::spawn(async move {
1767 client.send_text(format!("test{i}"), None).await.unwrap();
1768 }));
1769 }
1770
1771 for handle in handles {
1772 handle.await.unwrap();
1773 }
1774
1775 client.disconnect().await;
1777 assert!(client.is_disconnected());
1778 }
1779}
1780
1781#[cfg(test)]
1782#[cfg(not(feature = "turmoil"))]
1783mod rust_tests {
1784 use futures_util::{SinkExt, StreamExt};
1785 use nautilus_common::testing::wait_until_async;
1786 use rstest::rstest;
1787 use tokio::{
1788 net::TcpListener,
1789 task,
1790 time::{Duration, sleep},
1791 };
1792 use tokio_tungstenite::accept_async;
1793
1794 use super::*;
1795 use crate::websocket::types::channel_message_handler;
1796
1797 #[rstest]
1798 #[tokio::test]
1799 async fn test_reconnect_then_disconnect() {
1800 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1802 let port = listener.local_addr().unwrap().port();
1803
1804 let server = task::spawn(async move {
1806 let (stream, _) = listener.accept().await.unwrap();
1807 let ws = accept_async(stream).await.unwrap();
1808 drop(ws);
1809 sleep(Duration::from_secs(1)).await;
1811 });
1812
1813 let (handler, _rx) = channel_message_handler();
1815
1816 let config = WebSocketConfig {
1818 url: format!("ws://127.0.0.1:{port}"),
1819 headers: vec![],
1820 heartbeat: None,
1821 heartbeat_msg: None,
1822 reconnect_timeout_ms: Some(1_000),
1823 reconnect_delay_initial_ms: Some(50),
1824 reconnect_delay_max_ms: Some(100),
1825 reconnect_backoff_factor: Some(1.0),
1826 reconnect_jitter_ms: Some(0),
1827 reconnect_max_attempts: None,
1828 idle_timeout_ms: None,
1829 };
1830
1831 let client = WebSocketClient::connect(config, Some(handler), None, None, vec![], None)
1833 .await
1834 .unwrap();
1835
1836 sleep(Duration::from_millis(100)).await;
1838 client.disconnect().await;
1840 assert!(client.is_disconnected());
1841 server.abort();
1842 }
1843
1844 #[rstest]
1845 #[tokio::test]
1846 async fn test_reconnect_state_flips_when_reader_stops() {
1847 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1849 let port = listener.local_addr().unwrap().port();
1850
1851 let server = task::spawn(async move {
1852 if let Ok((stream, _)) = listener.accept().await
1853 && let Ok(ws) = accept_async(stream).await
1854 {
1855 drop(ws);
1856 }
1857 sleep(Duration::from_millis(50)).await;
1858 });
1859
1860 let (handler, _rx) = channel_message_handler();
1861
1862 let config = WebSocketConfig {
1863 url: format!("ws://127.0.0.1:{port}"),
1864 headers: vec![],
1865 heartbeat: None,
1866 heartbeat_msg: None,
1867 reconnect_timeout_ms: Some(1_000),
1868 reconnect_delay_initial_ms: Some(50),
1869 reconnect_delay_max_ms: Some(100),
1870 reconnect_backoff_factor: Some(1.0),
1871 reconnect_jitter_ms: Some(0),
1872 reconnect_max_attempts: None,
1873 idle_timeout_ms: None,
1874 };
1875
1876 let client = WebSocketClient::connect(config, Some(handler), None, None, vec![], None)
1877 .await
1878 .unwrap();
1879
1880 tokio::time::timeout(Duration::from_secs(2), async {
1881 loop {
1882 if client.is_reconnecting() {
1883 break;
1884 }
1885 tokio::time::sleep(Duration::from_millis(10)).await;
1886 }
1887 })
1888 .await
1889 .expect("client did not enter RECONNECT state");
1890
1891 client.disconnect().await;
1892 server.abort();
1893 }
1894
1895 #[rstest]
1896 #[tokio::test]
1897 async fn test_stream_mode_disables_auto_reconnect() {
1898 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1901 let port = listener.local_addr().unwrap().port();
1902
1903 let server = task::spawn(async move {
1904 if let Ok((stream, _)) = listener.accept().await
1905 && let Ok(_ws) = accept_async(stream).await
1906 {
1907 sleep(Duration::from_millis(100)).await;
1909 }
1910 });
1911
1912 let config = WebSocketConfig {
1913 url: format!("ws://127.0.0.1:{port}"),
1914 headers: vec![],
1915 heartbeat: None,
1916 heartbeat_msg: None,
1917 reconnect_timeout_ms: Some(1_000),
1918 reconnect_delay_initial_ms: Some(50),
1919 reconnect_delay_max_ms: Some(100),
1920 reconnect_backoff_factor: Some(1.0),
1921 reconnect_jitter_ms: Some(0),
1922 reconnect_max_attempts: None,
1923 idle_timeout_ms: None,
1924 };
1925
1926 let (_reader, _client) = WebSocketClient::connect_stream(config, vec![], None, None)
1927 .await
1928 .unwrap();
1929
1930 server.abort();
1938 }
1939
1940 #[rstest]
1941 #[tokio::test]
1942 async fn test_message_handler_mode_allows_auto_reconnect() {
1943 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1945 let port = listener.local_addr().unwrap().port();
1946
1947 let server = task::spawn(async move {
1948 if let Ok((stream, _)) = listener.accept().await
1950 && let Ok(ws) = accept_async(stream).await
1951 {
1952 drop(ws);
1953 }
1954 sleep(Duration::from_millis(50)).await;
1955 });
1956
1957 let (handler, _rx) = channel_message_handler();
1958
1959 let config = WebSocketConfig {
1960 url: format!("ws://127.0.0.1:{port}"),
1961 headers: vec![],
1962 heartbeat: None,
1963 heartbeat_msg: None,
1964 reconnect_timeout_ms: Some(1_000),
1965 reconnect_delay_initial_ms: Some(50),
1966 reconnect_delay_max_ms: Some(100),
1967 reconnect_backoff_factor: Some(1.0),
1968 reconnect_jitter_ms: Some(0),
1969 reconnect_max_attempts: None,
1970 idle_timeout_ms: None,
1971 };
1972
1973 let client = WebSocketClient::connect(config, Some(handler), None, None, vec![], None)
1974 .await
1975 .unwrap();
1976
1977 tokio::time::timeout(Duration::from_secs(2), async {
1979 loop {
1980 if client.is_reconnecting() || client.is_closed() {
1981 break;
1982 }
1983 tokio::time::sleep(Duration::from_millis(10)).await;
1984 }
1985 })
1986 .await
1987 .expect("client should attempt reconnection or close");
1988
1989 assert!(
1992 client.is_reconnecting() || client.is_closed(),
1993 "Client with message handler should attempt reconnection"
1994 );
1995
1996 client.disconnect().await;
1997 server.abort();
1998 }
1999
2000 #[rstest]
2001 #[tokio::test]
2002 async fn test_handler_mode_reconnect_with_new_connection() {
2003 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
2005 let port = listener.local_addr().unwrap().port();
2006
2007 let server = task::spawn(async move {
2008 if let Ok((stream, _)) = listener.accept().await
2010 && let Ok(ws) = accept_async(stream).await
2011 {
2012 drop(ws);
2013 }
2014
2015 sleep(Duration::from_millis(100)).await;
2017
2018 if let Ok((stream, _)) = listener.accept().await
2020 && let Ok(mut ws) = accept_async(stream).await
2021 {
2022 use futures_util::SinkExt;
2023 let _ = ws
2024 .send(Message::Text("reconnected".to_string().into()))
2025 .await;
2026 sleep(Duration::from_secs(1)).await;
2027 }
2028 });
2029
2030 let (handler, mut rx) = channel_message_handler();
2031
2032 let config = WebSocketConfig {
2033 url: format!("ws://127.0.0.1:{port}"),
2034 headers: vec![],
2035 heartbeat: None,
2036 heartbeat_msg: None,
2037 reconnect_timeout_ms: Some(2_000),
2038 reconnect_delay_initial_ms: Some(50),
2039 reconnect_delay_max_ms: Some(200),
2040 reconnect_backoff_factor: Some(1.5),
2041 reconnect_jitter_ms: Some(10),
2042 reconnect_max_attempts: None,
2043 idle_timeout_ms: None,
2044 };
2045
2046 let client = WebSocketClient::connect(config, Some(handler), None, None, vec![], None)
2047 .await
2048 .unwrap();
2049
2050 let result = tokio::time::timeout(Duration::from_secs(5), async {
2052 loop {
2053 if let Ok(msg) = rx.try_recv()
2054 && matches!(msg, Message::Text(ref text) if AsRef::<str>::as_ref(text) == "reconnected")
2055 {
2056 return true;
2057 }
2058 tokio::time::sleep(Duration::from_millis(10)).await;
2059 }
2060 })
2061 .await;
2062
2063 assert!(
2064 result.is_ok(),
2065 "Should receive message after reconnection within timeout"
2066 );
2067
2068 client.disconnect().await;
2069 server.abort();
2070 }
2071
2072 #[rstest]
2073 #[tokio::test]
2074 async fn test_stream_mode_no_auto_reconnect() {
2075 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
2078 let port = listener.local_addr().unwrap().port();
2079
2080 let server = task::spawn(async move {
2081 if let Ok((stream, _)) = listener.accept().await
2083 && let Ok(mut ws) = accept_async(stream).await
2084 {
2085 use futures_util::SinkExt;
2086 let _ = ws.send(Message::Text("hello".to_string().into())).await;
2087 sleep(Duration::from_millis(50)).await;
2088 }
2090 });
2091
2092 let config = WebSocketConfig {
2093 url: format!("ws://127.0.0.1:{port}"),
2094 headers: vec![],
2095 heartbeat: None,
2096 heartbeat_msg: None,
2097 reconnect_timeout_ms: Some(1_000),
2098 reconnect_delay_initial_ms: Some(50),
2099 reconnect_delay_max_ms: Some(100),
2100 reconnect_backoff_factor: Some(1.0),
2101 reconnect_jitter_ms: Some(0),
2102 reconnect_max_attempts: None,
2103 idle_timeout_ms: None,
2104 };
2105
2106 let (mut reader, client) = WebSocketClient::connect_stream(config, vec![], None, None)
2107 .await
2108 .unwrap();
2109
2110 assert!(client.is_active(), "Client should start as active");
2112
2113 let msg = reader.next().await;
2115 assert!(
2116 matches!(msg, Some(Ok(Message::Text(ref text))) if AsRef::<str>::as_ref(text) == "hello"),
2117 "Should receive initial message"
2118 );
2119
2120 while let Some(msg) = reader.next().await {
2122 if msg.is_err() || matches!(msg, Ok(Message::Close(_))) {
2123 break;
2124 }
2125 }
2126
2127 sleep(Duration::from_millis(200)).await;
2130 assert!(
2131 client.is_active(),
2132 "Stream mode client stays ACTIVE before notify_closed()"
2133 );
2134
2135 client.notify_closed();
2137
2138 assert!(
2139 client.is_closed(),
2140 "Stream mode client should be CLOSED after notify_closed()"
2141 );
2142 assert!(
2143 !client.is_reconnecting(),
2144 "Stream mode client should never attempt reconnection"
2145 );
2146
2147 client.disconnect().await;
2148 server.abort();
2149 }
2150
2151 #[rstest]
2152 #[tokio::test]
2153 async fn test_send_timeout_uses_configured_reconnect_timeout() {
2154 use nautilus_common::testing::wait_until_async;
2157
2158 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
2159 let port = listener.local_addr().unwrap().port();
2160
2161 let server = task::spawn(async move {
2162 if let Ok((stream, _)) = listener.accept().await
2164 && let Ok(ws) = accept_async(stream).await
2165 {
2166 drop(ws);
2167 }
2168 sleep(Duration::from_secs(60)).await;
2170 });
2171
2172 let (handler, _rx) = channel_message_handler();
2173
2174 let config = WebSocketConfig {
2176 url: format!("ws://127.0.0.1:{port}"),
2177 headers: vec![],
2178 heartbeat: None,
2179 heartbeat_msg: None,
2180 reconnect_timeout_ms: Some(2_000), reconnect_delay_initial_ms: Some(50),
2182 reconnect_delay_max_ms: Some(100),
2183 reconnect_backoff_factor: Some(1.0),
2184 reconnect_jitter_ms: Some(0),
2185 reconnect_max_attempts: None,
2186 idle_timeout_ms: None,
2187 };
2188
2189 let client = WebSocketClient::connect(config, Some(handler), None, None, vec![], None)
2190 .await
2191 .unwrap();
2192
2193 wait_until_async(
2195 || async { client.is_reconnecting() },
2196 Duration::from_secs(3),
2197 )
2198 .await;
2199
2200 let start = std::time::Instant::now();
2202 let send_result = client.send_text("test".to_string(), None).await;
2203 let elapsed = start.elapsed();
2204
2205 assert!(
2206 send_result.is_err(),
2207 "Send should fail when client stuck in RECONNECT"
2208 );
2209 assert!(
2210 matches!(send_result, Err(crate::error::SendError::Timeout)),
2211 "Send should return Timeout error, was: {send_result:?}"
2212 );
2213 assert!(
2216 elapsed >= Duration::from_millis(1800),
2217 "Send should timeout after at least 2s (configured timeout), took {elapsed:?}"
2218 );
2219
2220 client.disconnect().await;
2221 server.abort();
2222 }
2223
2224 #[rstest]
2225 #[tokio::test]
2226 async fn test_send_waits_during_reconnection() {
2227 use nautilus_common::testing::wait_until_async;
2229
2230 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
2231 let port = listener.local_addr().unwrap().port();
2232
2233 let server = task::spawn(async move {
2234 if let Ok((stream, _)) = listener.accept().await
2236 && let Ok(ws) = accept_async(stream).await
2237 {
2238 drop(ws);
2239 }
2240
2241 sleep(Duration::from_millis(500)).await;
2243
2244 if let Ok((stream, _)) = listener.accept().await
2246 && let Ok(mut ws) = accept_async(stream).await
2247 {
2248 while let Some(Ok(msg)) = ws.next().await {
2250 if ws.send(msg).await.is_err() {
2251 break;
2252 }
2253 }
2254 }
2255 });
2256
2257 let (handler, _rx) = channel_message_handler();
2258
2259 let config = WebSocketConfig {
2260 url: format!("ws://127.0.0.1:{port}"),
2261 headers: vec![],
2262 heartbeat: None,
2263 heartbeat_msg: None,
2264 reconnect_timeout_ms: Some(5_000), reconnect_delay_initial_ms: Some(100),
2266 reconnect_delay_max_ms: Some(200),
2267 reconnect_backoff_factor: Some(1.0),
2268 reconnect_jitter_ms: Some(0),
2269 reconnect_max_attempts: None,
2270 idle_timeout_ms: None,
2271 };
2272
2273 let client = WebSocketClient::connect(config, Some(handler), None, None, vec![], None)
2274 .await
2275 .unwrap();
2276
2277 wait_until_async(
2279 || async { client.is_reconnecting() },
2280 Duration::from_secs(2),
2281 )
2282 .await;
2283
2284 let send_result = tokio::time::timeout(
2286 Duration::from_secs(3),
2287 client.send_text("test_message".to_string(), None),
2288 )
2289 .await;
2290
2291 assert!(
2292 send_result.is_ok() && send_result.unwrap().is_ok(),
2293 "Send should succeed after waiting for reconnection"
2294 );
2295
2296 client.disconnect().await;
2297 server.abort();
2298 }
2299
2300 #[rstest]
2301 #[tokio::test]
2302 async fn test_rate_limiter_before_active_wait() {
2303 use std::{num::NonZeroU32, sync::Arc};
2308
2309 use nautilus_common::testing::wait_until_async;
2310
2311 use crate::ratelimiter::quota::Quota;
2312
2313 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
2314 let port = listener.local_addr().unwrap().port();
2315
2316 let server = task::spawn(async move {
2317 if let Ok((stream, _)) = listener.accept().await
2319 && let Ok(mut ws) = accept_async(stream).await
2320 {
2321 if let Some(Ok(_)) = ws.next().await {
2323 drop(ws);
2324 }
2325 }
2326
2327 sleep(Duration::from_millis(500)).await;
2329
2330 if let Ok((stream, _)) = listener.accept().await
2332 && let Ok(mut ws) = accept_async(stream).await
2333 {
2334 while let Some(Ok(msg)) = ws.next().await {
2335 if ws.send(msg).await.is_err() {
2336 break;
2337 }
2338 }
2339 }
2340 });
2341
2342 let (handler, _rx) = channel_message_handler();
2343
2344 let config = WebSocketConfig {
2345 url: format!("ws://127.0.0.1:{port}"),
2346 headers: vec![],
2347 heartbeat: None,
2348 heartbeat_msg: None,
2349 reconnect_timeout_ms: Some(5_000),
2350 reconnect_delay_initial_ms: Some(50),
2351 reconnect_delay_max_ms: Some(100),
2352 reconnect_backoff_factor: Some(1.0),
2353 reconnect_jitter_ms: Some(0),
2354 reconnect_max_attempts: None,
2355 idle_timeout_ms: None,
2356 };
2357
2358 let quota = Quota::per_second(NonZeroU32::new(1).unwrap())
2360 .unwrap()
2361 .allow_burst(NonZeroU32::new(1).unwrap());
2362
2363 let client = Arc::new(
2364 WebSocketClient::connect(
2365 config,
2366 Some(handler),
2367 None,
2368 None,
2369 vec![("test_key".to_string(), quota)],
2370 None,
2371 )
2372 .await
2373 .unwrap(),
2374 );
2375
2376 let test_key: [Ustr; 1] = [Ustr::from("test_key")];
2378 client
2379 .send_text("msg1".to_string(), Some(test_key.as_slice()))
2380 .await
2381 .unwrap();
2382
2383 wait_until_async(
2385 || async { client.is_reconnecting() },
2386 Duration::from_secs(2),
2387 )
2388 .await;
2389
2390 let start = std::time::Instant::now();
2392 let send_result = client
2393 .send_text("msg2".to_string(), Some(test_key.as_slice()))
2394 .await;
2395 let elapsed = start.elapsed();
2396
2397 assert!(
2399 send_result.is_ok(),
2400 "Send should succeed after rate limit + reconnection, was: {send_result:?}"
2401 );
2402 assert!(
2406 elapsed >= Duration::from_millis(850),
2407 "Should wait for rate limit (~1s), waited {elapsed:?}"
2408 );
2409
2410 client.disconnect().await;
2411 server.abort();
2412 }
2413
2414 #[rstest]
2415 #[tokio::test]
2416 async fn test_disconnect_during_reconnect_exits_cleanly() {
2417 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
2420 let port = listener.local_addr().unwrap().port();
2421
2422 let server = task::spawn(async move {
2423 if let Ok((stream, _)) = listener.accept().await
2425 && let Ok(ws) = accept_async(stream).await
2426 {
2427 drop(ws);
2428 }
2429 sleep(Duration::from_secs(60)).await;
2431 });
2432
2433 let (handler, _rx) = channel_message_handler();
2434
2435 let config = WebSocketConfig {
2436 url: format!("ws://127.0.0.1:{port}"),
2437 headers: vec![],
2438 heartbeat: None,
2439 heartbeat_msg: None,
2440 reconnect_timeout_ms: Some(2_000), reconnect_delay_initial_ms: Some(100),
2442 reconnect_delay_max_ms: Some(200),
2443 reconnect_backoff_factor: Some(1.0),
2444 reconnect_jitter_ms: Some(0),
2445 reconnect_max_attempts: None,
2446 idle_timeout_ms: None,
2447 };
2448
2449 let client = WebSocketClient::connect(config, Some(handler), None, None, vec![], None)
2450 .await
2451 .unwrap();
2452
2453 tokio::time::timeout(Duration::from_secs(2), async {
2455 while !client.is_reconnecting() {
2456 sleep(Duration::from_millis(10)).await;
2457 }
2458 })
2459 .await
2460 .expect("Client should enter RECONNECT state");
2461
2462 client.disconnect().await;
2464
2465 assert!(
2467 client.is_disconnected(),
2468 "Client should be cleanly disconnected"
2469 );
2470
2471 server.abort();
2472 }
2473
2474 #[rstest]
2475 #[tokio::test]
2476 async fn test_send_fails_fast_when_closed_before_rate_limit() {
2477 use std::{num::NonZeroU32, sync::Arc};
2480
2481 use nautilus_common::testing::wait_until_async;
2482
2483 use crate::ratelimiter::quota::Quota;
2484
2485 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
2486 let port = listener.local_addr().unwrap().port();
2487
2488 let server = task::spawn(async move {
2489 if let Ok((stream, _)) = listener.accept().await
2491 && let Ok(ws) = accept_async(stream).await
2492 {
2493 drop(ws);
2494 }
2495 sleep(Duration::from_secs(60)).await;
2496 });
2497
2498 let (handler, _rx) = channel_message_handler();
2499
2500 let config = WebSocketConfig {
2501 url: format!("ws://127.0.0.1:{port}"),
2502 headers: vec![],
2503 heartbeat: None,
2504 heartbeat_msg: None,
2505 reconnect_timeout_ms: Some(5_000),
2506 reconnect_delay_initial_ms: Some(50),
2507 reconnect_delay_max_ms: Some(100),
2508 reconnect_backoff_factor: Some(1.0),
2509 reconnect_jitter_ms: Some(0),
2510 reconnect_max_attempts: None,
2511 idle_timeout_ms: None,
2512 };
2513
2514 let quota = Quota::with_period(Duration::from_secs(10))
2517 .unwrap()
2518 .allow_burst(NonZeroU32::new(1).unwrap());
2519
2520 let client = Arc::new(
2521 WebSocketClient::connect(
2522 config,
2523 Some(handler),
2524 None,
2525 None,
2526 vec![("test_key".to_string(), quota)],
2527 None,
2528 )
2529 .await
2530 .unwrap(),
2531 );
2532
2533 wait_until_async(
2535 || async { client.is_reconnecting() || client.is_closed() },
2536 Duration::from_secs(2),
2537 )
2538 .await;
2539
2540 client.disconnect().await;
2542 assert!(
2543 !client.is_active(),
2544 "Client should not be active after disconnect"
2545 );
2546
2547 let start = std::time::Instant::now();
2549 let test_key: [Ustr; 1] = [Ustr::from("test_key")];
2550 let result = client
2551 .send_text("test".to_string(), Some(test_key.as_slice()))
2552 .await;
2553 let elapsed = start.elapsed();
2554
2555 assert!(result.is_err(), "Send should fail when client is closed");
2557 assert!(
2558 matches!(result, Err(crate::error::SendError::Closed)),
2559 "Send should return Closed error, was: {result:?}"
2560 );
2561
2562 assert!(
2564 elapsed < Duration::from_millis(100),
2565 "Send should fail fast without rate limiting, took {elapsed:?}"
2566 );
2567
2568 server.abort();
2569 }
2570
2571 #[rstest]
2572 #[tokio::test]
2573 async fn test_connect_rejects_none_message_handler() {
2574 let config = WebSocketConfig {
2578 url: "ws://127.0.0.1:9999".to_string(),
2579 headers: vec![],
2580 heartbeat: None,
2581 heartbeat_msg: None,
2582 reconnect_timeout_ms: Some(1_000),
2583 reconnect_delay_initial_ms: Some(100),
2584 reconnect_delay_max_ms: Some(500),
2585 reconnect_backoff_factor: Some(1.5),
2586 reconnect_jitter_ms: Some(0),
2587 reconnect_max_attempts: None,
2588 idle_timeout_ms: None,
2589 };
2590
2591 let result = WebSocketClient::connect(config, None, None, None, vec![], None).await;
2593
2594 assert!(
2595 result.is_err(),
2596 "connect() should reject None message_handler"
2597 );
2598
2599 let err = result.unwrap_err();
2600 let err_msg = err.to_string();
2601 assert!(
2602 err_msg.contains("Handler mode requires message_handler"),
2603 "Error should mention missing message_handler, was: {err_msg}"
2604 );
2605 }
2606
2607 #[rstest]
2608 #[tokio::test]
2609 async fn test_client_without_handler_sets_stream_mode() {
2610 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
2614 let port = listener.local_addr().unwrap().port();
2615
2616 let server = task::spawn(async move {
2617 if let Ok((stream, _)) = listener.accept().await
2619 && let Ok(ws) = accept_async(stream).await
2620 {
2621 drop(ws); }
2623 });
2624
2625 let config = WebSocketConfig {
2626 url: format!("ws://127.0.0.1:{port}"),
2627 headers: vec![],
2628 heartbeat: None,
2629 heartbeat_msg: None,
2630 reconnect_timeout_ms: Some(1_000),
2631 reconnect_delay_initial_ms: Some(100),
2632 reconnect_delay_max_ms: Some(500),
2633 reconnect_backoff_factor: Some(1.5),
2634 reconnect_jitter_ms: Some(0),
2635 reconnect_max_attempts: None,
2636 idle_timeout_ms: None,
2637 };
2638
2639 let inner = WebSocketClientInner::connect_url(config, None, None)
2641 .await
2642 .unwrap();
2643
2644 assert!(
2646 inner.is_stream_mode,
2647 "Client without handler should have is_stream_mode=true"
2648 );
2649
2650 server.abort();
2654 }
2655
2656 #[rstest]
2657 #[tokio::test]
2658 async fn test_idle_timeout_triggers_reconnect() {
2659 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
2660 let port = listener.local_addr().unwrap().port();
2661
2662 let server = task::spawn(async move {
2664 let (stream, _) = listener.accept().await.unwrap();
2665 let _ws = accept_async(stream).await.unwrap();
2666 sleep(Duration::from_secs(5)).await;
2668 });
2669
2670 let (handler, _rx) = channel_message_handler();
2671
2672 let config = WebSocketConfig {
2673 url: format!("ws://127.0.0.1:{port}"),
2674 headers: vec![],
2675 heartbeat: None,
2676 heartbeat_msg: None,
2677 reconnect_timeout_ms: Some(2_000),
2678 reconnect_delay_initial_ms: Some(50),
2679 reconnect_delay_max_ms: Some(100),
2680 reconnect_backoff_factor: Some(1.0),
2681 reconnect_jitter_ms: Some(0),
2682 reconnect_max_attempts: Some(1),
2683 idle_timeout_ms: Some(500),
2684 };
2685
2686 let client = WebSocketClient::connect(config, Some(handler), None, None, vec![], None)
2687 .await
2688 .unwrap();
2689
2690 assert!(client.is_active());
2691
2692 wait_until_async(
2694 || async { client.is_reconnecting() || client.is_disconnected() },
2695 Duration::from_secs(3),
2696 )
2697 .await;
2698
2699 assert!(
2700 !client.is_active(),
2701 "Client should not be active after idle timeout"
2702 );
2703
2704 client.disconnect().await;
2705 server.abort();
2706 }
2707
2708 #[rstest]
2709 #[tokio::test]
2710 async fn test_idle_timeout_resets_on_data() {
2711 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
2712 let port = listener.local_addr().unwrap().port();
2713
2714 let server = task::spawn(async move {
2716 let (stream, _) = listener.accept().await.unwrap();
2717 let mut ws = accept_async(stream).await.unwrap();
2718 for _ in 0..10 {
2719 sleep(Duration::from_millis(200)).await;
2720
2721 if ws
2722 .send(tokio_tungstenite::tungstenite::Message::Text("ping".into()))
2723 .await
2724 .is_err()
2725 {
2726 break;
2727 }
2728 }
2729 });
2730
2731 let (handler, _rx) = channel_message_handler();
2732
2733 let config = WebSocketConfig {
2734 url: format!("ws://127.0.0.1:{port}"),
2735 headers: vec![],
2736 heartbeat: None,
2737 heartbeat_msg: None,
2738 reconnect_timeout_ms: Some(2_000),
2739 reconnect_delay_initial_ms: Some(50),
2740 reconnect_delay_max_ms: Some(100),
2741 reconnect_backoff_factor: Some(1.0),
2742 reconnect_jitter_ms: Some(0),
2743 reconnect_max_attempts: Some(1),
2744 idle_timeout_ms: Some(1_000),
2745 };
2746
2747 let client = WebSocketClient::connect(config, Some(handler), None, None, vec![], None)
2748 .await
2749 .unwrap();
2750
2751 assert!(client.is_active());
2752
2753 sleep(Duration::from_millis(1_500)).await;
2755
2756 assert!(
2757 client.is_active(),
2758 "Client should remain active when data is flowing"
2759 );
2760
2761 client.disconnect().await;
2762 server.abort();
2763 }
2764
2765 #[rstest]
2766 #[tokio::test]
2767 async fn test_disconnect_during_backoff_exits_promptly() {
2768 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
2772 let port = listener.local_addr().unwrap().port();
2773
2774 let server = task::spawn(async move {
2775 if let Ok((stream, _)) = listener.accept().await {
2777 let _ = accept_async(stream).await;
2778 }
2779 sleep(Duration::from_secs(60)).await;
2781 });
2782
2783 let (handler, _rx) = channel_message_handler();
2784
2785 let config = WebSocketConfig {
2786 url: format!("ws://127.0.0.1:{port}"),
2787 headers: vec![],
2788 heartbeat: None,
2789 heartbeat_msg: None,
2790 reconnect_timeout_ms: Some(1_000),
2791 reconnect_delay_initial_ms: Some(10_000), reconnect_delay_max_ms: Some(10_000),
2793 reconnect_backoff_factor: Some(1.0),
2794 reconnect_jitter_ms: Some(0),
2795 reconnect_max_attempts: None,
2796 idle_timeout_ms: None,
2797 };
2798
2799 let client = WebSocketClient::connect(config, Some(handler), None, None, vec![], None)
2800 .await
2801 .unwrap();
2802
2803 wait_until_async(
2805 || async { client.is_reconnecting() },
2806 Duration::from_secs(3),
2807 )
2808 .await;
2809
2810 sleep(Duration::from_millis(1_500)).await;
2812
2813 let start = std::time::Instant::now();
2815 client.disconnect().await;
2816 let elapsed = start.elapsed();
2817
2818 assert!(client.is_disconnected(), "Client should be disconnected");
2819 assert!(
2821 elapsed < Duration::from_secs(2),
2822 "Disconnect should interrupt backoff sleep, took {elapsed:?}"
2823 );
2824
2825 server.abort();
2826 }
2827
2828 #[rstest]
2829 #[tokio::test]
2830 async fn test_rate_limit_cancelled_on_disconnect() {
2831 use std::{num::NonZeroU32, sync::Arc};
2834
2835 use crate::ratelimiter::quota::Quota;
2836
2837 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
2838 let port = listener.local_addr().unwrap().port();
2839
2840 let server = task::spawn(async move {
2841 if let Ok((stream, _)) = listener.accept().await {
2842 let mut ws = accept_async(stream).await.unwrap();
2843 while let Some(Ok(msg)) = ws.next().await {
2845 if ws.send(msg).await.is_err() {
2846 break;
2847 }
2848 }
2849 }
2850 });
2851
2852 let (handler, _rx) = channel_message_handler();
2853
2854 let config = WebSocketConfig {
2855 url: format!("ws://127.0.0.1:{port}"),
2856 headers: vec![],
2857 heartbeat: None,
2858 heartbeat_msg: None,
2859 reconnect_timeout_ms: Some(5_000),
2860 reconnect_delay_initial_ms: Some(100),
2861 reconnect_delay_max_ms: Some(500),
2862 reconnect_backoff_factor: Some(1.5),
2863 reconnect_jitter_ms: Some(0),
2864 reconnect_max_attempts: None,
2865 idle_timeout_ms: None,
2866 };
2867
2868 let quota = Quota::with_period(Duration::from_secs(60))
2870 .unwrap()
2871 .allow_burst(NonZeroU32::new(1).unwrap());
2872
2873 let client = Arc::new(
2874 WebSocketClient::connect(
2875 config,
2876 Some(handler),
2877 None,
2878 None,
2879 vec![("rate_key".to_string(), quota)],
2880 None,
2881 )
2882 .await
2883 .unwrap(),
2884 );
2885
2886 let test_key: [Ustr; 1] = [Ustr::from("rate_key")];
2887
2888 client
2890 .send_text("exhaust".to_string(), Some(test_key.as_slice()))
2891 .await
2892 .unwrap();
2893
2894 let client_clone = client.clone();
2896 let send_handle = task::spawn(async move {
2897 client_clone
2898 .send_text("blocked".to_string(), Some(&[Ustr::from("rate_key")]))
2899 .await
2900 });
2901
2902 sleep(Duration::from_millis(200)).await;
2904
2905 let start = std::time::Instant::now();
2907 client.disconnect().await;
2908 let elapsed_disconnect = start.elapsed();
2909
2910 let result = tokio::time::timeout(Duration::from_secs(2), send_handle)
2912 .await
2913 .expect("Send task should complete quickly")
2914 .expect("Send task should not panic");
2915
2916 assert!(
2917 matches!(result, Err(crate::error::SendError::Closed)),
2918 "Blocked send should return Closed, was: {result:?}"
2919 );
2920
2921 assert!(
2923 elapsed_disconnect < Duration::from_secs(3),
2924 "Disconnect should not wait for rate limiter, took {elapsed_disconnect:?}"
2925 );
2926
2927 server.abort();
2928 }
2929
2930 #[rstest]
2931 #[tokio::test]
2932 async fn test_stream_mode_transitions_to_closed_on_dead_write_task() {
2933 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
2937 let port = listener.local_addr().unwrap().port();
2938
2939 let server = task::spawn(async move {
2940 if let Ok((stream, _)) = listener.accept().await
2941 && let Ok(ws) = accept_async(stream).await
2942 {
2943 drop(ws);
2945 }
2946 });
2947
2948 let config = WebSocketConfig {
2949 url: format!("ws://127.0.0.1:{port}"),
2950 headers: vec![],
2951 heartbeat: None,
2952 heartbeat_msg: None,
2953 reconnect_timeout_ms: Some(1_000),
2954 reconnect_delay_initial_ms: Some(50),
2955 reconnect_delay_max_ms: Some(100),
2956 reconnect_backoff_factor: Some(1.0),
2957 reconnect_jitter_ms: Some(0),
2958 reconnect_max_attempts: None,
2959 idle_timeout_ms: None,
2960 };
2961
2962 let (_reader, client) = WebSocketClient::connect_stream(config, vec![], None, None)
2963 .await
2964 .unwrap();
2965
2966 assert!(client.is_active(), "Client should start active");
2967
2968 sleep(Duration::from_millis(100)).await;
2970
2971 for _ in 0..20 {
2973 let _ = client.send_text("ping".to_string(), None).await;
2974 sleep(Duration::from_millis(50)).await;
2975
2976 if !client.is_active() {
2977 break;
2978 }
2979 }
2980
2981 wait_until_async(|| async { !client.is_active() }, Duration::from_secs(5)).await;
2983
2984 assert!(
2986 client.is_closed() || client.is_disconnected(),
2987 "Stream mode should transition to CLOSED, not RECONNECT. \
2988 is_reconnecting={}, is_closed={}, is_disconnected={}",
2989 client.is_reconnecting(),
2990 client.is_closed(),
2991 client.is_disconnected(),
2992 );
2993 assert!(
2994 !client.is_reconnecting(),
2995 "Stream mode should never attempt reconnection"
2996 );
2997
2998 server.abort();
2999 }
3000
3001 #[rstest]
3002 #[tokio::test]
3003 async fn test_zero_idle_timeout_rejected() {
3004 let (handler, _rx) = channel_message_handler();
3005
3006 let config = WebSocketConfig {
3007 url: "ws://127.0.0.1:9999".to_string(),
3008 headers: vec![],
3009 heartbeat: None,
3010 heartbeat_msg: None,
3011 reconnect_timeout_ms: None,
3012 reconnect_delay_initial_ms: None,
3013 reconnect_delay_max_ms: None,
3014 reconnect_backoff_factor: None,
3015 reconnect_jitter_ms: None,
3016 reconnect_max_attempts: None,
3017 idle_timeout_ms: Some(0),
3018 };
3019
3020 let result =
3021 WebSocketClient::connect(config, Some(handler), None, None, vec![], None).await;
3022
3023 assert!(result.is_err(), "Zero idle timeout should be rejected");
3024 let err_msg = result.unwrap_err().to_string();
3025 assert!(
3026 err_msg.contains("Idle timeout cannot be zero"),
3027 "Error should mention zero idle timeout, was: {err_msg}"
3028 );
3029 }
3030}