1use std::{
27 collections::VecDeque,
28 fmt::Debug,
29 sync::{
30 Arc, OnceLock,
31 atomic::{AtomicBool, AtomicU8, Ordering},
32 },
33 time::Duration,
34};
35
36use futures_util::{SinkExt, StreamExt};
37use http::HeaderName;
38use nautilus_core::CleanDrop;
39use nautilus_cryptography::providers::install_cryptographic_provider;
40#[cfg(any(feature = "turmoil", feature = "transport-sockudo"))]
41use rustls::ClientConfig;
42#[cfg(feature = "transport-sockudo")]
43use sockudo_ws::{
44 Config as SockudoConfig, Http1, Role, Stream as SockudoStream,
45 WebSocketStream as SockudoWebSocketStream,
46};
47#[cfg(feature = "transport-sockudo")]
48use tokio::io::{AsyncRead, AsyncWrite};
49#[cfg(any(feature = "turmoil", feature = "transport-sockudo"))]
50use tokio_rustls::TlsConnector;
51#[cfg(feature = "turmoil")]
52use tokio_tungstenite::MaybeTlsStream;
53#[cfg(feature = "turmoil")]
54use tokio_tungstenite::client_async;
55#[cfg(not(feature = "turmoil"))]
56use tokio_tungstenite::connect_async_with_config;
57use tokio_tungstenite::tungstenite::{client::IntoClientRequest, http::HeaderValue};
58use ustr::Ustr;
59
60#[cfg(not(feature = "turmoil"))]
61use super::proxy::{ProxiedStream, ProxyKind, WsTarget, tunnel_via_proxy};
62use super::{
63 auth::{AuthState, AuthTracker},
64 config::{TransportBackend, WebSocketConfig},
65 consts::{
66 CONNECTION_STATE_CHECK_INTERVAL_MS, GRACEFUL_SHUTDOWN_DELAY_MS,
67 GRACEFUL_SHUTDOWN_TIMEOUT_SECS,
68 },
69 types::{MessageHandler, MessageReader, MessageWriter, PingHandler, WriterCommand},
70};
71#[cfg(feature = "turmoil")]
72use crate::net::TcpConnector;
73#[cfg(feature = "transport-sockudo")]
74use crate::net::TcpStream;
75#[cfg(feature = "transport-sockudo")]
76use crate::transport::sockudo::{
77 PrefixedIo, SockudoTransport, client_handshake_with_headers, validate_extra_headers,
78};
79use crate::{
80 RECONNECTED,
81 backoff::ExponentialBackoff,
82 dst,
83 error::SendError,
84 logging::{log_task_aborted, log_task_started, log_task_stopped},
85 mode::ConnectionMode,
86 ratelimiter::{RateLimiter, clock::MonotonicClock, quota::Quota},
87 transport::{BoxedWsTransport, Message, TransportError, tungstenite::TungsteniteTransport},
88};
89
90pub struct WebSocketClientInner {
106 config: WebSocketConfig,
107 message_handler: Option<MessageHandler>,
109 ping_handler: Option<PingHandler>,
111 read_task: Option<tokio::task::JoinHandle<()>>,
112 write_task: tokio::task::JoinHandle<()>,
113 writer_tx: tokio::sync::mpsc::UnboundedSender<WriterCommand>,
114 heartbeat_task: Option<tokio::task::JoinHandle<()>>,
115 connection_mode: Arc<AtomicU8>,
116 state_notify: Arc<tokio::sync::Notify>,
117 reconnect_timeout: Duration,
118 backoff: ExponentialBackoff,
119 is_stream_mode: bool,
123 reconnect_max_attempts: Option<u32>,
125 reconnection_attempt_count: u32,
127 auth_tracker: Arc<OnceLock<AuthTracker>>,
129 reconnect_buffer_waits_for_auth: Arc<AtomicBool>,
131}
132
133enum ReconnectBufferAction {
134 Drain,
135 Wait,
136 Discard,
137}
138
139impl WebSocketClientInner {
140 #[expect(
148 clippy::unused_async,
149 reason = "async signature for consistency with connect-based constructors"
150 )]
151 pub async fn new_with_writer(
152 config: WebSocketConfig,
153 writer: MessageWriter,
154 ) -> Result<Self, TransportError> {
155 install_cryptographic_provider();
156
157 let connection_mode = Arc::new(AtomicU8::new(ConnectionMode::Active.as_u8()));
158 let state_notify = Arc::new(tokio::sync::Notify::new());
159
160 let read_task = None;
162
163 let backoff = ExponentialBackoff::new(
165 Duration::from_secs(2),
166 Duration::from_secs(30),
167 1.5,
168 100,
169 true,
170 )
171 .map_err(|e| {
172 TransportError::Io(std::io::Error::new(std::io::ErrorKind::InvalidInput, e))
173 })?;
174
175 let auth_tracker = Arc::new(OnceLock::new());
176 let reconnect_buffer_waits_for_auth = Arc::new(AtomicBool::new(false));
177
178 let (writer_tx, writer_rx) = tokio::sync::mpsc::unbounded_channel::<WriterCommand>();
179 let write_task = Self::spawn_write_task(
180 connection_mode.clone(),
181 state_notify.clone(),
182 writer,
183 writer_rx,
184 Arc::clone(&auth_tracker),
185 Arc::clone(&reconnect_buffer_waits_for_auth),
186 );
187
188 let heartbeat_task = if let Some(heartbeat_interval) = config.heartbeat {
189 Some(Self::spawn_heartbeat_task(
190 connection_mode.clone(),
191 heartbeat_interval,
192 config.heartbeat_msg.clone(),
193 writer_tx.clone(),
194 ))
195 } else {
196 None
197 };
198
199 let reconnect_max_attempts = None; let reconnect_timeout = Duration::from_secs(10);
201
202 Ok(Self {
203 config,
204 message_handler: None, ping_handler: None,
206 writer_tx,
207 connection_mode,
208 state_notify,
209 reconnect_timeout,
210 heartbeat_task,
211 read_task,
212 write_task,
213 backoff,
214 is_stream_mode: true,
215 reconnect_max_attempts,
216 reconnection_attempt_count: 0,
217 auth_tracker,
218 reconnect_buffer_waits_for_auth,
219 })
220 }
221
222 pub async fn connect_url(
230 config: WebSocketConfig,
231 message_handler: Option<MessageHandler>,
232 ping_handler: Option<PingHandler>,
233 ) -> Result<Self, TransportError> {
234 install_cryptographic_provider();
235
236 if config.heartbeat == Some(0) {
237 return Err(TransportError::Io(std::io::Error::new(
238 std::io::ErrorKind::InvalidInput,
239 "Heartbeat interval cannot be zero",
240 )));
241 }
242
243 if config.idle_timeout_ms == Some(0) {
244 return Err(TransportError::Io(std::io::Error::new(
245 std::io::ErrorKind::InvalidInput,
246 "Idle timeout cannot be zero",
247 )));
248 }
249
250 let is_stream_mode = message_handler.is_none();
252 let reconnect_max_attempts = config.reconnect_max_attempts;
253
254 let (writer, reader) = Box::pin(Self::connect_with_server(
255 &config.url,
256 config.headers.clone(),
257 config.backend,
258 config.proxy_url.as_deref(),
259 ))
260 .await?;
261
262 let connection_mode = Arc::new(AtomicU8::new(ConnectionMode::Active.as_u8()));
263 let state_notify = Arc::new(tokio::sync::Notify::new());
264
265 let read_task = if message_handler.is_some() {
266 Some(Self::spawn_message_handler_task(
267 connection_mode.clone(),
268 state_notify.clone(),
269 reader,
270 message_handler.as_ref(),
271 ping_handler.as_ref(),
272 config.idle_timeout_ms,
273 ))
274 } else {
275 None
276 };
277
278 let auth_tracker = Arc::new(OnceLock::new());
279 let reconnect_buffer_waits_for_auth = Arc::new(AtomicBool::new(false));
280
281 let (writer_tx, writer_rx) = tokio::sync::mpsc::unbounded_channel::<WriterCommand>();
282 let write_task = Self::spawn_write_task(
283 connection_mode.clone(),
284 state_notify.clone(),
285 writer,
286 writer_rx,
287 Arc::clone(&auth_tracker),
288 Arc::clone(&reconnect_buffer_waits_for_auth),
289 );
290
291 let heartbeat_task = config.heartbeat.map(|heartbeat_secs| {
293 Self::spawn_heartbeat_task(
294 connection_mode.clone(),
295 heartbeat_secs,
296 config.heartbeat_msg.clone(),
297 writer_tx.clone(),
298 )
299 });
300
301 let reconnect_timeout =
302 Duration::from_millis(config.reconnect_timeout_ms.unwrap_or(10_000));
303 let backoff = ExponentialBackoff::new(
304 Duration::from_millis(config.reconnect_delay_initial_ms.unwrap_or(2_000)),
305 Duration::from_millis(config.reconnect_delay_max_ms.unwrap_or(30_000)),
306 config.reconnect_backoff_factor.unwrap_or(1.5),
307 config.reconnect_jitter_ms.unwrap_or(100),
308 true, )
310 .map_err(|e| {
311 TransportError::Io(std::io::Error::new(std::io::ErrorKind::InvalidInput, e))
312 })?;
313
314 Ok(Self {
315 config,
316 message_handler,
317 ping_handler,
318 read_task,
319 write_task,
320 writer_tx,
321 heartbeat_task,
322 connection_mode,
323 state_notify,
324 reconnect_timeout,
325 backoff,
326 is_stream_mode,
328 reconnect_max_attempts,
329 reconnection_attempt_count: 0,
330 auth_tracker,
331 reconnect_buffer_waits_for_auth,
332 })
333 }
334
335 #[inline]
358 pub async fn connect_with_server(
359 url: &str,
360 headers: Vec<(String, String)>,
361 backend: TransportBackend,
362 proxy_url: Option<&str>,
363 ) -> Result<(MessageWriter, MessageReader), TransportError> {
364 if matches!(backend, TransportBackend::Sockudo)
369 && let Some(proxy) = proxy_url
370 {
371 log::warn!("Sockudo backend does not support proxy_url; falling back to Tungstenite");
372 return Box::pin(Self::connect_tungstenite_via_proxy(url, headers, proxy)).await;
373 }
374
375 match backend {
376 TransportBackend::Tungstenite => match proxy_url {
377 Some(proxy) => {
378 Box::pin(Self::connect_tungstenite_via_proxy(url, headers, proxy)).await
379 }
380 None => Self::connect_tungstenite(url, headers).await,
381 },
382 TransportBackend::Sockudo => {
383 #[cfg(feature = "transport-sockudo")]
384 {
385 Self::connect_sockudo(url, headers).await
386 }
387 #[cfg(not(feature = "transport-sockudo"))]
388 {
389 Err(TransportError::Other(
390 "sockudo backend selected but the transport-sockudo \
391 Cargo feature is not enabled"
392 .to_string(),
393 ))
394 }
395 }
396 }
397 }
398
399 #[inline]
402 #[cfg(not(feature = "turmoil"))]
403 async fn connect_tungstenite(
404 url: &str,
405 headers: Vec<(String, String)>,
406 ) -> Result<(MessageWriter, MessageReader), TransportError> {
407 let mut request = url.into_client_request().map_err(TransportError::from)?;
408 let req_headers = request.headers_mut();
409
410 for (key, val) in headers {
411 let header_value = HeaderValue::from_str(&val)
412 .map_err(|e| TransportError::Handshake(format!("invalid header value: {e}")))?;
413 let header_name: HeaderName = key
414 .parse()
415 .map_err(|e| TransportError::Handshake(format!("invalid header name: {e}")))?;
416 req_headers.insert(header_name, header_value);
417 }
418
419 let (stream, _resp) = connect_async_with_config(request, None, true)
420 .await
421 .map_err(TransportError::from)?;
422 let transport: BoxedWsTransport = Box::pin(TungsteniteTransport::new(stream));
423 Ok(transport.split())
424 }
425
426 #[inline]
434 #[cfg(not(feature = "turmoil"))]
435 async fn connect_tungstenite_via_proxy(
436 url: &str,
437 headers: Vec<(String, String)>,
438 proxy_url: &str,
439 ) -> Result<(MessageWriter, MessageReader), TransportError> {
440 let proxy = match ProxyKind::parse(proxy_url)? {
441 ProxyKind::Http(target) => target,
442 ProxyKind::Unsupported { scheme } => {
443 log::warn!(
444 "WebSocket proxy_url scheme '{scheme}' is not yet supported; \
445 connecting without a WebSocket proxy"
446 );
447 return Self::connect_tungstenite(url, headers).await;
448 }
449 };
450
451 let mut request = url.into_client_request().map_err(TransportError::from)?;
452 let req_headers = request.headers_mut();
453
454 for (key, val) in headers {
455 let header_value = HeaderValue::from_str(&val)
456 .map_err(|e| TransportError::Handshake(format!("invalid header value: {e}")))?;
457 let header_name: HeaderName = key
458 .parse()
459 .map_err(|e| TransportError::Handshake(format!("invalid header name: {e}")))?;
460 req_headers.insert(header_name, header_value);
461 }
462
463 let target = WsTarget::parse(url)?;
464 let stream = tunnel_via_proxy(&target, &proxy).await?;
465
466 #[allow(clippy::match_same_arms)]
474 let transport: BoxedWsTransport = match stream {
475 ProxiedStream::Plain(tcp) => Box::pin(proxied_ws_handshake(request, tcp)).await?,
476 ProxiedStream::PlainOverTlsProxy(s) => {
477 Box::pin(proxied_ws_handshake(request, *s)).await?
478 }
479 ProxiedStream::Tls(s) => Box::pin(proxied_ws_handshake(request, *s)).await?,
480 ProxiedStream::TlsOverTlsProxy(s) => {
481 Box::pin(proxied_ws_handshake(request, *s)).await?
482 }
483 };
484
485 Ok(transport.split())
486 }
487
488 #[inline]
491 #[cfg(feature = "turmoil")]
492 #[expect(
493 clippy::unused_async,
494 reason = "signature mirrors the production variant; both are awaited in the dispatcher"
495 )]
496 async fn connect_tungstenite_via_proxy(
497 _url: &str,
498 _headers: Vec<(String, String)>,
499 _proxy_url: &str,
500 ) -> Result<(MessageWriter, MessageReader), TransportError> {
501 Err(TransportError::Other(
502 "proxy_url is not supported under the turmoil simulator".to_string(),
503 ))
504 }
505
506 #[inline]
509 #[cfg(feature = "turmoil")]
510 async fn connect_tungstenite(
511 url: &str,
512 headers: Vec<(String, String)>,
513 ) -> Result<(MessageWriter, MessageReader), TransportError> {
514 let mut request = url.into_client_request().map_err(TransportError::from)?;
515 let req_headers = request.headers_mut();
516
517 for (key, val) in headers {
518 let header_value = HeaderValue::from_str(&val)
519 .map_err(|e| TransportError::Handshake(format!("invalid header value: {e}")))?;
520 let header_name: HeaderName = key
521 .parse()
522 .map_err(|e| TransportError::Handshake(format!("invalid header name: {e}")))?;
523 req_headers.insert(header_name, header_value);
524 }
525
526 let uri = request.uri();
527 let scheme = uri.scheme_str().unwrap_or("ws");
528 let host = uri
529 .host()
530 .ok_or_else(|| TransportError::InvalidUrl("missing hostname".to_string()))?;
531
532 let port = uri
534 .port_u16()
535 .unwrap_or_else(|| if scheme == "wss" { 443 } else { 80 });
536
537 let addr = format!("{host}:{port}");
538
539 let connector = crate::net::RealTcpConnector;
541 let tcp_stream = connector.connect(&addr).await?;
542 if let Err(e) = tcp_stream.set_nodelay(true) {
543 log::warn!("Failed to enable TCP_NODELAY for socket client: {e:?}");
544 }
545
546 let maybe_tls_stream = if scheme == "wss" {
548 let mut root_store = rustls::RootCertStore::empty();
550 root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
551
552 let config = ClientConfig::builder()
553 .with_root_certificates(root_store)
554 .with_no_client_auth();
555
556 let tls_connector = TlsConnector::from(std::sync::Arc::new(config));
557 let domain = rustls::pki_types::ServerName::try_from(host.to_string())
558 .map_err(|e| TransportError::Tls(format!("Invalid DNS name: {e}")))?;
559
560 let tls_stream = tls_connector
561 .connect(domain, tcp_stream)
562 .await
563 .map_err(TransportError::Io)?;
564 MaybeTlsStream::Rustls(tls_stream)
565 } else {
566 MaybeTlsStream::Plain(tcp_stream)
567 };
568
569 let (stream, _resp) = client_async(request, maybe_tls_stream)
571 .await
572 .map_err(TransportError::from)?;
573 let transport: BoxedWsTransport = Box::pin(TungsteniteTransport::new(stream));
574 Ok(transport.split())
575 }
576
577 #[inline]
586 #[cfg(feature = "transport-sockudo")]
587 async fn connect_sockudo(
588 url: &str,
589 headers: Vec<(String, String)>,
590 ) -> Result<(MessageWriter, MessageReader), TransportError> {
591 let target = SockudoTarget::parse(url)?;
592 validate_extra_headers(&headers).map_err(TransportError::from)?;
593
594 #[cfg(feature = "turmoil")]
595 if target.is_tls {
596 return Err(TransportError::Tls(
597 "wss:// is not supported under the turmoil simulator; use ws://".to_string(),
598 ));
599 }
600
601 let tcp_stream = TcpStream::connect((target.host.as_str(), target.port))
602 .await
603 .map_err(TransportError::Io)?;
604
605 if let Err(e) = tcp_stream.set_nodelay(true) {
606 log::warn!("Failed to enable TCP_NODELAY for sockudo client: {e:?}");
607 }
608
609 #[cfg(not(feature = "turmoil"))]
610 if target.is_tls {
611 let mut root_store = rustls::RootCertStore::empty();
612 root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
613 let config = ClientConfig::builder()
614 .with_root_certificates(root_store)
615 .with_no_client_auth();
616 let connector = TlsConnector::from(std::sync::Arc::new(config));
617 let domain = rustls::pki_types::ServerName::try_from(target.host.clone())
618 .map_err(|e| TransportError::Tls(format!("Invalid DNS name: {e}")))?;
619 let tls_stream = connector
620 .connect(domain, tcp_stream)
621 .await
622 .map_err(TransportError::Io)?;
623 return Self::finish_sockudo_handshake(tls_stream, &target, &headers).await;
624 }
625
626 Self::finish_sockudo_handshake(tcp_stream, &target, &headers).await
627 }
628
629 #[cfg(feature = "transport-sockudo")]
630 async fn finish_sockudo_handshake<S>(
631 mut stream: S,
632 target: &SockudoTarget,
633 headers: &[(String, String)],
634 ) -> Result<(MessageWriter, MessageReader), TransportError>
635 where
636 S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
637 {
638 let handshake = client_handshake_with_headers(
642 &mut stream,
643 &target.host_header,
644 &target.path,
645 None,
646 headers,
647 )
648 .await
649 .map_err(TransportError::from)?;
650
651 let stream = match handshake.leftover {
654 Some(prefix) => SockudoStream::<Http1>::new(PrefixedIo::new(stream, prefix)),
655 None => SockudoStream::<Http1>::new(stream),
656 };
657 let ws = SockudoWebSocketStream::from_raw(stream, Role::Client, SockudoConfig::default());
658 let transport: BoxedWsTransport = Box::pin(SockudoTransport::new(ws));
659 Ok(transport.split())
660 }
661}
662
663#[cfg(not(feature = "turmoil"))]
668async fn proxied_ws_handshake<S>(
669 request: tokio_tungstenite::tungstenite::handshake::client::Request,
670 stream: S,
671) -> Result<BoxedWsTransport, TransportError>
672where
673 S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static,
674{
675 let (ws, _resp) = tokio_tungstenite::client_async(request, stream)
676 .await
677 .map_err(TransportError::from)?;
678 Ok(Box::pin(TungsteniteTransport::new(ws)))
679}
680
681#[cfg(feature = "transport-sockudo")]
688#[derive(Debug, PartialEq, Eq)]
689struct SockudoTarget {
690 host: String,
691 host_header: String,
694 port: u16,
695 path: String,
696 is_tls: bool,
697}
698
699#[cfg(feature = "transport-sockudo")]
700impl SockudoTarget {
701 fn parse(url: &str) -> Result<Self, TransportError> {
702 let parsed =
703 url::Url::parse(url).map_err(|e| TransportError::InvalidUrl(format!("{url}: {e}")))?;
704
705 let scheme = parsed.scheme();
706 let is_tls = match scheme {
707 "ws" => false,
708 "wss" => true,
709 other => {
710 return Err(TransportError::InvalidUrl(format!(
711 "expected ws:// or wss:// scheme, was {other}"
712 )));
713 }
714 };
715
716 let raw_host = parsed
717 .host_str()
718 .ok_or_else(|| TransportError::InvalidUrl("missing hostname".to_string()))?;
719
720 let is_bracketed = raw_host.starts_with('[') && raw_host.ends_with(']');
725 let host = if is_bracketed {
726 raw_host[1..raw_host.len() - 1].to_string()
727 } else {
728 raw_host.to_string()
729 };
730
731 let explicit_port = parsed.port();
732 let port = explicit_port.unwrap_or(if is_tls { 443 } else { 80 });
733 let host_header = match explicit_port {
734 Some(p) => format!("{raw_host}:{p}"),
735 None => raw_host.to_string(),
736 };
737
738 let path = if parsed.path().is_empty() {
739 "/".to_string()
740 } else {
741 let mut p = parsed.path().to_string();
742 if let Some(query) = parsed.query() {
743 p.push('?');
744 p.push_str(query);
745 }
746 p
747 };
748
749 Ok(Self {
750 host,
751 host_header,
752 port,
753 path,
754 is_tls,
755 })
756 }
757}
758
759impl WebSocketClientInner {
760 pub async fn reconnect(&mut self) -> Result<(), TransportError> {
775 log::debug!("Reconnecting");
776
777 if self.is_stream_mode {
778 log::warn!(
779 "Auto-reconnect disabled for stream-based WebSocket client; \
780 stream users must manually reconnect by creating a new connection"
781 );
782 self.connection_mode
784 .store(ConnectionMode::Closed.as_u8(), Ordering::SeqCst);
785 return Ok(());
786 }
787
788 if ConnectionMode::from_atomic(&self.connection_mode).is_disconnect() {
789 log::debug!("Reconnect aborted due to disconnect state");
790 return Ok(());
791 }
792
793 dst::time::timeout(self.reconnect_timeout, async {
794 let (new_writer, reader) = Self::connect_with_server(
796 &self.config.url,
797 self.config.headers.clone(),
798 self.config.backend,
799 self.config.proxy_url.as_deref(),
800 )
801 .await?;
802
803 if ConnectionMode::from_atomic(&self.connection_mode).is_disconnect() {
804 log::debug!("Reconnect aborted mid-flight (after connect)");
805 return Ok(());
806 }
807
808 let (tx, rx) = tokio::sync::oneshot::channel();
811 if let Err(e) = self.writer_tx.send(WriterCommand::Update(new_writer, tx)) {
812 log::error!("{e}");
813 return Err(TransportError::Io(std::io::Error::new(
814 std::io::ErrorKind::BrokenPipe,
815 format!("Failed to send update command: {e}"),
816 )));
817 }
818
819 match rx.await {
821 Ok(true) => log::debug!("Writer confirmed socket update"),
822 Ok(false) => {
823 log::warn!("Writer rejected socket update, aborting reconnect");
824 return Err(TransportError::Io(std::io::Error::other(
825 "Failed to update reconnection writer",
826 )));
827 }
828 Err(e) => {
829 log::error!("Writer dropped update channel: {e}");
830 return Err(TransportError::Io(std::io::Error::new(
831 std::io::ErrorKind::BrokenPipe,
832 "Writer task dropped response channel",
833 )));
834 }
835 }
836
837 dst::time::sleep(Duration::from_millis(GRACEFUL_SHUTDOWN_DELAY_MS)).await;
839
840 if ConnectionMode::from_atomic(&self.connection_mode).is_disconnect() {
841 log::debug!("Reconnect aborted mid-flight (after delay)");
842 return Ok(());
843 }
844
845 if let Some(ref read_task) = self.read_task.take()
846 && !read_task.is_finished()
847 {
848 read_task.abort();
849 log_task_aborted("read");
850 }
851
852 if self
855 .connection_mode
856 .compare_exchange(
857 ConnectionMode::Reconnect.as_u8(),
858 ConnectionMode::Active.as_u8(),
859 Ordering::SeqCst,
860 Ordering::SeqCst,
861 )
862 .is_err()
863 {
864 log::debug!("Reconnect aborted (state changed during reconnect)");
865 return Ok(());
866 }
867
868 self.read_task = if self.message_handler.is_some() {
869 Some(Self::spawn_message_handler_task(
870 self.connection_mode.clone(),
871 self.state_notify.clone(),
872 reader,
873 self.message_handler.as_ref(),
874 self.ping_handler.as_ref(),
875 self.config.idle_timeout_ms,
876 ))
877 } else {
878 None
879 };
880
881 log::debug!("Reconnect succeeded");
882 Ok(())
883 })
884 .await
885 .map_err(|_| {
886 TransportError::Io(std::io::Error::new(
887 std::io::ErrorKind::TimedOut,
888 format!(
889 "reconnection timed out after {}s",
890 self.reconnect_timeout.as_secs_f64()
891 ),
892 ))
893 })?
894 }
895
896 #[inline]
902 #[must_use]
903 pub fn is_alive(&self) -> bool {
904 match &self.read_task {
905 Some(read_task) => !read_task.is_finished() && !self.write_task.is_finished(),
906 None => !self.write_task.is_finished(),
907 }
908 }
909
910 fn spawn_message_handler_task(
911 connection_state: Arc<AtomicU8>,
912 state_notify: Arc<tokio::sync::Notify>,
913 mut reader: MessageReader,
914 message_handler: Option<&MessageHandler>,
915 ping_handler: Option<&PingHandler>,
916 idle_timeout_ms: Option<u64>,
917 ) -> tokio::task::JoinHandle<()> {
918 log::debug!("Started message handler task 'read'");
919
920 let check_interval = Duration::from_millis(CONNECTION_STATE_CHECK_INTERVAL_MS);
921 let idle_timeout = idle_timeout_ms.map(Duration::from_millis);
922
923 let message_handler = message_handler.cloned();
925 let ping_handler = ping_handler.cloned();
926
927 tokio::task::spawn(async move {
928 let mut last_data_time = dst::time::Instant::now();
929
930 loop {
931 if !ConnectionMode::from_atomic(&connection_state).is_active() {
932 break;
933 }
934
935 match dst::time::timeout(check_interval, reader.next()).await {
936 Ok(Some(Ok(Message::Binary(data)))) => {
937 log::trace!("Received message <binary> {} bytes", data.len());
938 last_data_time = dst::time::Instant::now();
939
940 if let Some(ref handler) = message_handler {
941 handler(Message::Binary(data));
942 }
943 }
944 Ok(Some(Ok(Message::Text(data)))) => {
945 log::trace!("Received message: {data:?}");
946 last_data_time = dst::time::Instant::now();
947
948 if let Some(ref handler) = message_handler {
949 handler(Message::Text(data));
950 }
951 }
952 Ok(Some(Ok(Message::Ping(ping_data)))) => {
953 log::trace!("Received ping: {ping_data:?}");
954 if let Some(ref handler) = ping_handler {
958 handler(ping_data.to_vec());
959 }
960 }
961 Ok(Some(Ok(Message::Pong(_)))) => {
962 log::trace!("Received pong");
963 }
965 Ok(Some(Ok(Message::Close(_)))) => {
966 log::debug!("Received close message - terminating");
967 break;
968 }
969 Ok(Some(Err(e))) => {
970 log::error!("Received error message - terminating: {e}");
971 break;
972 }
973 Ok(None) => {
974 log::debug!("No message received - terminating");
975 break;
976 }
977 Err(_) => {
978 if let Some(timeout) = idle_timeout {
979 let idle_duration = last_data_time.elapsed();
980 if idle_duration >= timeout {
981 log::warn!(
982 "Read idle timeout: no data received for {:.1}s",
983 idle_duration.as_secs_f64()
984 );
985 break;
986 }
987 }
988 }
989 }
990 }
991
992 state_notify.notify_one();
994 })
995 }
996
997 async fn drain_reconnect_buffer(
1002 buffer: &mut VecDeque<Message>,
1003 writer: &mut MessageWriter,
1004 ) -> bool {
1005 if buffer.is_empty() {
1006 return false;
1007 }
1008
1009 let initial_buffer_len = buffer.len();
1010 log::info!("Sending {initial_buffer_len} buffered messages after reconnection");
1011
1012 let mut send_error_occurred = false;
1013
1014 while let Some(buffered_msg) = buffer.front() {
1015 let msg_to_send = buffered_msg.clone();
1017
1018 if let Err(e) = writer.send(msg_to_send).await {
1019 log::error!(
1020 "Failed to send buffered message after reconnection: {e}, {} messages remain in buffer",
1021 buffer.len()
1022 );
1023 send_error_occurred = true;
1024 break; }
1026
1027 buffer.pop_front();
1029 }
1030
1031 if buffer.is_empty() {
1032 log::info!("Successfully sent all {initial_buffer_len} buffered messages");
1033 }
1034
1035 send_error_occurred
1036 }
1037
1038 fn can_drain_reconnect_buffer(
1039 reconnect_buffer_waits_for_auth: &AtomicBool,
1040 auth_tracker: &Arc<OnceLock<AuthTracker>>,
1041 ) -> ReconnectBufferAction {
1042 if !reconnect_buffer_waits_for_auth.load(Ordering::Acquire) {
1043 return ReconnectBufferAction::Drain;
1044 }
1045
1046 match auth_tracker.get().map(AuthTracker::auth_state) {
1047 Some(AuthState::Authenticated) => ReconnectBufferAction::Drain,
1048 Some(AuthState::Failed) => ReconnectBufferAction::Discard,
1049 Some(AuthState::Unauthenticated) | None => ReconnectBufferAction::Wait,
1050 }
1051 }
1052
1053 fn spawn_write_task(
1054 connection_state: Arc<AtomicU8>,
1055 state_notify: Arc<tokio::sync::Notify>,
1056 writer: MessageWriter,
1057 mut writer_rx: tokio::sync::mpsc::UnboundedReceiver<WriterCommand>,
1058 auth_tracker: Arc<OnceLock<AuthTracker>>,
1059 reconnect_buffer_waits_for_auth: Arc<AtomicBool>,
1060 ) -> tokio::task::JoinHandle<()> {
1061 log_task_started("write");
1062
1063 let check_interval = Duration::from_millis(CONNECTION_STATE_CHECK_INTERVAL_MS);
1065
1066 tokio::task::spawn(async move {
1067 let mut active_writer = writer;
1068 let mut reconnect_buffer: VecDeque<Message> = VecDeque::new();
1071
1072 loop {
1073 let mode = ConnectionMode::from_atomic(&connection_state);
1074
1075 match mode {
1076 ConnectionMode::Disconnect => {
1077 if !reconnect_buffer.is_empty() {
1079 log::warn!(
1080 "Discarding {} buffered messages due to disconnect",
1081 reconnect_buffer.len()
1082 );
1083 reconnect_buffer.clear();
1084 }
1085
1086 _ = dst::time::timeout(
1089 Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS),
1090 active_writer.close(),
1091 )
1092 .await;
1093 break;
1094 }
1095 ConnectionMode::Closed => {
1096 if !reconnect_buffer.is_empty() {
1098 log::warn!(
1099 "Discarding {} buffered messages due to closed connection",
1100 reconnect_buffer.len()
1101 );
1102 reconnect_buffer.clear();
1103 }
1104 break;
1105 }
1106 _ => {}
1107 }
1108
1109 if mode.is_active() && !reconnect_buffer.is_empty() {
1110 match Self::can_drain_reconnect_buffer(
1111 reconnect_buffer_waits_for_auth.as_ref(),
1112 &auth_tracker,
1113 ) {
1114 ReconnectBufferAction::Drain => {
1115 let send_error = Self::drain_reconnect_buffer(
1116 &mut reconnect_buffer,
1117 &mut active_writer,
1118 )
1119 .await;
1120
1121 if send_error {
1122 if let Some(tracker) = auth_tracker.get() {
1123 tracker.invalidate();
1124 }
1125 connection_state
1126 .store(ConnectionMode::Reconnect.as_u8(), Ordering::SeqCst);
1127 state_notify.notify_one();
1128 }
1129
1130 continue;
1131 }
1132 ReconnectBufferAction::Discard => {
1133 log::warn!(
1134 "Discarding {} buffered messages after authentication failed",
1135 reconnect_buffer.len()
1136 );
1137 reconnect_buffer.clear();
1138 continue;
1139 }
1140 ReconnectBufferAction::Wait => {}
1141 }
1142 }
1143
1144 match dst::time::timeout(check_interval, writer_rx.recv()).await {
1145 Ok(Some(msg)) => {
1146 let mode = ConnectionMode::from_atomic(&connection_state);
1148 if matches!(mode, ConnectionMode::Disconnect | ConnectionMode::Closed) {
1149 break;
1150 }
1151
1152 match msg {
1153 WriterCommand::Update(new_writer, tx) => {
1154 log::debug!("Received new writer");
1155
1156 dst::time::sleep(Duration::from_millis(100)).await;
1158
1159 _ = dst::time::timeout(
1162 Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS),
1163 active_writer.close(),
1164 )
1165 .await;
1166
1167 active_writer = new_writer;
1168 log::debug!("Updated writer");
1169
1170 if let Err(e) = tx.send(true) {
1171 log::error!(
1172 "Failed to report writer update to controller: {e:?}"
1173 );
1174 }
1175 }
1176 WriterCommand::Send(msg) if mode.is_reconnect() => {
1177 log::debug!(
1179 "Buffering message during reconnection (buffer size: {})",
1180 reconnect_buffer.len() + 1
1181 );
1182 reconnect_buffer.push_back(msg);
1183 }
1184 WriterCommand::Send(msg) => {
1185 if let Err(e) = active_writer.send(msg.clone()).await {
1186 log::error!("Failed to send message: {e}");
1187 log::warn!("Writer triggering reconnect");
1188 reconnect_buffer.push_back(msg);
1189
1190 if let Some(tracker) = auth_tracker.get() {
1191 tracker.invalidate();
1192 }
1193 connection_state
1194 .store(ConnectionMode::Reconnect.as_u8(), Ordering::SeqCst);
1195 state_notify.notify_one();
1196 }
1197 }
1198 }
1199 }
1200 Ok(None) => {
1201 log::debug!("Writer channel closed, terminating writer task");
1203 break;
1204 }
1205 Err(_) => {
1206 }
1208 }
1209 }
1210
1211 _ = dst::time::timeout(
1214 Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS),
1215 active_writer.close(),
1216 )
1217 .await;
1218
1219 log_task_stopped("write");
1220 })
1221 }
1222
1223 fn spawn_heartbeat_task(
1224 connection_state: Arc<AtomicU8>,
1225 heartbeat_secs: u64,
1226 message: Option<String>,
1227 writer_tx: tokio::sync::mpsc::UnboundedSender<WriterCommand>,
1228 ) -> tokio::task::JoinHandle<()> {
1229 log_task_started("heartbeat");
1230
1231 tokio::task::spawn(async move {
1232 let interval = Duration::from_secs(heartbeat_secs);
1233
1234 loop {
1235 dst::time::sleep(interval).await;
1236
1237 match ConnectionMode::from_u8(connection_state.load(Ordering::SeqCst)) {
1238 ConnectionMode::Active => {
1239 let msg = match &message {
1240 Some(text) => WriterCommand::Send(Message::Text(text.clone().into())),
1241 None => WriterCommand::Send(Message::Ping(vec![].into())),
1242 };
1243
1244 match writer_tx.send(msg) {
1245 Ok(()) => log::trace!("Sent heartbeat to writer task"),
1246 Err(e) => {
1247 log::error!("Failed to send heartbeat to writer task: {e}");
1248 }
1249 }
1250 }
1251 ConnectionMode::Reconnect => {}
1252 ConnectionMode::Disconnect | ConnectionMode::Closed => break,
1253 }
1254 }
1255
1256 log_task_stopped("heartbeat");
1257 })
1258 }
1259}
1260
1261impl Drop for WebSocketClientInner {
1262 fn drop(&mut self) {
1263 self.clean_drop();
1265 }
1266}
1267
1268impl CleanDrop for WebSocketClientInner {
1270 fn clean_drop(&mut self) {
1271 if let Some(ref read_task) = self.read_task.take()
1272 && !read_task.is_finished()
1273 {
1274 read_task.abort();
1275 log_task_aborted("read");
1276 }
1277
1278 if !self.write_task.is_finished() {
1279 self.write_task.abort();
1280 log_task_aborted("write");
1281 }
1282
1283 if let Some(ref handle) = self.heartbeat_task.take()
1284 && !handle.is_finished()
1285 {
1286 handle.abort();
1287 log_task_aborted("heartbeat");
1288 }
1289
1290 self.message_handler = None;
1292 self.ping_handler = None;
1293 }
1294}
1295
1296#[expect(
1297 clippy::missing_fields_in_debug,
1298 reason = "handler closures and internal task handles are intentionally omitted"
1299)]
1300impl Debug for WebSocketClientInner {
1301 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1302 f.debug_struct(stringify!(WebSocketClientInner))
1303 .field("config", &self.config)
1304 .field(
1305 "connection_mode",
1306 &ConnectionMode::from_atomic(&self.connection_mode),
1307 )
1308 .field("reconnect_timeout", &self.reconnect_timeout)
1309 .field("is_stream_mode", &self.is_stream_mode)
1310 .finish()
1311 }
1312}
1313
1314#[cfg_attr(
1319 feature = "python",
1320 pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.network")
1321)]
1322#[cfg_attr(
1323 feature = "python",
1324 pyo3_stub_gen::derive::gen_stub_pyclass(module = "nautilus_trader.network")
1325)]
1326pub struct WebSocketClient {
1327 pub(crate) controller_task: tokio::task::JoinHandle<()>,
1328 pub(crate) connection_mode: Arc<AtomicU8>,
1329 pub(crate) state_notify: Arc<tokio::sync::Notify>,
1330 pub(crate) reconnect_timeout: Duration,
1331 pub(crate) rate_limiter: Arc<RateLimiter<Ustr, MonotonicClock>>,
1332 pub(crate) writer_tx: tokio::sync::mpsc::UnboundedSender<WriterCommand>,
1333 auth_tracker: Arc<OnceLock<AuthTracker>>,
1334 reconnect_buffer_waits_for_auth: Arc<AtomicBool>,
1335}
1336
1337impl Debug for WebSocketClient {
1338 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1339 f.debug_struct(stringify!(WebSocketClient)).finish()
1340 }
1341}
1342
1343impl WebSocketClient {
1344 pub async fn connect_stream(
1360 config: WebSocketConfig,
1361 keyed_quotas: Vec<(String, Quota)>,
1362 default_quota: Option<Quota>,
1363 post_reconnect: Option<Arc<dyn Fn() + Send + Sync>>,
1364 ) -> Result<(MessageReader, Self), TransportError> {
1365 install_cryptographic_provider();
1366
1367 let (writer, reader) = WebSocketClientInner::connect_with_server(
1369 &config.url,
1370 config.headers.clone(),
1371 config.backend,
1372 config.proxy_url.as_deref(),
1373 )
1374 .await?;
1375
1376 let inner = WebSocketClientInner::new_with_writer(config, writer).await?;
1378
1379 let connection_mode = inner.connection_mode.clone();
1380 let state_notify = inner.state_notify.clone();
1381 let reconnect_timeout = inner.reconnect_timeout;
1382 let auth_tracker = Arc::clone(&inner.auth_tracker);
1383 let reconnect_buffer_waits_for_auth = Arc::clone(&inner.reconnect_buffer_waits_for_auth);
1384 let keyed_quotas = keyed_quotas
1385 .into_iter()
1386 .map(|(key, quota)| (Ustr::from(&key), quota))
1387 .collect();
1388 let rate_limiter = Arc::new(RateLimiter::new_with_quota(default_quota, keyed_quotas));
1389 let writer_tx = inner.writer_tx.clone();
1390
1391 let controller_task = Self::spawn_controller_task(
1392 inner,
1393 connection_mode.clone(),
1394 state_notify.clone(),
1395 post_reconnect,
1396 Arc::clone(&auth_tracker),
1397 );
1398
1399 Ok((
1400 reader,
1401 Self {
1402 controller_task,
1403 connection_mode,
1404 state_notify,
1405 reconnect_timeout,
1406 rate_limiter,
1407 writer_tx,
1408 auth_tracker,
1409 reconnect_buffer_waits_for_auth,
1410 },
1411 ))
1412 }
1413
1414 pub async fn connect(
1432 config: WebSocketConfig,
1433 message_handler: Option<MessageHandler>,
1434 ping_handler: Option<PingHandler>,
1435 post_reconnection: Option<Arc<dyn Fn() + Send + Sync>>,
1436 keyed_quotas: Vec<(String, Quota)>,
1437 default_quota: Option<Quota>,
1438 ) -> Result<Self, TransportError> {
1439 if message_handler.is_none() {
1441 return Err(TransportError::Io(std::io::Error::new(
1442 std::io::ErrorKind::InvalidInput,
1443 "Handler mode requires message_handler to be set. Use connect_stream() for stream mode without a handler.",
1444 )));
1445 }
1446
1447 log::debug!("Connecting");
1448 let inner =
1449 WebSocketClientInner::connect_url(config, message_handler, ping_handler).await?;
1450 let connection_mode = inner.connection_mode.clone();
1451 let state_notify = inner.state_notify.clone();
1452 let writer_tx = inner.writer_tx.clone();
1453 let reconnect_timeout = inner.reconnect_timeout;
1454 let auth_tracker = Arc::clone(&inner.auth_tracker);
1455 let reconnect_buffer_waits_for_auth = Arc::clone(&inner.reconnect_buffer_waits_for_auth);
1456
1457 let controller_task = Self::spawn_controller_task(
1458 inner,
1459 connection_mode.clone(),
1460 state_notify.clone(),
1461 post_reconnection,
1462 Arc::clone(&auth_tracker),
1463 );
1464
1465 let keyed_quotas = keyed_quotas
1466 .into_iter()
1467 .map(|(key, quota)| (Ustr::from(&key), quota))
1468 .collect();
1469 let rate_limiter = Arc::new(RateLimiter::new_with_quota(default_quota, keyed_quotas));
1470
1471 Ok(Self {
1472 controller_task,
1473 connection_mode,
1474 state_notify,
1475 reconnect_timeout,
1476 rate_limiter,
1477 writer_tx,
1478 auth_tracker,
1479 reconnect_buffer_waits_for_auth,
1480 })
1481 }
1482
1483 #[must_use]
1485 pub fn connection_mode(&self) -> ConnectionMode {
1486 ConnectionMode::from_atomic(&self.connection_mode)
1487 }
1488
1489 #[must_use]
1494 pub fn connection_mode_atomic(&self) -> Arc<AtomicU8> {
1495 Arc::clone(&self.connection_mode)
1496 }
1497
1498 #[inline]
1503 #[must_use]
1504 pub fn is_active(&self) -> bool {
1505 self.connection_mode().is_active()
1506 }
1507
1508 #[must_use]
1510 pub fn is_disconnected(&self) -> bool {
1511 self.controller_task.is_finished()
1512 }
1513
1514 #[inline]
1519 #[must_use]
1520 pub fn is_reconnecting(&self) -> bool {
1521 self.connection_mode().is_reconnect()
1522 }
1523
1524 pub fn set_auth_tracker(&self, tracker: AuthTracker, reconnect_buffer_waits_for_auth: bool) {
1534 let _ = self.auth_tracker.set(tracker);
1535 self.reconnect_buffer_waits_for_auth
1536 .store(reconnect_buffer_waits_for_auth, Ordering::Release);
1537 }
1538
1539 #[inline]
1543 #[must_use]
1544 pub fn is_disconnecting(&self) -> bool {
1545 self.connection_mode().is_disconnect()
1546 }
1547
1548 #[inline]
1554 #[must_use]
1555 pub fn is_closed(&self) -> bool {
1556 self.connection_mode().is_closed()
1557 }
1558
1559 #[inline]
1563 fn check_not_terminal(&self) -> Result<(), SendError> {
1564 match self.connection_mode() {
1565 ConnectionMode::Disconnect | ConnectionMode::Closed => Err(SendError::Closed),
1566 _ => Ok(()),
1567 }
1568 }
1569
1570 async fn await_rate_limit_or_closed(&self, keys: Option<&[Ustr]>) -> Result<(), SendError> {
1572 const CHECK_INTERVAL_MS: u64 = 100;
1573
1574 tokio::select! {
1575 biased;
1576 () = self.rate_limiter.await_keys_ready(keys) => Ok(()),
1577 () = async {
1578 loop {
1579 let notified = self.state_notify.notified();
1580
1581 if matches!(self.connection_mode(), ConnectionMode::Disconnect | ConnectionMode::Closed) {
1582 break;
1583 }
1584 tokio::select! {
1585 biased;
1586 () = notified => {}
1587 () = dst::time::sleep(Duration::from_millis(CHECK_INTERVAL_MS)) => {}
1588 }
1589 }
1590 } => Err(SendError::Closed),
1591 }
1592 }
1593
1594 async fn wait_for_active(&self) -> Result<(), SendError> {
1600 const FALLBACK_INTERVAL_MS: u64 = 100;
1601
1602 let mode = self.connection_mode();
1603 if mode.is_active() {
1604 return Ok(());
1605 }
1606
1607 if matches!(mode, ConnectionMode::Disconnect | ConnectionMode::Closed) {
1608 return Err(SendError::Closed);
1609 }
1610
1611 log::debug!("Waiting for client to become ACTIVE before sending...");
1612
1613 let fallback_interval = Duration::from_millis(FALLBACK_INTERVAL_MS);
1614
1615 dst::time::timeout(self.reconnect_timeout, async {
1616 loop {
1617 let notified = self.state_notify.notified();
1620
1621 let mode = self.connection_mode();
1622 if mode.is_active() {
1623 return Ok(());
1624 }
1625
1626 if matches!(mode, ConnectionMode::Disconnect | ConnectionMode::Closed) {
1627 return Err(());
1628 }
1629
1630 tokio::select! {
1631 biased;
1632 () = notified => {}
1633 () = dst::time::sleep(fallback_interval) => {}
1634 }
1635 }
1636 })
1637 .await
1638 .map_err(|_| SendError::Timeout)?
1639 .map_err(|()| SendError::Closed)
1640 }
1641
1642 pub fn notify_closed(&self) {
1653 let mode = self.connection_mode();
1654 if mode.is_disconnect() || mode.is_closed() {
1655 return;
1656 }
1657
1658 log::debug!("Stream reader signalled EOF, transitioning to CLOSED");
1659
1660 self.connection_mode
1661 .store(ConnectionMode::Closed.as_u8(), Ordering::SeqCst);
1662 self.state_notify.notify_waiters();
1663 }
1664
1665 pub async fn disconnect(&self) {
1670 log::debug!("Disconnecting");
1671 self.connection_mode
1672 .store(ConnectionMode::Disconnect.as_u8(), Ordering::SeqCst);
1673 self.state_notify.notify_waiters();
1674
1675 if dst::time::timeout(Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS), async {
1676 while !self.is_disconnected() {
1677 dst::time::sleep(Duration::from_millis(CONNECTION_STATE_CHECK_INTERVAL_MS)).await;
1678 }
1679
1680 if !self.controller_task.is_finished() {
1681 self.controller_task.abort();
1682 log_task_aborted("controller");
1683 }
1684 })
1685 .await
1686 == Ok(())
1687 {
1688 log::debug!("Controller task finished");
1689 } else {
1690 log::error!("Timeout waiting for controller task to finish");
1691
1692 if !self.controller_task.is_finished() {
1693 self.controller_task.abort();
1694 log_task_aborted("controller");
1695 }
1696 self.connection_mode
1697 .store(ConnectionMode::Closed.as_u8(), Ordering::SeqCst);
1698 }
1699 }
1700
1701 #[allow(unused_variables)]
1711 pub async fn send_text(&self, data: String, keys: Option<&[Ustr]>) -> Result<(), SendError> {
1712 self.check_not_terminal()?;
1713
1714 self.await_rate_limit_or_closed(keys).await?;
1715 self.wait_for_active().await?;
1716
1717 log::trace!("Sending text: {data:?}");
1718
1719 let msg = Message::Text(data.into());
1720 self.writer_tx
1721 .send(WriterCommand::Send(msg))
1722 .map_err(|e| SendError::BrokenPipe(e.to_string()))
1723 }
1724
1725 pub async fn send_pong(&self, data: Vec<u8>) -> Result<(), SendError> {
1731 self.wait_for_active().await?;
1732
1733 log::trace!("Sending pong frame ({} bytes)", data.len());
1734
1735 let msg = Message::Pong(data.into());
1736 self.writer_tx
1737 .send(WriterCommand::Send(msg))
1738 .map_err(|e| SendError::BrokenPipe(e.to_string()))
1739 }
1740
1741 #[allow(unused_variables)]
1751 pub async fn send_bytes(&self, data: Vec<u8>, keys: Option<&[Ustr]>) -> Result<(), SendError> {
1752 self.check_not_terminal()?;
1753
1754 self.await_rate_limit_or_closed(keys).await?;
1755 self.wait_for_active().await?;
1756
1757 log::trace!("Sending bytes: {data:?}");
1758
1759 let msg = Message::Binary(data.into());
1760 self.writer_tx
1761 .send(WriterCommand::Send(msg))
1762 .map_err(|e| SendError::BrokenPipe(e.to_string()))
1763 }
1764
1765 pub async fn send_close_message(&self) -> Result<(), SendError> {
1771 self.wait_for_active().await?;
1772
1773 let msg = Message::Close(None);
1774 self.writer_tx
1775 .send(WriterCommand::Send(msg))
1776 .map_err(|e| SendError::BrokenPipe(e.to_string()))
1777 }
1778
1779 fn spawn_controller_task(
1780 mut inner: WebSocketClientInner,
1781 connection_mode: Arc<AtomicU8>,
1782 state_notify: Arc<tokio::sync::Notify>,
1783 post_reconnection: Option<Arc<dyn Fn() + Send + Sync>>,
1784 auth_tracker: Arc<OnceLock<AuthTracker>>,
1785 ) -> tokio::task::JoinHandle<()> {
1786 const CONTROLLER_FALLBACK_INTERVAL_MS: u64 = 100;
1787
1788 tokio::task::spawn(async move {
1789 log_task_started("controller");
1790
1791 let fallback_interval = Duration::from_millis(CONTROLLER_FALLBACK_INTERVAL_MS);
1792
1793 loop {
1794 tokio::select! {
1795 biased;
1796 () = state_notify.notified() => {}
1797 () = dst::time::sleep(fallback_interval) => {}
1798 }
1799
1800 let mut mode = ConnectionMode::from_atomic(&connection_mode);
1801
1802 if mode.is_disconnect() {
1803 log::debug!("Disconnecting");
1804
1805 let timeout = Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS);
1806 if dst::time::timeout(timeout, async {
1807 dst::time::sleep(Duration::from_millis(GRACEFUL_SHUTDOWN_DELAY_MS)).await;
1809
1810 if let Some(task) = &inner.read_task
1811 && !task.is_finished()
1812 {
1813 task.abort();
1814 log_task_aborted("read");
1815 }
1816
1817 if let Some(task) = &inner.heartbeat_task
1818 && !task.is_finished()
1819 {
1820 task.abort();
1821 log_task_aborted("heartbeat");
1822 }
1823 })
1824 .await
1825 .is_err()
1826 {
1827 log::error!("Shutdown timed out after {}s", timeout.as_secs());
1828 }
1829
1830 log::debug!("Closed");
1831 break; }
1833
1834 if mode.is_closed() {
1835 log::debug!("Connection closed");
1836 break;
1837 }
1838
1839 if mode.is_active() && !inner.is_alive() {
1840 let target = if inner.is_stream_mode {
1841 ConnectionMode::Closed
1842 } else {
1843 ConnectionMode::Reconnect
1844 };
1845
1846 if connection_mode
1847 .compare_exchange(
1848 ConnectionMode::Active.as_u8(),
1849 target.as_u8(),
1850 Ordering::SeqCst,
1851 Ordering::SeqCst,
1852 )
1853 .is_ok()
1854 {
1855 if let Some(tracker) = auth_tracker.get() {
1856 tracker.invalidate();
1857 }
1858 log::debug!("Detected dead connection, transitioning to {target:?}");
1859 }
1860 mode = ConnectionMode::from_atomic(&connection_mode);
1861 }
1862
1863 if mode.is_reconnect() {
1864 if let Some(max_attempts) = inner.reconnect_max_attempts
1866 && inner.reconnection_attempt_count >= max_attempts
1867 {
1868 log::error!(
1869 "Max reconnection attempts ({max_attempts}) exceeded, transitioning to CLOSED"
1870 );
1871 connection_mode.store(ConnectionMode::Closed.as_u8(), Ordering::SeqCst);
1872 state_notify.notify_waiters();
1873 break;
1874 }
1875
1876 inner.reconnection_attempt_count += 1;
1877 log::debug!(
1878 "Reconnection attempt {} of {}",
1879 inner.reconnection_attempt_count,
1880 inner
1881 .reconnect_max_attempts
1882 .map_or_else(|| "unlimited".to_string(), |m| m.to_string())
1883 );
1884
1885 let reconnect_result = tokio::select! {
1887 biased;
1888 result = inner.reconnect() => Some(result),
1889 () = async {
1890 loop {
1891 state_notify.notified().await;
1892
1893 if ConnectionMode::from_atomic(&connection_mode).is_disconnect() {
1894 break;
1895 }
1896 }
1897 } => None,
1898 };
1899
1900 match reconnect_result {
1901 None => {
1902 log::debug!("Reconnect interrupted by disconnect");
1903 }
1904 Some(Ok(())) => {
1905 inner.backoff.reset();
1906 inner.reconnection_attempt_count = 0;
1907
1908 state_notify.notify_waiters();
1909
1910 if ConnectionMode::from_atomic(&connection_mode).is_active() {
1911 if let Some(ref handler) = inner.message_handler {
1912 let reconnected_msg =
1913 Message::Text(RECONNECTED.to_string().into());
1914 handler(reconnected_msg);
1915 log::debug!("Sent reconnected message to handler");
1916 }
1917
1918 if let Some(ref callback) = post_reconnection {
1920 callback();
1921 log::debug!("Called `post_reconnection` handler");
1922 }
1923
1924 log::debug!("Reconnected successfully");
1925 } else {
1926 log::debug!(
1927 "Skipping post_reconnection handlers due to disconnect state"
1928 );
1929 }
1930 }
1931 Some(Err(e)) => {
1932 let duration = inner.backoff.next_duration();
1933 log::warn!(
1934 "Reconnect attempt {} failed: {e}",
1935 inner.reconnection_attempt_count
1936 );
1937
1938 if !duration.is_zero() {
1939 log::warn!("Backing off for {}s...", duration.as_secs_f64());
1940 tokio::select! {
1942 biased;
1943 () = dst::time::sleep(duration) => {}
1944 () = async {
1945 loop {
1946 state_notify.notified().await;
1947
1948 if ConnectionMode::from_atomic(&connection_mode).is_disconnect() {
1949 break;
1950 }
1951 }
1952 } => {
1953 log::debug!("Backoff interrupted by disconnect");
1954 }
1955 }
1956 }
1957 }
1958 }
1959 }
1960 }
1961 inner
1962 .connection_mode
1963 .store(ConnectionMode::Closed.as_u8(), Ordering::SeqCst);
1964
1965 log_task_stopped("controller");
1966 })
1967 }
1968}
1969
1970impl Drop for WebSocketClient {
1972 fn drop(&mut self) {
1973 if !self.controller_task.is_finished() {
1974 self.controller_task.abort();
1975 log_task_aborted("controller");
1976 }
1977 }
1978}
1979
1980#[cfg(test)]
1981#[cfg(not(feature = "turmoil"))]
1982#[cfg(not(all(feature = "simulation", madsim)))] #[cfg(target_os = "linux")] mod tests {
1985 use std::{num::NonZeroU32, sync::Arc};
1986
1987 use futures_util::{SinkExt, StreamExt};
1988 use tokio::{
1989 net::TcpListener,
1990 task::{self, JoinHandle},
1991 };
1992 use tokio_tungstenite::{
1993 accept_hdr_async,
1994 tungstenite::{
1995 Message as WsMessage,
1996 handshake::server::{self, Callback},
1997 http::HeaderValue,
1998 },
1999 };
2000
2001 use crate::{
2002 ratelimiter::quota::Quota,
2003 websocket::{TransportBackend, WebSocketClient, WebSocketConfig},
2004 };
2005
2006 struct TestServer {
2007 task: JoinHandle<()>,
2008 port: u16,
2009 }
2010
2011 #[derive(Debug, Clone)]
2012 struct TestCallback {
2013 key: String,
2014 value: HeaderValue,
2015 }
2016
2017 impl Callback for TestCallback {
2018 #[expect(clippy::panic_in_result_fn)]
2019 fn on_request(
2020 self,
2021 request: &server::Request,
2022 response: server::Response,
2023 ) -> Result<server::Response, server::ErrorResponse> {
2024 let _ = response;
2025 let value = request.headers().get(&self.key);
2026 assert!(value.is_some());
2027
2028 if let Some(value) = request.headers().get(&self.key) {
2029 assert_eq!(value, self.value);
2030 }
2031
2032 Ok(response)
2033 }
2034 }
2035
2036 impl TestServer {
2037 async fn setup() -> Self {
2038 let server = TcpListener::bind("127.0.0.1:0").await.unwrap();
2039 let port = TcpListener::local_addr(&server).unwrap().port();
2040
2041 let header_key = "test".to_string();
2042 let header_value = "test".to_string();
2043
2044 let test_call_back = TestCallback {
2045 key: header_key,
2046 value: HeaderValue::from_str(&header_value).unwrap(),
2047 };
2048
2049 let task = task::spawn(async move {
2050 loop {
2052 let (conn, _) = server.accept().await.unwrap();
2053 let mut websocket = accept_hdr_async(conn, test_call_back.clone())
2054 .await
2055 .unwrap();
2056
2057 task::spawn(async move {
2058 #[expect(clippy::collapsible_match)]
2060 while let Some(Ok(msg)) = websocket.next().await {
2061 match msg {
2062 WsMessage::Text(txt) if txt == "close-now" => {
2063 log::debug!("Forcibly closing from server side");
2064 let _ = websocket.close(None).await;
2066 break;
2067 }
2068 WsMessage::Text(_) | WsMessage::Binary(_) => {
2070 if websocket.send(msg).await.is_err() {
2071 break;
2072 }
2073 }
2074 WsMessage::Close(_frame) => {
2076 let _ = websocket.close(None).await;
2077 break;
2078 }
2079 _ => {}
2081 }
2082 }
2083 });
2084 }
2085 });
2086
2087 Self { task, port }
2088 }
2089 }
2090
2091 impl Drop for TestServer {
2092 fn drop(&mut self) {
2093 self.task.abort();
2094 }
2095 }
2096
2097 async fn setup_test_client(port: u16) -> WebSocketClient {
2098 let config = WebSocketConfig {
2099 url: format!("ws://127.0.0.1:{port}"),
2100 headers: vec![("test".into(), "test".into())],
2101 heartbeat: None,
2102 heartbeat_msg: None,
2103 reconnect_timeout_ms: None,
2104 reconnect_delay_initial_ms: None,
2105 reconnect_backoff_factor: None,
2106 reconnect_delay_max_ms: None,
2107 reconnect_jitter_ms: None,
2108 reconnect_max_attempts: None,
2109 idle_timeout_ms: None,
2110 backend: TransportBackend::Tungstenite,
2111 proxy_url: None,
2112 };
2113 WebSocketClient::connect(config, Some(Arc::new(|_| {})), None, None, vec![], None)
2114 .await
2115 .expect("Failed to connect")
2116 }
2117
2118 #[tokio::test]
2119 async fn test_websocket_basic() {
2120 let server = TestServer::setup().await;
2121 let client = setup_test_client(server.port).await;
2122
2123 assert!(!client.is_disconnected());
2124
2125 client.disconnect().await;
2126 assert!(client.is_disconnected());
2127 }
2128
2129 #[tokio::test]
2130 async fn test_websocket_heartbeat() {
2131 let server = TestServer::setup().await;
2132 let client = setup_test_client(server.port).await;
2133
2134 tokio::time::sleep(std::time::Duration::from_secs(3)).await;
2136
2137 client.disconnect().await;
2139 assert!(client.is_disconnected());
2140 }
2141
2142 #[tokio::test]
2143 async fn test_websocket_reconnect_exhausted() {
2144 let config = WebSocketConfig {
2145 url: "ws://127.0.0.1:9997".into(), headers: vec![],
2147 heartbeat: None,
2148 heartbeat_msg: None,
2149 reconnect_timeout_ms: None,
2150 reconnect_delay_initial_ms: None,
2151 reconnect_backoff_factor: None,
2152 reconnect_delay_max_ms: None,
2153 reconnect_jitter_ms: None,
2154 reconnect_max_attempts: None,
2155 idle_timeout_ms: None,
2156 backend: TransportBackend::Tungstenite,
2157 proxy_url: None,
2158 };
2159 let res =
2160 WebSocketClient::connect(config, Some(Arc::new(|_| {})), None, None, vec![], None)
2161 .await;
2162 assert!(res.is_err(), "Should fail quickly with no server");
2163 }
2164
2165 #[tokio::test]
2166 async fn test_websocket_forced_close_reconnect() {
2167 let server = TestServer::setup().await;
2168 let client = setup_test_client(server.port).await;
2169
2170 client.send_text("Hello".into(), None).await.unwrap();
2172
2173 client.send_text("close-now".into(), None).await.unwrap();
2175
2176 tokio::time::sleep(std::time::Duration::from_secs(1)).await;
2178
2179 assert!(!client.is_disconnected());
2181
2182 client.disconnect().await;
2184 assert!(client.is_disconnected());
2185 }
2186
2187 #[tokio::test]
2188 async fn test_rate_limiter() {
2189 let server = TestServer::setup().await;
2190 let quota = Quota::per_second(NonZeroU32::new(2).unwrap()).unwrap();
2191
2192 let config = WebSocketConfig {
2193 url: format!("ws://127.0.0.1:{}", server.port),
2194 headers: vec![("test".into(), "test".into())],
2195 heartbeat: None,
2196 heartbeat_msg: None,
2197 reconnect_timeout_ms: None,
2198 reconnect_delay_initial_ms: None,
2199 reconnect_backoff_factor: None,
2200 reconnect_delay_max_ms: None,
2201 reconnect_jitter_ms: None,
2202 reconnect_max_attempts: None,
2203 idle_timeout_ms: None,
2204 backend: TransportBackend::Tungstenite,
2205 proxy_url: None,
2206 };
2207
2208 let client = WebSocketClient::connect(
2209 config,
2210 Some(Arc::new(|_| {})),
2211 None,
2212 None,
2213 vec![("default".into(), quota)],
2214 None,
2215 )
2216 .await
2217 .unwrap();
2218
2219 client.send_text("test1".into(), None).await.unwrap();
2221 client.send_text("test2".into(), None).await.unwrap();
2222
2223 client.send_text("test3".into(), None).await.unwrap();
2225
2226 client.disconnect().await;
2228 assert!(client.is_disconnected());
2229 }
2230
2231 #[tokio::test]
2232 async fn test_concurrent_writers() {
2233 let server = TestServer::setup().await;
2234 let client = Arc::new(setup_test_client(server.port).await);
2235
2236 let mut handles = vec![];
2237
2238 for i in 0..10 {
2239 let client = client.clone();
2240 handles.push(task::spawn(async move {
2241 client.send_text(format!("test{i}"), None).await.unwrap();
2242 }));
2243 }
2244
2245 for handle in handles {
2246 handle.await.unwrap();
2247 }
2248
2249 client.disconnect().await;
2251 assert!(client.is_disconnected());
2252 }
2253}
2254
2255#[cfg(test)]
2256#[cfg(not(feature = "turmoil"))]
2257#[cfg(not(all(feature = "simulation", madsim)))] mod rust_tests {
2259 use std::sync::{
2260 Arc, OnceLock,
2261 atomic::{AtomicBool, AtomicU8, Ordering},
2262 };
2263
2264 use futures_util::{SinkExt, StreamExt};
2265 use nautilus_common::testing::wait_until_async;
2266 use rstest::rstest;
2267 #[cfg(feature = "transport-sockudo")]
2268 use sockudo_ws::handshake as sockudo_handshake;
2269 #[cfg(feature = "transport-sockudo")]
2270 use tokio::io::{AsyncRead, AsyncReadExt, AsyncWriteExt};
2271 use tokio::{
2272 net::TcpListener,
2273 task::{self, JoinHandle},
2274 time::{Duration, sleep},
2275 };
2276 use tokio_tungstenite::{accept_async, tungstenite::Message as WsMessage};
2277 #[cfg(feature = "transport-sockudo")]
2278 use tokio_tungstenite::{
2279 accept_hdr_async,
2280 tungstenite::{
2281 handshake::server::{self, Callback},
2282 http::HeaderValue,
2283 },
2284 };
2285
2286 use super::*;
2287 use crate::websocket::types::channel_message_handler;
2288
2289 struct RecordingServer {
2290 task: JoinHandle<()>,
2291 port: u16,
2292 messages: Arc<tokio::sync::Mutex<Vec<String>>>,
2293 }
2294
2295 #[cfg(feature = "transport-sockudo")]
2296 async fn read_http_request<S>(stream: &mut S) -> Vec<u8>
2297 where
2298 S: AsyncRead + Unpin,
2299 {
2300 let mut buf = Vec::new();
2301 let mut chunk = [0u8; 256];
2302
2303 loop {
2304 let n = stream.read(&mut chunk).await.unwrap();
2305 assert!(n > 0, "HTTP request closed before headers completed");
2306 buf.extend_from_slice(&chunk[..n]);
2307 if buf.windows(4).any(|window| window == b"\r\n\r\n") {
2308 return buf;
2309 }
2310 }
2311 }
2312
2313 #[cfg(feature = "transport-sockudo")]
2314 fn extract_header<'a>(request: &'a str, name: &str) -> Option<&'a str> {
2315 request.lines().find_map(|line| {
2316 let (header_name, header_value) = line.split_once(':')?;
2317 if header_name.eq_ignore_ascii_case(name) {
2318 Some(header_value.trim())
2319 } else {
2320 None
2321 }
2322 })
2323 }
2324
2325 #[cfg(feature = "transport-sockudo")]
2326 #[derive(Debug, Clone)]
2327 struct HeaderAssertCallback {
2328 key: String,
2329 value: HeaderValue,
2330 }
2331
2332 #[cfg(feature = "transport-sockudo")]
2333 impl Callback for HeaderAssertCallback {
2334 #[expect(
2335 clippy::panic_in_result_fn,
2336 reason = "assertion failures should fail the test"
2337 )]
2338 fn on_request(
2339 self,
2340 request: &server::Request,
2341 response: server::Response,
2342 ) -> Result<server::Response, server::ErrorResponse> {
2343 assert_eq!(request.headers().get(&self.key), Some(&self.value));
2344 Ok(response)
2345 }
2346 }
2347
2348 impl RecordingServer {
2349 async fn setup() -> Self {
2350 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
2351 let port = listener.local_addr().unwrap().port();
2352 let messages = Arc::new(tokio::sync::Mutex::new(Vec::new()));
2353 let messages_clone = Arc::clone(&messages);
2354
2355 let task = task::spawn(async move {
2356 loop {
2357 let (stream, _) = listener.accept().await.unwrap();
2358 let mut websocket = accept_async(stream).await.unwrap();
2359 let messages = Arc::clone(&messages_clone);
2360
2361 task::spawn(async move {
2362 while let Some(Ok(msg)) = websocket.next().await {
2363 match msg {
2364 WsMessage::Text(text) => {
2365 messages.lock().await.push(text.to_string());
2366 }
2367 WsMessage::Close(_) => {
2368 let _ = websocket.close(None).await;
2369 break;
2370 }
2371 _ => {}
2372 }
2373 }
2374 });
2375 }
2376 });
2377
2378 Self {
2379 task,
2380 port,
2381 messages,
2382 }
2383 }
2384
2385 async fn messages(&self) -> Vec<String> {
2386 self.messages.lock().await.clone()
2387 }
2388 }
2389
2390 impl Drop for RecordingServer {
2391 fn drop(&mut self) {
2392 self.task.abort();
2393 }
2394 }
2395
2396 #[rstest]
2397 #[tokio::test]
2398 async fn test_reconnect_then_disconnect() {
2399 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
2401 let port = listener.local_addr().unwrap().port();
2402
2403 let server = task::spawn(async move {
2405 let (stream, _) = listener.accept().await.unwrap();
2406 let ws = accept_async(stream).await.unwrap();
2407 drop(ws);
2408 sleep(Duration::from_secs(1)).await;
2410 });
2411
2412 let (handler, _rx) = channel_message_handler();
2414
2415 let config = WebSocketConfig {
2417 url: format!("ws://127.0.0.1:{port}"),
2418 headers: vec![],
2419 heartbeat: None,
2420 heartbeat_msg: None,
2421 reconnect_timeout_ms: Some(1_000),
2422 reconnect_delay_initial_ms: Some(50),
2423 reconnect_delay_max_ms: Some(100),
2424 reconnect_backoff_factor: Some(1.0),
2425 reconnect_jitter_ms: Some(0),
2426 reconnect_max_attempts: None,
2427 idle_timeout_ms: None,
2428 backend: TransportBackend::Tungstenite,
2429 proxy_url: None,
2430 };
2431
2432 let client = WebSocketClient::connect(config, Some(handler), None, None, vec![], None)
2434 .await
2435 .unwrap();
2436
2437 sleep(Duration::from_millis(100)).await;
2439 client.disconnect().await;
2441 assert!(client.is_disconnected());
2442 server.abort();
2443 }
2444
2445 #[rstest]
2446 #[tokio::test]
2447 async fn test_reconnect_state_flips_when_reader_stops() {
2448 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
2450 let port = listener.local_addr().unwrap().port();
2451
2452 let server = task::spawn(async move {
2453 if let Ok((stream, _)) = listener.accept().await
2454 && let Ok(ws) = accept_async(stream).await
2455 {
2456 drop(ws);
2457 }
2458 sleep(Duration::from_millis(50)).await;
2459 });
2460
2461 let (handler, _rx) = channel_message_handler();
2462
2463 let config = WebSocketConfig {
2464 url: format!("ws://127.0.0.1:{port}"),
2465 headers: vec![],
2466 heartbeat: None,
2467 heartbeat_msg: None,
2468 reconnect_timeout_ms: Some(1_000),
2469 reconnect_delay_initial_ms: Some(50),
2470 reconnect_delay_max_ms: Some(100),
2471 reconnect_backoff_factor: Some(1.0),
2472 reconnect_jitter_ms: Some(0),
2473 reconnect_max_attempts: None,
2474 idle_timeout_ms: None,
2475 backend: TransportBackend::Tungstenite,
2476 proxy_url: None,
2477 };
2478
2479 let client = WebSocketClient::connect(config, Some(handler), None, None, vec![], None)
2480 .await
2481 .unwrap();
2482
2483 tokio::time::timeout(Duration::from_secs(2), async {
2484 loop {
2485 if client.is_reconnecting() {
2486 break;
2487 }
2488 tokio::time::sleep(Duration::from_millis(10)).await;
2489 }
2490 })
2491 .await
2492 .expect("client did not enter RECONNECT state");
2493
2494 client.disconnect().await;
2495 server.abort();
2496 }
2497
2498 #[rstest]
2499 #[tokio::test]
2500 async fn test_stream_mode_disables_auto_reconnect() {
2501 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
2504 let port = listener.local_addr().unwrap().port();
2505
2506 let server = task::spawn(async move {
2507 if let Ok((stream, _)) = listener.accept().await
2508 && let Ok(_ws) = accept_async(stream).await
2509 {
2510 sleep(Duration::from_millis(100)).await;
2512 }
2513 });
2514
2515 let config = WebSocketConfig {
2516 url: format!("ws://127.0.0.1:{port}"),
2517 headers: vec![],
2518 heartbeat: None,
2519 heartbeat_msg: None,
2520 reconnect_timeout_ms: Some(1_000),
2521 reconnect_delay_initial_ms: Some(50),
2522 reconnect_delay_max_ms: Some(100),
2523 reconnect_backoff_factor: Some(1.0),
2524 reconnect_jitter_ms: Some(0),
2525 reconnect_max_attempts: None,
2526 idle_timeout_ms: None,
2527 backend: TransportBackend::Tungstenite,
2528 proxy_url: None,
2529 };
2530
2531 let (_reader, _client) = WebSocketClient::connect_stream(config, vec![], None, None)
2532 .await
2533 .unwrap();
2534
2535 server.abort();
2543 }
2544
2545 #[rstest]
2546 #[tokio::test]
2547 async fn test_message_handler_mode_allows_auto_reconnect() {
2548 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
2550 let port = listener.local_addr().unwrap().port();
2551
2552 let server = task::spawn(async move {
2553 if let Ok((stream, _)) = listener.accept().await
2555 && let Ok(ws) = accept_async(stream).await
2556 {
2557 drop(ws);
2558 }
2559 sleep(Duration::from_millis(50)).await;
2560 });
2561
2562 let (handler, _rx) = channel_message_handler();
2563
2564 let config = WebSocketConfig {
2565 url: format!("ws://127.0.0.1:{port}"),
2566 headers: vec![],
2567 heartbeat: None,
2568 heartbeat_msg: None,
2569 reconnect_timeout_ms: Some(1_000),
2570 reconnect_delay_initial_ms: Some(50),
2571 reconnect_delay_max_ms: Some(100),
2572 reconnect_backoff_factor: Some(1.0),
2573 reconnect_jitter_ms: Some(0),
2574 reconnect_max_attempts: None,
2575 idle_timeout_ms: None,
2576 backend: TransportBackend::Tungstenite,
2577 proxy_url: None,
2578 };
2579
2580 let client = WebSocketClient::connect(config, Some(handler), None, None, vec![], None)
2581 .await
2582 .unwrap();
2583
2584 tokio::time::timeout(Duration::from_secs(2), async {
2586 loop {
2587 if client.is_reconnecting() || client.is_closed() {
2588 break;
2589 }
2590 tokio::time::sleep(Duration::from_millis(10)).await;
2591 }
2592 })
2593 .await
2594 .expect("client should attempt reconnection or close");
2595
2596 assert!(
2599 client.is_reconnecting() || client.is_closed(),
2600 "Client with message handler should attempt reconnection"
2601 );
2602
2603 client.disconnect().await;
2604 server.abort();
2605 }
2606
2607 #[rstest]
2608 #[tokio::test]
2609 async fn test_handler_mode_reconnect_with_new_connection() {
2610 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
2612 let port = listener.local_addr().unwrap().port();
2613
2614 let server = task::spawn(async move {
2615 if let Ok((stream, _)) = listener.accept().await
2617 && let Ok(ws) = accept_async(stream).await
2618 {
2619 drop(ws);
2620 }
2621
2622 sleep(Duration::from_millis(100)).await;
2624
2625 if let Ok((stream, _)) = listener.accept().await
2627 && let Ok(mut ws) = accept_async(stream).await
2628 {
2629 use futures_util::SinkExt;
2630 let _ = ws
2631 .send(WsMessage::Text("reconnected".to_string().into()))
2632 .await;
2633 sleep(Duration::from_secs(1)).await;
2634 }
2635 });
2636
2637 let (handler, mut rx) = channel_message_handler();
2638
2639 let config = WebSocketConfig {
2640 url: format!("ws://127.0.0.1:{port}"),
2641 headers: vec![],
2642 heartbeat: None,
2643 heartbeat_msg: None,
2644 reconnect_timeout_ms: Some(2_000),
2645 reconnect_delay_initial_ms: Some(50),
2646 reconnect_delay_max_ms: Some(200),
2647 reconnect_backoff_factor: Some(1.5),
2648 reconnect_jitter_ms: Some(10),
2649 reconnect_max_attempts: None,
2650 idle_timeout_ms: None,
2651 backend: TransportBackend::Tungstenite,
2652 proxy_url: None,
2653 };
2654
2655 let client = WebSocketClient::connect(config, Some(handler), None, None, vec![], None)
2656 .await
2657 .unwrap();
2658
2659 let result = tokio::time::timeout(Duration::from_secs(5), async {
2661 loop {
2662 if let Ok(msg) = rx.try_recv()
2663 && matches!(msg, WsMessage::Text(ref text) if AsRef::<str>::as_ref(text) == "reconnected")
2664 {
2665 return true;
2666 }
2667 tokio::time::sleep(Duration::from_millis(10)).await;
2668 }
2669 })
2670 .await;
2671
2672 assert!(
2673 result.is_ok(),
2674 "Should receive message after reconnection within timeout"
2675 );
2676
2677 client.disconnect().await;
2678 server.abort();
2679 }
2680
2681 #[rstest]
2682 #[tokio::test]
2683 async fn test_stream_mode_no_auto_reconnect() {
2684 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
2687 let port = listener.local_addr().unwrap().port();
2688
2689 let server = task::spawn(async move {
2690 if let Ok((stream, _)) = listener.accept().await
2692 && let Ok(mut ws) = accept_async(stream).await
2693 {
2694 use futures_util::SinkExt;
2695 let _ = ws.send(WsMessage::Text("hello".to_string().into())).await;
2696 sleep(Duration::from_millis(50)).await;
2697 }
2699 });
2700
2701 let config = WebSocketConfig {
2702 url: format!("ws://127.0.0.1:{port}"),
2703 headers: vec![],
2704 heartbeat: None,
2705 heartbeat_msg: None,
2706 reconnect_timeout_ms: Some(1_000),
2707 reconnect_delay_initial_ms: Some(50),
2708 reconnect_delay_max_ms: Some(100),
2709 reconnect_backoff_factor: Some(1.0),
2710 reconnect_jitter_ms: Some(0),
2711 reconnect_max_attempts: None,
2712 idle_timeout_ms: None,
2713 backend: TransportBackend::Tungstenite,
2714 proxy_url: None,
2715 };
2716
2717 let (mut reader, client) = WebSocketClient::connect_stream(config, vec![], None, None)
2718 .await
2719 .unwrap();
2720
2721 assert!(client.is_active(), "Client should start as active");
2723
2724 let msg = reader.next().await;
2726 assert!(
2727 matches!(&msg, Some(Ok(Message::Text(bytes))) if bytes.as_ref() == b"hello"),
2728 "Should receive initial message"
2729 );
2730
2731 while let Some(msg) = reader.next().await {
2733 if msg.is_err() || matches!(msg, Ok(Message::Close(_))) {
2734 break;
2735 }
2736 }
2737
2738 sleep(Duration::from_millis(200)).await;
2741 assert!(
2742 client.is_active(),
2743 "Stream mode client stays ACTIVE before notify_closed()"
2744 );
2745
2746 client.notify_closed();
2748
2749 assert!(
2750 client.is_closed(),
2751 "Stream mode client should be CLOSED after notify_closed()"
2752 );
2753 assert!(
2754 !client.is_reconnecting(),
2755 "Stream mode client should never attempt reconnection"
2756 );
2757
2758 client.disconnect().await;
2759 server.abort();
2760 }
2761
2762 #[rstest]
2763 #[tokio::test]
2764 async fn test_send_timeout_uses_configured_reconnect_timeout() {
2765 use nautilus_common::testing::wait_until_async;
2768
2769 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
2770 let port = listener.local_addr().unwrap().port();
2771
2772 let server = task::spawn(async move {
2773 if let Ok((stream, _)) = listener.accept().await
2775 && let Ok(ws) = accept_async(stream).await
2776 {
2777 drop(ws);
2778 }
2779 sleep(Duration::from_mins(1)).await;
2781 });
2782
2783 let (handler, _rx) = channel_message_handler();
2784
2785 let config = WebSocketConfig {
2787 url: format!("ws://127.0.0.1:{port}"),
2788 headers: vec![],
2789 heartbeat: None,
2790 heartbeat_msg: None,
2791 reconnect_timeout_ms: Some(2_000), reconnect_delay_initial_ms: Some(50),
2793 reconnect_delay_max_ms: Some(100),
2794 reconnect_backoff_factor: Some(1.0),
2795 reconnect_jitter_ms: Some(0),
2796 reconnect_max_attempts: None,
2797 idle_timeout_ms: None,
2798 backend: TransportBackend::Tungstenite,
2799 proxy_url: None,
2800 };
2801
2802 let client = WebSocketClient::connect(config, Some(handler), None, None, vec![], None)
2803 .await
2804 .unwrap();
2805
2806 wait_until_async(
2808 || async { client.is_reconnecting() },
2809 Duration::from_secs(3),
2810 )
2811 .await;
2812
2813 let start = std::time::Instant::now();
2815 let send_result = client.send_text("test".to_string(), None).await;
2816 let elapsed = start.elapsed();
2817
2818 assert!(
2819 send_result.is_err(),
2820 "Send should fail when client stuck in RECONNECT"
2821 );
2822 assert!(
2823 matches!(send_result, Err(crate::error::SendError::Timeout)),
2824 "Send should return Timeout error, was: {send_result:?}"
2825 );
2826 assert!(
2829 elapsed >= Duration::from_millis(1800),
2830 "Send should timeout after at least 2s (configured timeout), took {elapsed:?}"
2831 );
2832
2833 client.disconnect().await;
2834 server.abort();
2835 }
2836
2837 #[rstest]
2838 #[tokio::test]
2839 async fn test_send_waits_during_reconnection() {
2840 use nautilus_common::testing::wait_until_async;
2842
2843 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
2844 let port = listener.local_addr().unwrap().port();
2845
2846 let server = task::spawn(async move {
2847 if let Ok((stream, _)) = listener.accept().await
2849 && let Ok(ws) = accept_async(stream).await
2850 {
2851 drop(ws);
2852 }
2853
2854 sleep(Duration::from_millis(500)).await;
2856
2857 if let Ok((stream, _)) = listener.accept().await
2859 && let Ok(mut ws) = accept_async(stream).await
2860 {
2861 while let Some(Ok(msg)) = ws.next().await {
2863 if ws.send(msg).await.is_err() {
2864 break;
2865 }
2866 }
2867 }
2868 });
2869
2870 let (handler, _rx) = channel_message_handler();
2871
2872 let config = WebSocketConfig {
2873 url: format!("ws://127.0.0.1:{port}"),
2874 headers: vec![],
2875 heartbeat: None,
2876 heartbeat_msg: None,
2877 reconnect_timeout_ms: Some(5_000), reconnect_delay_initial_ms: Some(100),
2879 reconnect_delay_max_ms: Some(200),
2880 reconnect_backoff_factor: Some(1.0),
2881 reconnect_jitter_ms: Some(0),
2882 reconnect_max_attempts: None,
2883 idle_timeout_ms: None,
2884 backend: TransportBackend::Tungstenite,
2885 proxy_url: None,
2886 };
2887
2888 let client = WebSocketClient::connect(config, Some(handler), None, None, vec![], None)
2889 .await
2890 .unwrap();
2891
2892 wait_until_async(
2894 || async { client.is_reconnecting() },
2895 Duration::from_secs(2),
2896 )
2897 .await;
2898
2899 let send_result = tokio::time::timeout(
2901 Duration::from_secs(3),
2902 client.send_text("test_message".to_string(), None),
2903 )
2904 .await;
2905
2906 assert!(
2907 send_result.is_ok() && send_result.unwrap().is_ok(),
2908 "Send should succeed after waiting for reconnection"
2909 );
2910
2911 client.disconnect().await;
2912 server.abort();
2913 }
2914
2915 #[rstest]
2916 #[tokio::test]
2917 async fn test_rate_limiter_before_active_wait() {
2918 use std::{num::NonZeroU32, sync::Arc};
2923
2924 use nautilus_common::testing::wait_until_async;
2925
2926 use crate::ratelimiter::quota::Quota;
2927
2928 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
2929 let port = listener.local_addr().unwrap().port();
2930
2931 let server = task::spawn(async move {
2932 if let Ok((stream, _)) = listener.accept().await
2934 && let Ok(mut ws) = accept_async(stream).await
2935 {
2936 if let Some(Ok(_)) = ws.next().await {
2938 drop(ws);
2939 }
2940 }
2941
2942 sleep(Duration::from_millis(500)).await;
2944
2945 if let Ok((stream, _)) = listener.accept().await
2947 && let Ok(mut ws) = accept_async(stream).await
2948 {
2949 while let Some(Ok(msg)) = ws.next().await {
2950 if ws.send(msg).await.is_err() {
2951 break;
2952 }
2953 }
2954 }
2955 });
2956
2957 let (handler, _rx) = channel_message_handler();
2958
2959 let config = WebSocketConfig {
2960 url: format!("ws://127.0.0.1:{port}"),
2961 headers: vec![],
2962 heartbeat: None,
2963 heartbeat_msg: None,
2964 reconnect_timeout_ms: Some(5_000),
2965 reconnect_delay_initial_ms: Some(50),
2966 reconnect_delay_max_ms: Some(100),
2967 reconnect_backoff_factor: Some(1.0),
2968 reconnect_jitter_ms: Some(0),
2969 reconnect_max_attempts: None,
2970 idle_timeout_ms: None,
2971 backend: TransportBackend::Tungstenite,
2972 proxy_url: None,
2973 };
2974
2975 let quota = Quota::per_second(NonZeroU32::new(1).unwrap())
2977 .unwrap()
2978 .allow_burst(NonZeroU32::new(1).unwrap());
2979
2980 let client = Arc::new(
2981 WebSocketClient::connect(
2982 config,
2983 Some(handler),
2984 None,
2985 None,
2986 vec![("test_key".to_string(), quota)],
2987 None,
2988 )
2989 .await
2990 .unwrap(),
2991 );
2992
2993 let test_key: [Ustr; 1] = [Ustr::from("test_key")];
2995 client
2996 .send_text("msg1".to_string(), Some(test_key.as_slice()))
2997 .await
2998 .unwrap();
2999
3000 wait_until_async(
3002 || async { client.is_reconnecting() },
3003 Duration::from_secs(2),
3004 )
3005 .await;
3006
3007 let start = std::time::Instant::now();
3009 let send_result = client
3010 .send_text("msg2".to_string(), Some(test_key.as_slice()))
3011 .await;
3012 let elapsed = start.elapsed();
3013
3014 assert!(
3016 send_result.is_ok(),
3017 "Send should succeed after rate limit + reconnection, was: {send_result:?}"
3018 );
3019 assert!(
3023 elapsed >= Duration::from_millis(850),
3024 "Should wait for rate limit (~1s), waited {elapsed:?}"
3025 );
3026
3027 client.disconnect().await;
3028 server.abort();
3029 }
3030
3031 #[rstest]
3032 #[tokio::test]
3033 async fn test_disconnect_during_reconnect_exits_cleanly() {
3034 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
3037 let port = listener.local_addr().unwrap().port();
3038
3039 let server = task::spawn(async move {
3040 if let Ok((stream, _)) = listener.accept().await
3042 && let Ok(ws) = accept_async(stream).await
3043 {
3044 drop(ws);
3045 }
3046 sleep(Duration::from_mins(1)).await;
3048 });
3049
3050 let (handler, _rx) = channel_message_handler();
3051
3052 let config = WebSocketConfig {
3053 url: format!("ws://127.0.0.1:{port}"),
3054 headers: vec![],
3055 heartbeat: None,
3056 heartbeat_msg: None,
3057 reconnect_timeout_ms: Some(2_000), reconnect_delay_initial_ms: Some(100),
3059 reconnect_delay_max_ms: Some(200),
3060 reconnect_backoff_factor: Some(1.0),
3061 reconnect_jitter_ms: Some(0),
3062 reconnect_max_attempts: None,
3063 idle_timeout_ms: None,
3064 backend: TransportBackend::Tungstenite,
3065 proxy_url: None,
3066 };
3067
3068 let client = WebSocketClient::connect(config, Some(handler), None, None, vec![], None)
3069 .await
3070 .unwrap();
3071
3072 tokio::time::timeout(Duration::from_secs(2), async {
3074 while !client.is_reconnecting() {
3075 sleep(Duration::from_millis(10)).await;
3076 }
3077 })
3078 .await
3079 .expect("Client should enter RECONNECT state");
3080
3081 client.disconnect().await;
3083
3084 assert!(
3086 client.is_disconnected(),
3087 "Client should be cleanly disconnected"
3088 );
3089
3090 server.abort();
3091 }
3092
3093 #[rstest]
3094 #[tokio::test]
3095 async fn test_send_fails_fast_when_closed_before_rate_limit() {
3096 use std::{num::NonZeroU32, sync::Arc};
3099
3100 use nautilus_common::testing::wait_until_async;
3101
3102 use crate::ratelimiter::quota::Quota;
3103
3104 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
3105 let port = listener.local_addr().unwrap().port();
3106
3107 let server = task::spawn(async move {
3108 if let Ok((stream, _)) = listener.accept().await
3110 && let Ok(ws) = accept_async(stream).await
3111 {
3112 drop(ws);
3113 }
3114 sleep(Duration::from_mins(1)).await;
3115 });
3116
3117 let (handler, _rx) = channel_message_handler();
3118
3119 let config = WebSocketConfig {
3120 url: format!("ws://127.0.0.1:{port}"),
3121 headers: vec![],
3122 heartbeat: None,
3123 heartbeat_msg: None,
3124 reconnect_timeout_ms: Some(5_000),
3125 reconnect_delay_initial_ms: Some(50),
3126 reconnect_delay_max_ms: Some(100),
3127 reconnect_backoff_factor: Some(1.0),
3128 reconnect_jitter_ms: Some(0),
3129 reconnect_max_attempts: None,
3130 idle_timeout_ms: None,
3131 backend: TransportBackend::Tungstenite,
3132 proxy_url: None,
3133 };
3134
3135 let quota = Quota::with_period(Duration::from_secs(10))
3138 .unwrap()
3139 .allow_burst(NonZeroU32::new(1).unwrap());
3140
3141 let client = Arc::new(
3142 WebSocketClient::connect(
3143 config,
3144 Some(handler),
3145 None,
3146 None,
3147 vec![("test_key".to_string(), quota)],
3148 None,
3149 )
3150 .await
3151 .unwrap(),
3152 );
3153
3154 wait_until_async(
3156 || async { client.is_reconnecting() || client.is_closed() },
3157 Duration::from_secs(2),
3158 )
3159 .await;
3160
3161 client.disconnect().await;
3163 assert!(
3164 !client.is_active(),
3165 "Client should not be active after disconnect"
3166 );
3167
3168 let start = std::time::Instant::now();
3170 let test_key: [Ustr; 1] = [Ustr::from("test_key")];
3171 let result = client
3172 .send_text("test".to_string(), Some(test_key.as_slice()))
3173 .await;
3174 let elapsed = start.elapsed();
3175
3176 assert!(result.is_err(), "Send should fail when client is closed");
3178 assert!(
3179 matches!(result, Err(crate::error::SendError::Closed)),
3180 "Send should return Closed error, was: {result:?}"
3181 );
3182
3183 assert!(
3185 elapsed < Duration::from_millis(100),
3186 "Send should fail fast without rate limiting, took {elapsed:?}"
3187 );
3188
3189 server.abort();
3190 }
3191
3192 #[rstest]
3193 #[tokio::test]
3194 async fn test_connect_rejects_none_message_handler() {
3195 let config = WebSocketConfig {
3199 url: "ws://127.0.0.1:9999".to_string(),
3200 headers: vec![],
3201 heartbeat: None,
3202 heartbeat_msg: None,
3203 reconnect_timeout_ms: Some(1_000),
3204 reconnect_delay_initial_ms: Some(100),
3205 reconnect_delay_max_ms: Some(500),
3206 reconnect_backoff_factor: Some(1.5),
3207 reconnect_jitter_ms: Some(0),
3208 reconnect_max_attempts: None,
3209 idle_timeout_ms: None,
3210 backend: TransportBackend::Tungstenite,
3211 proxy_url: None,
3212 };
3213
3214 let result = WebSocketClient::connect(config, None, None, None, vec![], None).await;
3216
3217 assert!(
3218 result.is_err(),
3219 "connect() should reject None message_handler"
3220 );
3221
3222 let err = result.unwrap_err();
3223 let err_msg = err.to_string();
3224 assert!(
3225 err_msg.contains("Handler mode requires message_handler"),
3226 "Error should mention missing message_handler, was: {err_msg}"
3227 );
3228 }
3229
3230 #[rstest]
3231 #[tokio::test]
3232 async fn test_client_without_handler_sets_stream_mode() {
3233 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
3237 let port = listener.local_addr().unwrap().port();
3238
3239 let server = task::spawn(async move {
3240 if let Ok((stream, _)) = listener.accept().await
3242 && let Ok(ws) = accept_async(stream).await
3243 {
3244 drop(ws); }
3246 });
3247
3248 let config = WebSocketConfig {
3249 url: format!("ws://127.0.0.1:{port}"),
3250 headers: vec![],
3251 heartbeat: None,
3252 heartbeat_msg: None,
3253 reconnect_timeout_ms: Some(1_000),
3254 reconnect_delay_initial_ms: Some(100),
3255 reconnect_delay_max_ms: Some(500),
3256 reconnect_backoff_factor: Some(1.5),
3257 reconnect_jitter_ms: Some(0),
3258 reconnect_max_attempts: None,
3259 idle_timeout_ms: None,
3260 backend: TransportBackend::Tungstenite,
3261 proxy_url: None,
3262 };
3263
3264 let inner = WebSocketClientInner::connect_url(config, None, None)
3266 .await
3267 .unwrap();
3268
3269 assert!(
3271 inner.is_stream_mode,
3272 "Client without handler should have is_stream_mode=true"
3273 );
3274
3275 server.abort();
3279 }
3280
3281 #[rstest]
3282 #[tokio::test]
3283 async fn test_idle_timeout_triggers_reconnect() {
3284 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
3285 let port = listener.local_addr().unwrap().port();
3286
3287 let server = task::spawn(async move {
3289 let (stream, _) = listener.accept().await.unwrap();
3290 let _ws = accept_async(stream).await.unwrap();
3291 sleep(Duration::from_secs(5)).await;
3293 });
3294
3295 let (handler, _rx) = channel_message_handler();
3296
3297 let config = WebSocketConfig {
3298 url: format!("ws://127.0.0.1:{port}"),
3299 headers: vec![],
3300 heartbeat: None,
3301 heartbeat_msg: None,
3302 reconnect_timeout_ms: Some(2_000),
3303 reconnect_delay_initial_ms: Some(50),
3304 reconnect_delay_max_ms: Some(100),
3305 reconnect_backoff_factor: Some(1.0),
3306 reconnect_jitter_ms: Some(0),
3307 reconnect_max_attempts: Some(1),
3308 idle_timeout_ms: Some(500),
3309 backend: TransportBackend::Tungstenite,
3310 proxy_url: None,
3311 };
3312
3313 let client = WebSocketClient::connect(config, Some(handler), None, None, vec![], None)
3314 .await
3315 .unwrap();
3316
3317 assert!(client.is_active());
3318
3319 wait_until_async(
3321 || async { client.is_reconnecting() || client.is_disconnected() },
3322 Duration::from_secs(3),
3323 )
3324 .await;
3325
3326 assert!(
3327 !client.is_active(),
3328 "Client should not be active after idle timeout"
3329 );
3330
3331 client.disconnect().await;
3332 server.abort();
3333 }
3334
3335 #[rstest]
3336 #[tokio::test]
3337 async fn test_idle_timeout_resets_on_data() {
3338 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
3339 let port = listener.local_addr().unwrap().port();
3340
3341 let server = task::spawn(async move {
3343 let (stream, _) = listener.accept().await.unwrap();
3344 let mut ws = accept_async(stream).await.unwrap();
3345
3346 for _ in 0..10 {
3347 sleep(Duration::from_millis(200)).await;
3348
3349 if ws.send(WsMessage::Text("ping".into())).await.is_err() {
3350 break;
3351 }
3352 }
3353 });
3354
3355 let (handler, _rx) = channel_message_handler();
3356
3357 let config = WebSocketConfig {
3358 url: format!("ws://127.0.0.1:{port}"),
3359 headers: vec![],
3360 heartbeat: None,
3361 heartbeat_msg: None,
3362 reconnect_timeout_ms: Some(2_000),
3363 reconnect_delay_initial_ms: Some(50),
3364 reconnect_delay_max_ms: Some(100),
3365 reconnect_backoff_factor: Some(1.0),
3366 reconnect_jitter_ms: Some(0),
3367 reconnect_max_attempts: Some(1),
3368 idle_timeout_ms: Some(1_000),
3369 backend: TransportBackend::Tungstenite,
3370 proxy_url: None,
3371 };
3372
3373 let client = WebSocketClient::connect(config, Some(handler), None, None, vec![], None)
3374 .await
3375 .unwrap();
3376
3377 assert!(client.is_active());
3378
3379 sleep(Duration::from_millis(1_500)).await;
3381
3382 assert!(
3383 client.is_active(),
3384 "Client should remain active when data is flowing"
3385 );
3386
3387 client.disconnect().await;
3388 server.abort();
3389 }
3390
3391 #[rstest]
3392 #[tokio::test]
3393 async fn test_idle_timeout_fires_when_only_pings_received() {
3394 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
3400 let port = listener.local_addr().unwrap().port();
3401
3402 let server = task::spawn(async move {
3403 let (stream, _) = listener.accept().await.unwrap();
3404 let mut ws = accept_async(stream).await.unwrap();
3405
3406 for _ in 0..60 {
3407 sleep(Duration::from_millis(100)).await;
3408
3409 if ws.send(WsMessage::Ping(Vec::new().into())).await.is_err() {
3410 break;
3411 }
3412 }
3413 });
3414
3415 let (handler, _rx) = channel_message_handler();
3416
3417 let config = WebSocketConfig {
3418 url: format!("ws://127.0.0.1:{port}"),
3419 headers: vec![],
3420 heartbeat: None,
3421 heartbeat_msg: None,
3422 reconnect_timeout_ms: Some(2_000),
3423 reconnect_delay_initial_ms: Some(50),
3424 reconnect_delay_max_ms: Some(100),
3425 reconnect_backoff_factor: Some(1.0),
3426 reconnect_jitter_ms: Some(0),
3427 reconnect_max_attempts: Some(1),
3428 idle_timeout_ms: Some(500),
3429 backend: TransportBackend::Tungstenite,
3430 proxy_url: None,
3431 };
3432
3433 let client = WebSocketClient::connect(config, Some(handler), None, None, vec![], None)
3434 .await
3435 .unwrap();
3436
3437 assert!(client.is_active());
3438
3439 wait_until_async(
3443 || async { client.is_reconnecting() || client.is_disconnected() },
3444 Duration::from_millis(1_500),
3445 )
3446 .await;
3447
3448 assert!(
3449 !client.is_active(),
3450 "Client should not be active after idle timeout when only pings/pongs flow"
3451 );
3452
3453 client.disconnect().await;
3454 server.abort();
3455 }
3456
3457 #[rstest]
3458 #[tokio::test]
3459 async fn test_idle_timeout_fires_when_only_pongs_received() {
3460 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
3465 let port = listener.local_addr().unwrap().port();
3466
3467 let server = task::spawn(async move {
3468 let (stream, _) = listener.accept().await.unwrap();
3469 let mut ws = accept_async(stream).await.unwrap();
3470
3471 let deadline = tokio::time::Instant::now() + Duration::from_secs(6);
3475 while tokio::time::Instant::now() < deadline {
3476 if let Ok(Some(Err(_)) | None) =
3477 tokio::time::timeout(Duration::from_millis(100), ws.next()).await
3478 {
3479 break;
3480 }
3481 }
3482 });
3483
3484 let (handler, _rx) = channel_message_handler();
3485
3486 let config = WebSocketConfig {
3487 url: format!("ws://127.0.0.1:{port}"),
3488 headers: vec![],
3489 heartbeat: Some(1),
3490 heartbeat_msg: None,
3491 reconnect_timeout_ms: Some(2_000),
3492 reconnect_delay_initial_ms: Some(50),
3493 reconnect_delay_max_ms: Some(100),
3494 reconnect_backoff_factor: Some(1.0),
3495 reconnect_jitter_ms: Some(0),
3496 reconnect_max_attempts: Some(1),
3497 idle_timeout_ms: Some(1_500),
3498 backend: TransportBackend::Tungstenite,
3499 proxy_url: None,
3500 };
3501
3502 let client = WebSocketClient::connect(config, Some(handler), None, None, vec![], None)
3503 .await
3504 .unwrap();
3505
3506 assert!(client.is_active());
3507
3508 wait_until_async(
3512 || async { client.is_reconnecting() || client.is_disconnected() },
3513 Duration::from_millis(2_500),
3514 )
3515 .await;
3516
3517 assert!(
3518 !client.is_active(),
3519 "Client should not be active after idle timeout when only pongs flow"
3520 );
3521
3522 client.disconnect().await;
3523 server.abort();
3524 }
3525
3526 #[rstest]
3527 #[tokio::test]
3528 async fn test_disconnect_during_backoff_exits_promptly() {
3529 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
3533 let port = listener.local_addr().unwrap().port();
3534
3535 let server = task::spawn(async move {
3536 if let Ok((stream, _)) = listener.accept().await {
3538 let _ = accept_async(stream).await;
3539 }
3540 sleep(Duration::from_mins(1)).await;
3542 });
3543
3544 let (handler, _rx) = channel_message_handler();
3545
3546 let config = WebSocketConfig {
3547 url: format!("ws://127.0.0.1:{port}"),
3548 headers: vec![],
3549 heartbeat: None,
3550 heartbeat_msg: None,
3551 reconnect_timeout_ms: Some(1_000),
3552 reconnect_delay_initial_ms: Some(10_000), reconnect_delay_max_ms: Some(10_000),
3554 reconnect_backoff_factor: Some(1.0),
3555 reconnect_jitter_ms: Some(0),
3556 reconnect_max_attempts: None,
3557 idle_timeout_ms: None,
3558 backend: TransportBackend::Tungstenite,
3559 proxy_url: None,
3560 };
3561
3562 let client = WebSocketClient::connect(config, Some(handler), None, None, vec![], None)
3563 .await
3564 .unwrap();
3565
3566 wait_until_async(
3568 || async { client.is_reconnecting() },
3569 Duration::from_secs(3),
3570 )
3571 .await;
3572
3573 sleep(Duration::from_millis(1_500)).await;
3575
3576 let start = std::time::Instant::now();
3578 client.disconnect().await;
3579 let elapsed = start.elapsed();
3580
3581 assert!(client.is_disconnected(), "Client should be disconnected");
3582 assert!(
3584 elapsed < Duration::from_secs(2),
3585 "Disconnect should interrupt backoff sleep, took {elapsed:?}"
3586 );
3587
3588 server.abort();
3589 }
3590
3591 #[rstest]
3592 #[tokio::test]
3593 async fn test_rate_limit_cancelled_on_disconnect() {
3594 use std::{num::NonZeroU32, sync::Arc};
3597
3598 use crate::ratelimiter::quota::Quota;
3599
3600 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
3601 let port = listener.local_addr().unwrap().port();
3602
3603 let server = task::spawn(async move {
3604 if let Ok((stream, _)) = listener.accept().await {
3605 let mut ws = accept_async(stream).await.unwrap();
3606 while let Some(Ok(msg)) = ws.next().await {
3608 if ws.send(msg).await.is_err() {
3609 break;
3610 }
3611 }
3612 }
3613 });
3614
3615 let (handler, _rx) = channel_message_handler();
3616
3617 let config = WebSocketConfig {
3618 url: format!("ws://127.0.0.1:{port}"),
3619 headers: vec![],
3620 heartbeat: None,
3621 heartbeat_msg: None,
3622 reconnect_timeout_ms: Some(5_000),
3623 reconnect_delay_initial_ms: Some(100),
3624 reconnect_delay_max_ms: Some(500),
3625 reconnect_backoff_factor: Some(1.5),
3626 reconnect_jitter_ms: Some(0),
3627 reconnect_max_attempts: None,
3628 idle_timeout_ms: None,
3629 backend: TransportBackend::Tungstenite,
3630 proxy_url: None,
3631 };
3632
3633 let quota = Quota::with_period(Duration::from_mins(1))
3635 .unwrap()
3636 .allow_burst(NonZeroU32::new(1).unwrap());
3637
3638 let client = Arc::new(
3639 WebSocketClient::connect(
3640 config,
3641 Some(handler),
3642 None,
3643 None,
3644 vec![("rate_key".to_string(), quota)],
3645 None,
3646 )
3647 .await
3648 .unwrap(),
3649 );
3650
3651 let test_key: [Ustr; 1] = [Ustr::from("rate_key")];
3652
3653 client
3655 .send_text("exhaust".to_string(), Some(test_key.as_slice()))
3656 .await
3657 .unwrap();
3658
3659 let client_clone = client.clone();
3661 let send_handle = task::spawn(async move {
3662 client_clone
3663 .send_text("blocked".to_string(), Some(&[Ustr::from("rate_key")]))
3664 .await
3665 });
3666
3667 sleep(Duration::from_millis(200)).await;
3669
3670 let start = std::time::Instant::now();
3672 client.disconnect().await;
3673 let elapsed_disconnect = start.elapsed();
3674
3675 let result = tokio::time::timeout(Duration::from_secs(2), send_handle)
3677 .await
3678 .expect("Send task should complete quickly")
3679 .expect("Send task should not panic");
3680
3681 assert!(
3682 matches!(result, Err(crate::error::SendError::Closed)),
3683 "Blocked send should return Closed, was: {result:?}"
3684 );
3685
3686 assert!(
3688 elapsed_disconnect < Duration::from_secs(3),
3689 "Disconnect should not wait for rate limiter, took {elapsed_disconnect:?}"
3690 );
3691
3692 server.abort();
3693 }
3694
3695 #[rstest]
3696 #[tokio::test]
3697 async fn test_stream_mode_transitions_to_closed_on_dead_write_task() {
3698 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
3702 let port = listener.local_addr().unwrap().port();
3703
3704 let server = task::spawn(async move {
3705 if let Ok((stream, _)) = listener.accept().await
3706 && let Ok(ws) = accept_async(stream).await
3707 {
3708 drop(ws);
3710 }
3711 });
3712
3713 let config = WebSocketConfig {
3714 url: format!("ws://127.0.0.1:{port}"),
3715 headers: vec![],
3716 heartbeat: None,
3717 heartbeat_msg: None,
3718 reconnect_timeout_ms: Some(1_000),
3719 reconnect_delay_initial_ms: Some(50),
3720 reconnect_delay_max_ms: Some(100),
3721 reconnect_backoff_factor: Some(1.0),
3722 reconnect_jitter_ms: Some(0),
3723 reconnect_max_attempts: None,
3724 idle_timeout_ms: None,
3725 backend: TransportBackend::Tungstenite,
3726 proxy_url: None,
3727 };
3728
3729 let (_reader, client) = WebSocketClient::connect_stream(config, vec![], None, None)
3730 .await
3731 .unwrap();
3732
3733 assert!(client.is_active(), "Client should start active");
3734
3735 sleep(Duration::from_millis(100)).await;
3737
3738 for _ in 0..20 {
3740 let _ = client.send_text("ping".to_string(), None).await;
3741 sleep(Duration::from_millis(50)).await;
3742
3743 if !client.is_active() {
3744 break;
3745 }
3746 }
3747
3748 wait_until_async(|| async { !client.is_active() }, Duration::from_secs(5)).await;
3750
3751 assert!(
3753 client.is_closed() || client.is_disconnected(),
3754 "Stream mode should transition to CLOSED, not RECONNECT. \
3755 is_reconnecting={}, is_closed={}, is_disconnected={}",
3756 client.is_reconnecting(),
3757 client.is_closed(),
3758 client.is_disconnected(),
3759 );
3760 assert!(
3761 !client.is_reconnecting(),
3762 "Stream mode should never attempt reconnection"
3763 );
3764
3765 server.abort();
3766 }
3767
3768 #[tokio::test]
3769 async fn test_write_task_waits_for_auth_before_replaying_buffer() {
3770 use nautilus_common::testing::wait_until_async;
3771
3772 let server = RecordingServer::setup().await;
3773 let url = format!("ws://127.0.0.1:{}", server.port);
3774 let (writer, _reader) = WebSocketClientInner::connect_with_server(
3775 &url,
3776 vec![],
3777 TransportBackend::Tungstenite,
3778 None,
3779 )
3780 .await
3781 .unwrap();
3782
3783 let connection_state = Arc::new(AtomicU8::new(ConnectionMode::Reconnect.as_u8()));
3784 let state_notify = Arc::new(tokio::sync::Notify::new());
3785 let auth_tracker = Arc::new(OnceLock::new());
3786 let reconnect_buffer_waits_for_auth = Arc::new(AtomicBool::new(true));
3787 let tracker = AuthTracker::new();
3788 auth_tracker.set(tracker.clone()).unwrap();
3789
3790 let (writer_tx, writer_rx) = tokio::sync::mpsc::unbounded_channel();
3791 let write_task = WebSocketClientInner::spawn_write_task(
3792 Arc::clone(&connection_state),
3793 Arc::clone(&state_notify),
3794 writer,
3795 writer_rx,
3796 Arc::clone(&auth_tracker),
3797 Arc::clone(&reconnect_buffer_waits_for_auth),
3798 );
3799
3800 writer_tx
3801 .send(WriterCommand::Send(Message::Text("stale".into())))
3802 .unwrap();
3803
3804 let (new_writer, _reader) = WebSocketClientInner::connect_with_server(
3805 &url,
3806 vec![],
3807 TransportBackend::Tungstenite,
3808 None,
3809 )
3810 .await
3811 .unwrap();
3812 let (tx, rx) = tokio::sync::oneshot::channel();
3813 writer_tx
3814 .send(WriterCommand::Update(new_writer, tx))
3815 .unwrap();
3816 assert!(rx.await.unwrap());
3817
3818 connection_state.store(ConnectionMode::Active.as_u8(), Ordering::SeqCst);
3819
3820 tokio::time::sleep(Duration::from_millis(300)).await;
3821 assert!(
3822 server.messages().await.is_empty(),
3823 "buffered messages should wait for re-authentication"
3824 );
3825
3826 tracker.succeed();
3827
3828 wait_until_async(
3829 || {
3830 let messages = Arc::clone(&server.messages);
3831 async move { !messages.lock().await.is_empty() }
3832 },
3833 Duration::from_secs(3),
3834 )
3835 .await;
3836
3837 assert_eq!(server.messages().await, vec!["stale".to_string()]);
3838
3839 connection_state.store(ConnectionMode::Closed.as_u8(), Ordering::SeqCst);
3840 state_notify.notify_waiters();
3841 drop(writer_tx);
3842 write_task.abort();
3843 }
3844
3845 #[tokio::test]
3846 async fn test_write_task_discards_buffer_after_auth_failure() {
3847 let server = RecordingServer::setup().await;
3848 let url = format!("ws://127.0.0.1:{}", server.port);
3849 let (writer, _reader) = WebSocketClientInner::connect_with_server(
3850 &url,
3851 vec![],
3852 TransportBackend::Tungstenite,
3853 None,
3854 )
3855 .await
3856 .unwrap();
3857
3858 let connection_state = Arc::new(AtomicU8::new(ConnectionMode::Reconnect.as_u8()));
3859 let state_notify = Arc::new(tokio::sync::Notify::new());
3860 let auth_tracker = Arc::new(OnceLock::new());
3861 let reconnect_buffer_waits_for_auth = Arc::new(AtomicBool::new(true));
3862 let tracker = AuthTracker::new();
3863 auth_tracker.set(tracker.clone()).unwrap();
3864
3865 let (writer_tx, writer_rx) = tokio::sync::mpsc::unbounded_channel();
3866 let write_task = WebSocketClientInner::spawn_write_task(
3867 Arc::clone(&connection_state),
3868 Arc::clone(&state_notify),
3869 writer,
3870 writer_rx,
3871 Arc::clone(&auth_tracker),
3872 Arc::clone(&reconnect_buffer_waits_for_auth),
3873 );
3874
3875 writer_tx
3876 .send(WriterCommand::Send(Message::Text("stale".into())))
3877 .unwrap();
3878
3879 let (new_writer, _reader) = WebSocketClientInner::connect_with_server(
3880 &url,
3881 vec![],
3882 TransportBackend::Tungstenite,
3883 None,
3884 )
3885 .await
3886 .unwrap();
3887 let (tx, rx) = tokio::sync::oneshot::channel();
3888 writer_tx
3889 .send(WriterCommand::Update(new_writer, tx))
3890 .unwrap();
3891 assert!(rx.await.unwrap());
3892
3893 connection_state.store(ConnectionMode::Active.as_u8(), Ordering::SeqCst);
3894 tracker.fail("rejected");
3895 tokio::time::sleep(Duration::from_millis(300)).await;
3896 assert!(
3897 server.messages().await.is_empty(),
3898 "buffered messages should be discarded after authentication failure"
3899 );
3900
3901 let _auth_receiver = tracker.begin();
3902 tracker.succeed();
3903 tokio::time::sleep(Duration::from_millis(300)).await;
3904 assert!(
3905 server.messages().await.is_empty(),
3906 "discarded buffered messages should not replay on a later auth success"
3907 );
3908
3909 connection_state.store(ConnectionMode::Closed.as_u8(), Ordering::SeqCst);
3910 state_notify.notify_waiters();
3911 drop(writer_tx);
3912 write_task.abort();
3913 }
3914
3915 #[rstest]
3916 #[tokio::test]
3917 async fn test_zero_idle_timeout_rejected() {
3918 let (handler, _rx) = channel_message_handler();
3919
3920 let config = WebSocketConfig {
3921 url: "ws://127.0.0.1:9999".to_string(),
3922 headers: vec![],
3923 heartbeat: None,
3924 heartbeat_msg: None,
3925 reconnect_timeout_ms: None,
3926 reconnect_delay_initial_ms: None,
3927 reconnect_delay_max_ms: None,
3928 reconnect_backoff_factor: None,
3929 reconnect_jitter_ms: None,
3930 reconnect_max_attempts: None,
3931 idle_timeout_ms: Some(0),
3932 backend: TransportBackend::Tungstenite,
3933 proxy_url: None,
3934 };
3935
3936 let result =
3937 WebSocketClient::connect(config, Some(handler), None, None, vec![], None).await;
3938
3939 assert!(result.is_err(), "Zero idle timeout should be rejected");
3940 let err_msg = result.unwrap_err().to_string();
3941 assert!(
3942 err_msg.contains("Idle timeout cannot be zero"),
3943 "Error should mention zero idle timeout, was: {err_msg}"
3944 );
3945 }
3946
3947 #[cfg(all(feature = "transport-sockudo", not(feature = "turmoil")))]
3948 #[rstest]
3949 #[tokio::test]
3950 async fn test_sockudo_backend_rejects_reserved_headers_before_connect() {
3951 let (handler, _rx) = channel_message_handler();
3952
3953 let config = WebSocketConfig {
3954 url: "ws://127.0.0.1:1".to_string(),
3955 headers: vec![("Host".to_string(), "example.com".to_string())],
3956 heartbeat: None,
3957 heartbeat_msg: None,
3958 reconnect_timeout_ms: None,
3959 reconnect_delay_initial_ms: None,
3960 reconnect_delay_max_ms: None,
3961 reconnect_backoff_factor: None,
3962 reconnect_jitter_ms: None,
3963 reconnect_max_attempts: None,
3964 idle_timeout_ms: None,
3965 backend: TransportBackend::Sockudo,
3966 proxy_url: None,
3967 };
3968
3969 let err = WebSocketClient::connect(config, Some(handler), None, None, vec![], None)
3970 .await
3971 .expect_err("reserved header should fail before TCP connect");
3972
3973 assert!(
3974 err.to_string()
3975 .contains("reserved upgrade header not allowed in extra_headers"),
3976 "expected reserved-header failure, was: {err}"
3977 );
3978 }
3979
3980 #[cfg(all(feature = "transport-sockudo", not(feature = "turmoil")))]
3981 #[rstest]
3982 #[tokio::test]
3983 async fn test_sockudo_backend_replays_leftover_without_custom_headers() {
3984 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
3985 let port = listener.local_addr().unwrap().port();
3986
3987 let server = task::spawn(async move {
3988 if let Ok((mut stream, _)) = listener.accept().await {
3989 let request = read_http_request(&mut stream).await;
3990 let request = String::from_utf8(request).unwrap();
3991 let sec_websocket_key = extract_header(&request, "Sec-WebSocket-Key").unwrap();
3992 let accept = sockudo_handshake::generate_accept_key(sec_websocket_key);
3993 let mut response = format!(
3994 concat!(
3995 "HTTP/1.1 101 Switching Protocols\r\n",
3996 "Upgrade: websocket\r\n",
3997 "Connection: Upgrade\r\n",
3998 "Sec-WebSocket-Accept: {}\r\n",
3999 "\r\n",
4000 ),
4001 accept
4002 )
4003 .into_bytes();
4004 response.extend_from_slice(b"\x81\x05hello");
4005 stream.write_all(&response).await.unwrap();
4006 }
4007 });
4008
4009 let (handler, mut rx) = channel_message_handler();
4010
4011 let config = WebSocketConfig {
4012 url: format!("ws://127.0.0.1:{port}/ws"),
4013 headers: vec![],
4014 heartbeat: None,
4015 heartbeat_msg: None,
4016 reconnect_timeout_ms: Some(2_000),
4017 reconnect_delay_initial_ms: Some(50),
4018 reconnect_delay_max_ms: Some(100),
4019 reconnect_backoff_factor: Some(1.0),
4020 reconnect_jitter_ms: Some(0),
4021 reconnect_max_attempts: None,
4022 idle_timeout_ms: None,
4023 backend: TransportBackend::Sockudo,
4024 proxy_url: None,
4025 };
4026
4027 let client = WebSocketClient::connect(config, Some(handler), None, None, vec![], None)
4028 .await
4029 .expect("sockudo connect without custom headers");
4030
4031 let received = tokio::time::timeout(Duration::from_secs(3), async {
4032 loop {
4033 if let Ok(msg) = rx.try_recv() {
4034 return msg;
4035 }
4036 tokio::time::sleep(Duration::from_millis(10)).await;
4037 }
4038 })
4039 .await
4040 .expect("did not receive leftover frame before timeout");
4041
4042 match received {
4043 WsMessage::Text(t) => assert_eq!(t.as_str(), "hello"),
4044 other => panic!("expected text, was {other:?}"),
4045 }
4046
4047 client.disconnect().await;
4048 tokio::time::timeout(Duration::from_secs(3), server)
4049 .await
4050 .expect("server did not close before timeout")
4051 .unwrap();
4052 }
4053
4054 #[cfg(all(feature = "transport-sockudo", not(feature = "turmoil")))]
4055 #[rstest]
4056 #[tokio::test]
4057 async fn test_sockudo_backend_sends_custom_headers() {
4058 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
4059 let port = listener.local_addr().unwrap().port();
4060
4061 let server = task::spawn(async move {
4062 if let Ok((stream, _)) = listener.accept().await {
4063 let callback = HeaderAssertCallback {
4064 key: "X-Test".to_string(),
4065 value: HeaderValue::from_static("value"),
4066 };
4067
4068 if let Ok(mut ws) = accept_hdr_async(stream, callback).await {
4069 while let Some(Ok(msg)) = ws.next().await {
4070 if msg.is_text() || msg.is_binary() {
4071 if ws.send(msg).await.is_err() {
4072 break;
4073 }
4074
4075 continue;
4076 }
4077
4078 if msg.is_close() {
4079 let _ = ws.close(None).await;
4080 break;
4081 }
4082 }
4083 }
4084 }
4085 });
4086
4087 let (handler, mut rx) = channel_message_handler();
4088
4089 let config = WebSocketConfig {
4090 url: format!("ws://127.0.0.1:{port}"),
4091 headers: vec![("X-Test".to_string(), "value".to_string())],
4092 heartbeat: None,
4093 heartbeat_msg: None,
4094 reconnect_timeout_ms: Some(2_000),
4095 reconnect_delay_initial_ms: Some(50),
4096 reconnect_delay_max_ms: Some(100),
4097 reconnect_backoff_factor: Some(1.0),
4098 reconnect_jitter_ms: Some(0),
4099 reconnect_max_attempts: None,
4100 idle_timeout_ms: None,
4101 backend: TransportBackend::Sockudo,
4102 proxy_url: None,
4103 };
4104
4105 let client = WebSocketClient::connect(config, Some(handler), None, None, vec![], None)
4106 .await
4107 .expect("sockudo connect with custom headers");
4108
4109 client.send_text("ping".to_string(), None).await.unwrap();
4110
4111 let received = tokio::time::timeout(Duration::from_secs(3), async {
4112 loop {
4113 if let Ok(msg) = rx.try_recv() {
4114 return msg;
4115 }
4116 tokio::time::sleep(Duration::from_millis(10)).await;
4117 }
4118 })
4119 .await
4120 .expect("did not receive echo before timeout");
4121
4122 match received {
4123 WsMessage::Text(t) => assert_eq!(t.as_str(), "ping"),
4124 other => panic!("expected text, was {other:?}"),
4125 }
4126
4127 client.disconnect().await;
4128 tokio::time::timeout(Duration::from_secs(3), server)
4129 .await
4130 .expect("server did not close before timeout")
4131 .unwrap();
4132 }
4133
4134 #[cfg(all(feature = "transport-sockudo", not(feature = "turmoil")))]
4135 #[rstest]
4136 #[tokio::test]
4137 async fn test_sockudo_backend_round_trip_text() {
4138 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
4140 let port = listener.local_addr().unwrap().port();
4141
4142 let server = task::spawn(async move {
4143 if let Ok((stream, _)) = listener.accept().await
4144 && let Ok(mut ws) = accept_async(stream).await
4145 {
4146 while let Some(Ok(msg)) = ws.next().await {
4147 #[expect(clippy::collapsible_match)]
4149 match msg {
4150 WsMessage::Text(_) | WsMessage::Binary(_) => {
4151 if ws.send(msg).await.is_err() {
4152 break;
4153 }
4154 }
4155 WsMessage::Close(_) => {
4156 let _ = ws.close(None).await;
4157 break;
4158 }
4159 _ => {}
4160 }
4161 }
4162 }
4163 });
4164
4165 let (handler, mut rx) = channel_message_handler();
4166 let config = WebSocketConfig {
4167 url: format!("ws://127.0.0.1:{port}"),
4168 headers: vec![],
4169 heartbeat: None,
4170 heartbeat_msg: None,
4171 reconnect_timeout_ms: Some(2_000),
4172 reconnect_delay_initial_ms: Some(50),
4173 reconnect_delay_max_ms: Some(100),
4174 reconnect_backoff_factor: Some(1.0),
4175 reconnect_jitter_ms: Some(0),
4176 reconnect_max_attempts: None,
4177 idle_timeout_ms: None,
4178 backend: TransportBackend::Sockudo,
4179 proxy_url: None,
4180 };
4181
4182 let client = WebSocketClient::connect(config, Some(handler), None, None, vec![], None)
4183 .await
4184 .expect("sockudo connect");
4185
4186 client.send_text("ping".to_string(), None).await.unwrap();
4187
4188 let received = tokio::time::timeout(Duration::from_secs(3), async {
4189 loop {
4190 if let Ok(msg) = rx.try_recv() {
4191 return msg;
4192 }
4193 tokio::time::sleep(Duration::from_millis(10)).await;
4194 }
4195 })
4196 .await
4197 .expect("did not receive echo before timeout");
4198
4199 match received {
4200 WsMessage::Text(t) => assert_eq!(t.as_str(), "ping"),
4201 other => panic!("expected text, was {other:?}"),
4202 }
4203
4204 client.disconnect().await;
4205 server.abort();
4206 }
4207
4208 #[cfg(all(feature = "transport-sockudo", not(feature = "turmoil")))]
4209 #[rstest]
4210 #[case::ws_default_port("ws://example.com/ws", "example.com", "example.com", 80, "/ws", false)]
4211 #[case::wss_default_port(
4212 "wss://example.com/ws",
4213 "example.com",
4214 "example.com",
4215 443,
4216 "/ws",
4217 true
4218 )]
4219 #[case::ws_explicit_default(
4222 "ws://example.com:80/ws",
4223 "example.com",
4224 "example.com",
4225 80,
4226 "/ws",
4227 false
4228 )]
4229 #[case::ws_non_default(
4230 "ws://example.com:8443/feed",
4231 "example.com",
4232 "example.com:8443",
4233 8443,
4234 "/feed",
4235 false
4236 )]
4237 #[case::wss_non_default(
4238 "wss://example.com:9443/feed",
4239 "example.com",
4240 "example.com:9443",
4241 9443,
4242 "/feed",
4243 true
4244 )]
4245 #[case::root_path(
4246 "ws://example.com:9000/",
4247 "example.com",
4248 "example.com:9000",
4249 9000,
4250 "/",
4251 false
4252 )]
4253 #[case::query_string(
4254 "ws://example.com/feed?token=abc&channel=trades",
4255 "example.com",
4256 "example.com",
4257 80,
4258 "/feed?token=abc&channel=trades",
4259 false
4260 )]
4261 #[case::ipv6_default("ws://[::1]/feed", "::1", "[::1]", 80, "/feed", false)]
4263 #[case::ipv6_explicit_port("ws://[::1]:9000/feed", "::1", "[::1]:9000", 9000, "/feed", false)]
4264 #[case::ipv6_wss(
4265 "wss://[2001:db8::1]:8443/",
4266 "2001:db8::1",
4267 "[2001:db8::1]:8443",
4268 8443,
4269 "/",
4270 true
4271 )]
4272 fn sockudo_target_parses_url(
4273 #[case] url: &str,
4274 #[case] host: &str,
4275 #[case] host_header: &str,
4276 #[case] port: u16,
4277 #[case] path: &str,
4278 #[case] is_tls: bool,
4279 ) {
4280 let target = super::SockudoTarget::parse(url).expect("parse should succeed");
4281 assert_eq!(target.host, host);
4282 assert_eq!(target.host_header, host_header);
4283 assert_eq!(target.port, port);
4284 assert_eq!(target.path, path);
4285 assert_eq!(target.is_tls, is_tls);
4286 }
4287
4288 #[cfg(all(feature = "transport-sockudo", not(feature = "turmoil")))]
4289 #[rstest]
4290 fn sockudo_target_rejects_unsupported_scheme() {
4291 let err = super::SockudoTarget::parse("http://example.com/feed").expect_err("not a ws URL");
4292 let msg = err.to_string();
4293 assert!(
4294 msg.contains("expected ws:// or wss://"),
4295 "unexpected error: {msg}"
4296 );
4297 }
4298
4299 #[cfg(all(feature = "transport-sockudo", not(feature = "turmoil")))]
4300 #[rstest]
4301 fn sockudo_target_rejects_malformed_url() {
4302 let err = super::SockudoTarget::parse("not a url").expect_err("malformed URL");
4303 assert!(
4304 matches!(err, super::TransportError::InvalidUrl(_)),
4305 "expected InvalidUrl, was: {err:?}"
4306 );
4307 }
4308}
4309
4310#[cfg(test)]
4311#[cfg(feature = "turmoil")]
4312mod turmoil_tests {
4313 use std::{sync::Arc, time::Duration};
4314
4315 use futures_util::{SinkExt, StreamExt};
4316 use nautilus_common::testing::wait_until_async;
4317 use rstest::rstest;
4318 use tokio_tungstenite::{accept_async, tungstenite::Message as WsMessage};
4319 use turmoil::{Builder, net};
4320
4321 use super::*;
4322 use crate::websocket::types::channel_message_handler;
4323
4324 #[rstest]
4325 fn test_turmoil_reconnect_buffer_waits_for_auth() {
4326 let mut sim = Builder::new().build();
4327 let messages = Arc::new(tokio::sync::Mutex::new(Vec::new()));
4328 let server_messages = Arc::clone(&messages);
4329
4330 sim.host("server", move || {
4331 let messages = Arc::clone(&server_messages);
4332 auth_buffer_server(messages)
4333 });
4334
4335 sim.client("client", async move {
4336 let tracker = AuthTracker::new();
4337 let (handler, _rx) = channel_message_handler();
4338 let client = WebSocketClient::connect(
4339 turmoil_websocket_config(),
4340 Some(handler),
4341 None,
4342 None,
4343 vec![],
4344 None,
4345 )
4346 .await
4347 .expect("Should connect");
4348
4349 client.set_auth_tracker(tracker.clone(), true);
4350 assert!(client.is_active(), "Client should start active");
4351
4352 wait_until_async(
4353 || async { client.is_reconnecting() },
4354 Duration::from_secs(3),
4355 )
4356 .await;
4357
4358 client
4359 .writer_tx
4360 .send(WriterCommand::Send(Message::Text("stale".into())))
4361 .unwrap();
4362
4363 wait_until_async(|| async { client.is_active() }, Duration::from_secs(3)).await;
4364
4365 let _auth_receiver = tracker.begin();
4366
4367 tokio::time::sleep(Duration::from_millis(300)).await;
4368 assert!(
4369 messages.lock().await.is_empty(),
4370 "buffered messages should wait for auth after reconnect"
4371 );
4372
4373 tracker.succeed();
4374
4375 wait_until_async(
4376 || {
4377 let messages = Arc::clone(&messages);
4378 async move { messages.lock().await.as_slice() == ["stale"] }
4379 },
4380 Duration::from_secs(3),
4381 )
4382 .await;
4383
4384 assert_eq!(messages.lock().await.as_slice(), ["stale"]);
4385
4386 client.disconnect().await;
4387 assert!(client.is_disconnected());
4388
4389 Ok(())
4390 });
4391
4392 sim.run().unwrap();
4393 }
4394
4395 #[rstest]
4396 fn test_turmoil_reconnect_buffer_discards_after_auth_failure() {
4397 let mut sim = Builder::new().build();
4398 let messages = Arc::new(tokio::sync::Mutex::new(Vec::new()));
4399 let server_messages = Arc::clone(&messages);
4400
4401 sim.host("server", move || {
4402 let messages = Arc::clone(&server_messages);
4403 auth_buffer_server(messages)
4404 });
4405
4406 sim.client("client", async move {
4407 let tracker = AuthTracker::new();
4408 let (handler, _rx) = channel_message_handler();
4409 let client = WebSocketClient::connect(
4410 turmoil_websocket_config(),
4411 Some(handler),
4412 None,
4413 None,
4414 vec![],
4415 None,
4416 )
4417 .await
4418 .expect("Should connect");
4419
4420 client.set_auth_tracker(tracker.clone(), true);
4421 assert!(client.is_active(), "Client should start active");
4422
4423 wait_until_async(
4424 || async { client.is_reconnecting() },
4425 Duration::from_secs(3),
4426 )
4427 .await;
4428
4429 client
4430 .writer_tx
4431 .send(WriterCommand::Send(Message::Text("stale".into())))
4432 .unwrap();
4433
4434 wait_until_async(|| async { client.is_active() }, Duration::from_secs(3)).await;
4435
4436 let _auth_receiver = tracker.begin();
4437 tracker.fail("rejected");
4438
4439 tokio::time::sleep(Duration::from_millis(300)).await;
4440 assert!(
4441 messages.lock().await.is_empty(),
4442 "buffered messages should be discarded after auth failure"
4443 );
4444
4445 let _retry_auth_receiver = tracker.begin();
4446 tracker.succeed();
4447
4448 tokio::time::sleep(Duration::from_millis(300)).await;
4449 assert!(
4450 messages.lock().await.is_empty(),
4451 "discarded messages should not replay on a later auth success"
4452 );
4453
4454 client.disconnect().await;
4455 assert!(client.is_disconnected());
4456
4457 Ok(())
4458 });
4459
4460 sim.run().unwrap();
4461 }
4462
4463 fn turmoil_websocket_config() -> WebSocketConfig {
4464 WebSocketConfig {
4465 url: "ws://server:8080".to_string(),
4466 headers: vec![],
4467 heartbeat: None,
4468 heartbeat_msg: None,
4469 reconnect_timeout_ms: Some(5_000),
4470 reconnect_delay_initial_ms: Some(50),
4471 reconnect_delay_max_ms: Some(200),
4472 reconnect_backoff_factor: Some(1.0),
4473 reconnect_jitter_ms: Some(0),
4474 reconnect_max_attempts: None,
4475 idle_timeout_ms: None,
4476 backend: TransportBackend::Tungstenite,
4477 proxy_url: None,
4478 }
4479 }
4480
4481 async fn auth_buffer_server(
4482 messages: Arc<tokio::sync::Mutex<Vec<String>>>,
4483 ) -> Result<(), Box<dyn std::error::Error>> {
4484 let listener = net::TcpListener::bind("0.0.0.0:8080").await?;
4485
4486 let (stream, _) = listener.accept().await?;
4487 let mut websocket = accept_async(stream).await?;
4488 let _ = websocket.send(WsMessage::Text("first".into())).await;
4489 drop(websocket);
4490
4491 tokio::time::sleep(Duration::from_millis(200)).await;
4492
4493 let (stream, _) = listener.accept().await?;
4494 let mut websocket = accept_async(stream).await?;
4495
4496 while let Some(msg) = websocket.next().await {
4497 match msg {
4498 Ok(WsMessage::Text(text)) => {
4499 messages.lock().await.push(text.to_string());
4500 }
4501 Ok(WsMessage::Close(_)) => {
4502 let _ = websocket.close(None).await;
4503 break;
4504 }
4505 Ok(_) => {}
4506 Err(_) => break,
4507 }
4508 }
4509
4510 Ok(())
4511 }
4512}