use std::{
collections::VecDeque,
fmt::Debug,
path::Path,
sync::{
Arc,
atomic::{AtomicU8, Ordering},
},
time::Duration,
};
use bytes::Bytes;
use nautilus_core::CleanDrop;
use nautilus_cryptography::providers::install_cryptographic_provider;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio_tungstenite::tungstenite::{Error, client::IntoClientRequest, stream::Mode};
use super::{SocketConfig, TcpMessageHandler, TcpReader, TcpWriter, WriterCommand};
use crate::{
backoff::ExponentialBackoff,
error::SendError,
logging::{log_task_aborted, log_task_started, log_task_stopped},
mode::ConnectionMode,
net::TcpStream,
tls::{Connector, create_tls_config_from_certs_dir, tcp_tls},
};
const CONNECTION_STATE_CHECK_INTERVAL_MS: u64 = 10;
const GRACEFUL_SHUTDOWN_DELAY_MS: u64 = 100;
const GRACEFUL_SHUTDOWN_TIMEOUT_SECS: u64 = 5;
const MAX_READ_BUFFER_BYTES: usize = 10 * 1024 * 1024;
#[cfg_attr(
feature = "python",
pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.network")
)]
struct SocketClientInner {
config: SocketConfig,
connector: Option<Connector>,
read_task: Arc<tokio::task::JoinHandle<()>>,
write_task: tokio::task::JoinHandle<()>,
writer_tx: tokio::sync::mpsc::UnboundedSender<WriterCommand>,
heartbeat_task: Option<tokio::task::JoinHandle<()>>,
connection_mode: Arc<AtomicU8>,
state_notify: Arc<tokio::sync::Notify>,
reconnect_timeout: Duration,
backoff: ExponentialBackoff,
handler: Option<TcpMessageHandler>,
reconnect_max_attempts: Option<u32>,
reconnect_attempt_count: u32,
}
impl SocketClientInner {
pub async fn connect_url(config: SocketConfig) -> anyhow::Result<Self> {
const CONNECTION_TIMEOUT_SECS: u64 = 10;
install_cryptographic_provider();
if config.suffix.is_empty() {
anyhow::bail!("Socket suffix cannot be empty: suffix is required for message framing");
}
if let Some((interval_secs, _)) = &config.heartbeat
&& *interval_secs == 0
{
anyhow::bail!("Heartbeat interval cannot be zero");
}
if config.idle_timeout_ms == Some(0) {
anyhow::bail!("Idle timeout cannot be zero");
}
let SocketConfig {
url,
mode,
heartbeat,
suffix,
message_handler,
reconnect_timeout_ms,
reconnect_delay_initial_ms,
reconnect_delay_max_ms,
reconnect_backoff_factor,
reconnect_jitter_ms,
connection_max_retries,
reconnect_max_attempts,
idle_timeout_ms,
certs_dir,
} = &config.clone();
let connector = if let Some(dir) = certs_dir {
let config = create_tls_config_from_certs_dir(Path::new(dir), false)?;
Some(Connector::Rustls(Arc::new(config)))
} else {
None
};
let max_retries = connection_max_retries.unwrap_or(5);
let mut backoff = ExponentialBackoff::new(
Duration::from_millis(500),
Duration::from_millis(5000),
2.0,
250,
false,
)?;
#[allow(unused_assignments)]
let mut last_error = String::new();
let mut attempt = 0;
let (reader, writer) = loop {
attempt += 1;
match tokio::time::timeout(
Duration::from_secs(CONNECTION_TIMEOUT_SECS),
Self::tls_connect_with_server(url, *mode, connector.clone()),
)
.await
{
Ok(Ok(result)) => {
if attempt > 1 {
log::info!("Socket connection established after {attempt} attempts");
}
break result;
}
Ok(Err(e)) => {
last_error = e.to_string();
log::warn!(
"Socket connection attempt {attempt}/{max_retries} to {url} failed: {last_error}"
);
}
Err(_) => {
last_error = format!(
"Connection timeout after {CONNECTION_TIMEOUT_SECS}s (possible DNS resolution failure)"
);
log::warn!(
"Socket connection attempt {attempt}/{max_retries} to {url} timed out"
);
}
}
if attempt >= max_retries {
anyhow::bail!(
"Failed to connect to {} after {} attempts: {}. \
If this is a DNS error, check your network configuration and DNS settings.",
url,
max_retries,
if last_error.is_empty() {
"unknown error"
} else {
&last_error
}
);
}
let delay = backoff.next_duration();
log::debug!(
"Retrying in {delay:?} (attempt {}/{})",
attempt + 1,
max_retries
);
tokio::time::sleep(delay).await;
};
log::debug!("Connected");
let connection_mode = Arc::new(AtomicU8::new(ConnectionMode::Active.as_u8()));
let state_notify = Arc::new(tokio::sync::Notify::new());
let read_task = Arc::new(Self::spawn_read_task(
connection_mode.clone(),
reader,
message_handler.clone(),
suffix.clone(),
*idle_timeout_ms,
));
let (writer_tx, writer_rx) = tokio::sync::mpsc::unbounded_channel::<WriterCommand>();
let write_task = Self::spawn_write_task(
connection_mode.clone(),
state_notify.clone(),
writer,
writer_rx,
suffix.clone(),
);
let heartbeat_task = heartbeat.as_ref().map(|heartbeat| {
Self::spawn_heartbeat_task(
connection_mode.clone(),
heartbeat.clone(),
writer_tx.clone(),
)
});
let reconnect_timeout = Duration::from_millis(reconnect_timeout_ms.unwrap_or(10_000));
let backoff = ExponentialBackoff::new(
Duration::from_millis(reconnect_delay_initial_ms.unwrap_or(2_000)),
Duration::from_millis(reconnect_delay_max_ms.unwrap_or(30_000)),
reconnect_backoff_factor.unwrap_or(1.5),
reconnect_jitter_ms.unwrap_or(100),
true, )?;
Ok(Self {
config,
connector,
read_task,
write_task,
writer_tx,
heartbeat_task,
connection_mode,
state_notify,
reconnect_timeout,
backoff,
handler: message_handler.clone(),
reconnect_max_attempts: *reconnect_max_attempts,
reconnect_attempt_count: 0,
})
}
fn parse_socket_url(url: &str, mode: Mode) -> Result<(String, String), Error> {
if url.contains("://") {
let parsed = url.parse::<http::Uri>().map_err(|e| {
Error::Io(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
format!("Invalid URL: {e}"),
))
})?;
let host = parsed.host().ok_or_else(|| {
Error::Io(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
"URL missing host",
))
})?;
let port = parsed
.port_u16()
.unwrap_or_else(|| match parsed.scheme_str() {
Some("wss" | "https") => 443,
Some("ws" | "http") => 80,
_ => match mode {
Mode::Tls => 443,
Mode::Plain => 80,
},
});
Ok((format!("{host}:{port}"), url.to_string()))
} else {
let scheme = match mode {
Mode::Tls => "wss",
Mode::Plain => "ws",
};
Ok((url.to_string(), format!("{scheme}://{url}")))
}
}
pub async fn tls_connect_with_server(
url: &str,
mode: Mode,
connector: Option<Connector>,
) -> Result<(TcpReader, TcpWriter), Error> {
log::debug!("Connecting to {url}");
let (socket_addr, request_url) = Self::parse_socket_url(url, mode)?;
let tcp_result = TcpStream::connect(&socket_addr).await;
match tcp_result {
Ok(stream) => {
log::debug!("TCP connection established to {socket_addr}, proceeding with TLS");
if let Err(e) = stream.set_nodelay(true) {
log::warn!("Failed to enable TCP_NODELAY for socket client: {e:?}");
}
let request = request_url.into_client_request()?;
tcp_tls(&request, mode, stream, connector)
.await
.map(tokio::io::split)
}
Err(e) => {
log::error!("TCP connection failed to {socket_addr}: {e:?}");
Err(Error::Io(e))
}
}
}
async fn reconnect(&mut self) -> Result<(), Error> {
log::debug!("Reconnecting");
if ConnectionMode::from_atomic(&self.connection_mode).is_disconnect() {
log::debug!("Reconnect aborted due to disconnect state");
return Ok(());
}
tokio::time::timeout(self.reconnect_timeout, async {
let SocketConfig {
url,
mode,
heartbeat: _,
suffix,
message_handler: _,
reconnect_timeout_ms: _,
reconnect_delay_initial_ms: _,
reconnect_backoff_factor: _,
reconnect_delay_max_ms: _,
reconnect_jitter_ms: _,
connection_max_retries: _,
reconnect_max_attempts: _,
idle_timeout_ms,
certs_dir: _,
} = &self.config;
let connector = self.connector.clone();
let (reader, new_writer) = Self::tls_connect_with_server(url, *mode, connector).await?;
if ConnectionMode::from_atomic(&self.connection_mode).is_disconnect() {
log::debug!("Reconnect aborted mid-flight (after connect)");
return Ok(());
}
log::debug!("Connected");
let (tx, rx) = tokio::sync::oneshot::channel();
if let Err(e) = self.writer_tx.send(WriterCommand::Update(new_writer, tx)) {
log::error!("{e}");
return Err(Error::Io(std::io::Error::new(
std::io::ErrorKind::BrokenPipe,
format!("Failed to send update command: {e}"),
)));
}
match rx.await {
Ok(true) => log::debug!("Writer confirmed buffer drain success"),
Ok(false) => {
log::warn!("Writer failed to drain buffer, aborting reconnect");
return Err(Error::Io(std::io::Error::other(
"Failed to drain reconnection buffer",
)));
}
Err(e) => {
log::error!("Writer dropped update channel: {e}");
return Err(Error::Io(std::io::Error::new(
std::io::ErrorKind::BrokenPipe,
"Writer task dropped response channel",
)));
}
}
tokio::time::sleep(Duration::from_millis(GRACEFUL_SHUTDOWN_DELAY_MS)).await;
if ConnectionMode::from_atomic(&self.connection_mode).is_disconnect() {
log::debug!("Reconnect aborted mid-flight (after delay)");
return Ok(());
}
if !self.read_task.is_finished() {
self.read_task.abort();
log_task_aborted("read");
}
if self
.connection_mode
.compare_exchange(
ConnectionMode::Reconnect.as_u8(),
ConnectionMode::Active.as_u8(),
Ordering::SeqCst,
Ordering::SeqCst,
)
.is_err()
{
log::debug!("Reconnect aborted (state changed during reconnect)");
return Ok(());
}
self.read_task = Arc::new(Self::spawn_read_task(
self.connection_mode.clone(),
reader,
self.handler.clone(),
suffix.clone(),
*idle_timeout_ms,
));
log::debug!("Reconnect succeeded");
Ok(())
})
.await
.map_err(|_| {
Error::Io(std::io::Error::new(
std::io::ErrorKind::TimedOut,
format!(
"reconnection timed out after {}s",
self.reconnect_timeout.as_secs_f64()
),
))
})?
}
#[inline]
#[must_use]
pub fn is_alive(&self) -> bool {
!self.read_task.is_finished() && !self.write_task.is_finished()
}
#[must_use]
fn spawn_read_task(
connection_state: Arc<AtomicU8>,
mut reader: TcpReader,
handler: Option<TcpMessageHandler>,
suffix: Vec<u8>,
idle_timeout_ms: Option<u64>,
) -> tokio::task::JoinHandle<()> {
log_task_started("read");
let check_interval = Duration::from_millis(CONNECTION_STATE_CHECK_INTERVAL_MS);
let idle_timeout = idle_timeout_ms.map(Duration::from_millis);
tokio::task::spawn(async move {
let mut buf = Vec::new();
let mut last_data_time = tokio::time::Instant::now();
loop {
if !ConnectionMode::from_atomic(&connection_state).is_active() {
break;
}
match tokio::time::timeout(check_interval, reader.read_buf(&mut buf)).await {
Ok(Ok(0)) => {
log::debug!("Connection closed by server");
break;
}
Ok(Err(e)) => {
log::debug!("Connection ended: {e}");
break;
}
Ok(Ok(bytes)) => {
log::trace!("Received <binary> {bytes} bytes");
last_data_time = tokio::time::Instant::now();
while let Some((i, _)) = &buf
.windows(suffix.len())
.enumerate()
.find(|(_, pair)| pair.eq(&suffix))
{
let mut data: Vec<u8> = buf.drain(0..i + suffix.len()).collect();
data.truncate(data.len() - suffix.len());
if let Some(ref handler) = handler {
handler(&data);
}
}
if buf.len() > MAX_READ_BUFFER_BYTES {
log::error!(
"Read buffer exceeded maximum size ({MAX_READ_BUFFER_BYTES} bytes), closing connection"
);
break;
}
}
Err(_) => {
if let Some(timeout) = idle_timeout {
let idle_duration = last_data_time.elapsed();
if idle_duration >= timeout {
log::warn!(
"Read idle timeout: no data received for {:.1}s",
idle_duration.as_secs_f64()
);
break;
}
}
}
}
}
log_task_stopped("read");
})
}
async fn drain_reconnect_buffer(
buffer: &mut VecDeque<Bytes>,
writer: &mut TcpWriter,
suffix: &[u8],
) -> bool {
if buffer.is_empty() {
return false;
}
let initial_buffer_len = buffer.len();
log::info!("Sending {initial_buffer_len} buffered messages after reconnection");
let mut send_error_occurred = false;
while let Some(buffered_msg) = buffer.front() {
let mut combined_msg = Vec::with_capacity(buffered_msg.len() + suffix.len());
combined_msg.extend_from_slice(buffered_msg);
combined_msg.extend_from_slice(suffix);
if let Err(e) = writer.write_all(&combined_msg).await {
log::error!(
"Failed to send buffered message with suffix after reconnection: {e}, {} messages remain in buffer",
buffer.len()
);
send_error_occurred = true;
break;
}
buffer.pop_front();
}
if buffer.is_empty() {
log::info!("Successfully sent all {initial_buffer_len} buffered messages");
}
send_error_occurred
}
fn spawn_write_task(
connection_state: Arc<AtomicU8>,
state_notify: Arc<tokio::sync::Notify>,
writer: TcpWriter,
mut writer_rx: tokio::sync::mpsc::UnboundedReceiver<WriterCommand>,
suffix: Vec<u8>,
) -> tokio::task::JoinHandle<()> {
log_task_started("write");
let check_interval = Duration::from_millis(CONNECTION_STATE_CHECK_INTERVAL_MS);
tokio::task::spawn(async move {
let mut active_writer = writer;
let mut reconnect_buffer: VecDeque<Bytes> = VecDeque::new();
let mut write_buf: Vec<u8> = Vec::new();
loop {
if matches!(
ConnectionMode::from_atomic(&connection_state),
ConnectionMode::Disconnect | ConnectionMode::Closed
) {
break;
}
match tokio::time::timeout(check_interval, writer_rx.recv()).await {
Ok(Some(msg)) => {
let mode = ConnectionMode::from_atomic(&connection_state);
if matches!(mode, ConnectionMode::Disconnect | ConnectionMode::Closed) {
break;
}
match msg {
WriterCommand::Update(new_writer, tx) => {
log::debug!("Received new writer");
tokio::time::sleep(Duration::from_millis(100)).await;
_ = tokio::time::timeout(
Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS),
active_writer.shutdown(),
)
.await;
active_writer = new_writer;
log::debug!("Updated writer");
let send_error = Self::drain_reconnect_buffer(
&mut reconnect_buffer,
&mut active_writer,
&suffix,
)
.await;
if let Err(e) = tx.send(!send_error) {
log::error!(
"Failed to report drain status to controller: {e:?}"
);
}
}
_ if mode.is_reconnect() => {
if let WriterCommand::Send(data) = msg {
log::debug!(
"Buffering message while reconnecting ({} bytes)",
data.len()
);
reconnect_buffer.push_back(data);
}
}
WriterCommand::Send(msg) => {
write_buf.clear();
write_buf.extend_from_slice(&msg);
write_buf.extend_from_slice(&suffix);
if let Err(e) = active_writer.write_all(&write_buf).await {
log::error!("Failed to send message: {e}");
log::warn!("Writer triggering reconnect");
reconnect_buffer.push_back(msg);
connection_state
.store(ConnectionMode::Reconnect.as_u8(), Ordering::SeqCst);
state_notify.notify_one();
}
}
}
}
Ok(None) => {
log::debug!("Writer channel closed, terminating writer task");
break;
}
Err(_) => {
}
}
}
_ = tokio::time::timeout(
Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS),
active_writer.shutdown(),
)
.await;
log_task_stopped("write");
})
}
fn spawn_heartbeat_task(
connection_state: Arc<AtomicU8>,
heartbeat: (u64, Vec<u8>),
writer_tx: tokio::sync::mpsc::UnboundedSender<WriterCommand>,
) -> tokio::task::JoinHandle<()> {
log_task_started("heartbeat");
let (interval_secs, message) = heartbeat;
tokio::task::spawn(async move {
let interval = Duration::from_secs(interval_secs);
loop {
tokio::time::sleep(interval).await;
match ConnectionMode::from_u8(connection_state.load(Ordering::SeqCst)) {
ConnectionMode::Active => {
let msg = WriterCommand::Send(message.clone().into());
match writer_tx.send(msg) {
Ok(()) => log::trace!("Sent heartbeat to writer task"),
Err(e) => {
log::error!("Failed to send heartbeat to writer task: {e}");
}
}
}
ConnectionMode::Reconnect => {}
ConnectionMode::Disconnect | ConnectionMode::Closed => break,
}
}
log_task_stopped("heartbeat");
})
}
}
impl Drop for SocketClientInner {
fn drop(&mut self) {
self.clean_drop();
}
}
impl CleanDrop for SocketClientInner {
fn clean_drop(&mut self) {
if !self.read_task.is_finished() {
self.read_task.abort();
log_task_aborted("read");
}
if !self.write_task.is_finished() {
self.write_task.abort();
log_task_aborted("write");
}
if let Some(ref handle) = self.heartbeat_task.take()
&& !handle.is_finished()
{
handle.abort();
log_task_aborted("heartbeat");
}
#[cfg(feature = "python")]
{
self.config.message_handler = None;
}
}
}
#[cfg_attr(
feature = "python",
pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.network")
)]
#[cfg_attr(
feature = "python",
pyo3_stub_gen::derive::gen_stub_pyclass(module = "nautilus_trader.network")
)]
pub struct SocketClient {
pub(crate) controller_task: tokio::task::JoinHandle<()>,
pub(crate) connection_mode: Arc<AtomicU8>,
pub(crate) state_notify: Arc<tokio::sync::Notify>,
pub(crate) reconnect_timeout: Duration,
pub writer_tx: tokio::sync::mpsc::UnboundedSender<WriterCommand>,
}
impl Debug for SocketClient {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct(stringify!(SocketClient)).finish()
}
}
impl SocketClient {
pub async fn connect(
config: SocketConfig,
post_connection: Option<Arc<dyn Fn() + Send + Sync>>,
post_reconnection: Option<Arc<dyn Fn() + Send + Sync>>,
post_disconnection: Option<Arc<dyn Fn() + Send + Sync>>,
) -> anyhow::Result<Self> {
let inner = SocketClientInner::connect_url(config).await?;
let writer_tx = inner.writer_tx.clone();
let connection_mode = inner.connection_mode.clone();
let state_notify = inner.state_notify.clone();
let reconnect_timeout = inner.reconnect_timeout;
let controller_task = Self::spawn_controller_task(
inner,
connection_mode.clone(),
state_notify.clone(),
post_reconnection,
post_disconnection,
);
if let Some(handler) = post_connection {
handler();
log::debug!("Called `post_connection` handler");
}
Ok(Self {
controller_task,
connection_mode,
state_notify,
reconnect_timeout,
writer_tx,
})
}
#[must_use]
pub fn connection_mode(&self) -> ConnectionMode {
ConnectionMode::from_atomic(&self.connection_mode)
}
#[inline]
#[must_use]
pub fn is_active(&self) -> bool {
self.connection_mode().is_active()
}
#[inline]
#[must_use]
pub fn is_reconnecting(&self) -> bool {
self.connection_mode().is_reconnect()
}
#[inline]
#[must_use]
pub fn is_disconnecting(&self) -> bool {
self.connection_mode().is_disconnect()
}
#[inline]
#[must_use]
pub fn is_closed(&self) -> bool {
self.connection_mode().is_closed()
}
pub async fn close(&self) {
self.connection_mode
.store(ConnectionMode::Disconnect.as_u8(), Ordering::SeqCst);
self.state_notify.notify_waiters();
if tokio::time::timeout(Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS), async {
while !self.is_closed() {
tokio::time::sleep(Duration::from_millis(CONNECTION_STATE_CHECK_INTERVAL_MS)).await;
}
if !self.controller_task.is_finished() {
self.controller_task.abort();
log_task_aborted("controller");
}
})
.await
== Ok(())
{
log_task_stopped("controller");
} else {
log::error!("Timeout waiting for controller task to finish");
if !self.controller_task.is_finished() {
self.controller_task.abort();
log_task_aborted("controller");
}
self.connection_mode
.store(ConnectionMode::Closed.as_u8(), Ordering::SeqCst);
}
}
#[inline]
fn check_not_terminal(&self) -> Result<(), SendError> {
match self.connection_mode() {
ConnectionMode::Disconnect | ConnectionMode::Closed => Err(SendError::Closed),
_ => Ok(()),
}
}
async fn wait_for_active(&self) -> Result<(), SendError> {
const FALLBACK_INTERVAL_MS: u64 = 100;
let mode = self.connection_mode();
if mode.is_active() {
return Ok(());
}
if matches!(mode, ConnectionMode::Disconnect | ConnectionMode::Closed) {
return Err(SendError::Closed);
}
log::debug!("Waiting for client to become ACTIVE before sending...");
let fallback_interval = Duration::from_millis(FALLBACK_INTERVAL_MS);
tokio::time::timeout(self.reconnect_timeout, async {
loop {
let notified = self.state_notify.notified();
let mode = self.connection_mode();
if mode.is_active() {
return Ok(());
}
if matches!(mode, ConnectionMode::Disconnect | ConnectionMode::Closed) {
return Err(());
}
tokio::select! {
() = notified => {}
() = tokio::time::sleep(fallback_interval) => {}
}
}
})
.await
.map_err(|_| SendError::Timeout)?
.map_err(|()| SendError::Closed)
}
pub async fn send_bytes(&self, data: Vec<u8>) -> Result<(), SendError> {
self.check_not_terminal()?;
self.wait_for_active().await?;
let msg = WriterCommand::Send(data.into());
self.writer_tx
.send(msg)
.map_err(|e| SendError::BrokenPipe(e.to_string()))
}
fn spawn_controller_task(
mut inner: SocketClientInner,
connection_mode: Arc<AtomicU8>,
state_notify: Arc<tokio::sync::Notify>,
post_reconnection: Option<Arc<dyn Fn() + Send + Sync>>,
post_disconnection: Option<Arc<dyn Fn() + Send + Sync>>,
) -> tokio::task::JoinHandle<()> {
const CONTROLLER_FALLBACK_INTERVAL_MS: u64 = 100;
tokio::task::spawn(async move {
log_task_started("controller");
let fallback_interval = Duration::from_millis(CONTROLLER_FALLBACK_INTERVAL_MS);
loop {
tokio::select! {
() = state_notify.notified() => {}
() = tokio::time::sleep(fallback_interval) => {}
}
let mut mode = ConnectionMode::from_atomic(&connection_mode);
if mode.is_disconnect() {
log::debug!("Disconnecting");
let timeout = Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS);
if tokio::time::timeout(timeout, async {
tokio::time::sleep(Duration::from_millis(GRACEFUL_SHUTDOWN_DELAY_MS)).await;
if !inner.read_task.is_finished() {
inner.read_task.abort();
log_task_aborted("read");
}
if let Some(task) = &inner.heartbeat_task
&& !task.is_finished()
{
task.abort();
log_task_aborted("heartbeat");
}
})
.await
.is_err()
{
log::error!("Shutdown timed out after {}s", timeout.as_secs());
}
log::debug!("Closed");
if let Some(ref handler) = post_disconnection {
handler();
log::debug!("Called `post_disconnection` handler");
}
break; }
if mode.is_closed() {
log::debug!("Connection closed");
break;
}
if mode.is_active() && !inner.is_alive() {
if connection_mode
.compare_exchange(
ConnectionMode::Active.as_u8(),
ConnectionMode::Reconnect.as_u8(),
Ordering::SeqCst,
Ordering::SeqCst,
)
.is_ok()
{
log::debug!("Detected dead read task, transitioning to RECONNECT");
}
mode = ConnectionMode::from_atomic(&connection_mode);
}
if mode.is_reconnect() {
if let Some(max_attempts) = inner.reconnect_max_attempts
&& inner.reconnect_attempt_count >= max_attempts
{
log::error!(
"Max reconnection attempts ({max_attempts}) exceeded, transitioning to CLOSED"
);
connection_mode.store(ConnectionMode::Closed.as_u8(), Ordering::SeqCst);
state_notify.notify_waiters();
break;
}
inner.reconnect_attempt_count += 1;
let reconnect_result = tokio::select! {
result = inner.reconnect() => Some(result),
() = async {
loop {
state_notify.notified().await;
if ConnectionMode::from_atomic(&connection_mode).is_disconnect() {
break;
}
}
} => None,
};
match reconnect_result {
None => {
log::debug!("Reconnect interrupted by disconnect");
}
Some(Ok(())) => {
log::debug!("Reconnected successfully");
inner.backoff.reset();
inner.reconnect_attempt_count = 0;
state_notify.notify_waiters();
if ConnectionMode::from_atomic(&connection_mode).is_active() {
if let Some(ref handler) = post_reconnection {
handler();
log::debug!("Called `post_reconnection` handler");
}
} else {
log::debug!(
"Skipping post_reconnection handlers due to disconnect state"
);
}
}
Some(Err(e)) => {
let duration = inner.backoff.next_duration();
log::warn!(
"Reconnect attempt {} failed: {e}",
inner.reconnect_attempt_count
);
if !duration.is_zero() {
log::warn!("Backing off for {}s...", duration.as_secs_f64());
tokio::select! {
() = tokio::time::sleep(duration) => {}
() = async {
loop {
state_notify.notified().await;
if ConnectionMode::from_atomic(&connection_mode).is_disconnect() {
break;
}
}
} => {
log::debug!("Backoff interrupted by disconnect");
}
}
}
}
}
}
}
inner
.connection_mode
.store(ConnectionMode::Closed.as_u8(), Ordering::SeqCst);
log_task_stopped("controller");
})
}
}
impl Drop for SocketClient {
fn drop(&mut self) {
if !self.controller_task.is_finished() {
self.controller_task.abort();
log_task_aborted("controller");
}
}
}
#[cfg(test)]
#[cfg(feature = "python")]
#[cfg(target_os = "linux")] mod tests {
use nautilus_common::testing::wait_until_async;
use pyo3::Python;
use tokio::{
io::{AsyncReadExt, AsyncWriteExt},
net::{TcpListener, TcpStream},
sync::Mutex,
task,
time::{Duration, sleep},
};
use super::*;
async fn bind_test_server() -> (u16, TcpListener) {
let listener = TcpListener::bind("127.0.0.1:0")
.await
.expect("Failed to bind ephemeral port");
let port = listener.local_addr().unwrap().port();
(port, listener)
}
async fn run_echo_server(mut socket: TcpStream) {
let mut buf = Vec::new();
loop {
match socket.read_buf(&mut buf).await {
Ok(0) => {
break;
}
Ok(_n) => {
while let Some(idx) = buf.array_windows().position(|w| w == b"\r\n") {
let mut line = buf.drain(..idx + 2).collect::<Vec<u8>>();
line.truncate(line.len() - 2);
if line == b"close" {
let _ = socket.shutdown().await;
return;
}
let mut echo_data = line;
echo_data.extend_from_slice(b"\r\n");
if socket.write_all(&echo_data).await.is_err() {
break;
}
}
}
Err(e) => {
eprintln!("Server read error: {e}");
break;
}
}
}
}
#[tokio::test]
async fn test_basic_send_receive() {
Python::initialize();
let (port, listener) = bind_test_server().await;
let server_task = task::spawn(async move {
let (socket, _) = listener.accept().await.unwrap();
run_echo_server(socket).await;
});
let config = SocketConfig {
url: format!("127.0.0.1:{port}"),
mode: Mode::Plain,
suffix: b"\r\n".to_vec(),
message_handler: None,
heartbeat: None,
reconnect_timeout_ms: None,
reconnect_delay_initial_ms: None,
reconnect_backoff_factor: None,
reconnect_delay_max_ms: None,
reconnect_jitter_ms: None,
reconnect_max_attempts: None,
connection_max_retries: None,
idle_timeout_ms: None,
certs_dir: None,
};
let client = SocketClient::connect(config, None, None, None)
.await
.expect("Client connect failed unexpectedly");
client.send_bytes(b"Hello".into()).await.unwrap();
client.send_bytes(b"World".into()).await.unwrap();
sleep(Duration::from_millis(100)).await;
client.send_bytes(b"close".into()).await.unwrap();
server_task.await.unwrap();
assert!(!client.is_closed());
}
#[tokio::test]
async fn test_reconnect_fail_exhausted() {
Python::initialize();
let (port, listener) = bind_test_server().await;
drop(listener);
wait_until_async(
|| async {
TcpStream::connect(format!("127.0.0.1:{port}"))
.await
.is_err()
},
Duration::from_secs(2),
)
.await;
let config = SocketConfig {
url: format!("127.0.0.1:{port}"),
mode: Mode::Plain,
suffix: b"\r\n".to_vec(),
message_handler: None,
heartbeat: None,
reconnect_timeout_ms: Some(100),
reconnect_delay_initial_ms: Some(50),
reconnect_backoff_factor: Some(1.0),
reconnect_delay_max_ms: Some(50),
reconnect_jitter_ms: Some(0),
connection_max_retries: Some(1),
reconnect_max_attempts: None,
idle_timeout_ms: None,
certs_dir: None,
};
let client_res = SocketClient::connect(config, None, None, None).await;
assert!(
client_res.is_err(),
"Should fail quickly with no server listening"
);
}
#[tokio::test]
async fn test_user_disconnect() {
Python::initialize();
let (port, listener) = bind_test_server().await;
let server_task = task::spawn(async move {
let (socket, _) = listener.accept().await.unwrap();
let mut buf = [0u8; 1024];
let _ = socket.try_read(&mut buf);
loop {
sleep(Duration::from_secs(1)).await;
}
});
let config = SocketConfig {
url: format!("127.0.0.1:{port}"),
mode: Mode::Plain,
suffix: b"\r\n".to_vec(),
message_handler: None,
heartbeat: None,
reconnect_timeout_ms: None,
reconnect_delay_initial_ms: None,
reconnect_backoff_factor: None,
reconnect_delay_max_ms: None,
reconnect_jitter_ms: None,
reconnect_max_attempts: None,
connection_max_retries: None,
idle_timeout_ms: None,
certs_dir: None,
};
let client = SocketClient::connect(config, None, None, None)
.await
.unwrap();
client.close().await;
assert!(client.is_closed());
server_task.abort();
}
#[tokio::test]
async fn test_heartbeat() {
Python::initialize();
let (port, listener) = bind_test_server().await;
let received = Arc::new(Mutex::new(Vec::new()));
let received2 = received.clone();
let server_task = task::spawn(async move {
let (socket, _) = listener.accept().await.unwrap();
let mut buf = Vec::new();
loop {
match socket.try_read_buf(&mut buf) {
Ok(0) => break,
Ok(_) => {
while let Some(idx) = buf.array_windows().position(|w| w == b"\r\n") {
let mut line = buf.drain(..idx + 2).collect::<Vec<u8>>();
line.truncate(line.len() - 2);
received2.lock().await.push(line);
}
}
Err(_) => {
tokio::time::sleep(Duration::from_millis(10)).await;
}
}
}
});
let heartbeat = Some((1, b"ping".to_vec()));
let config = SocketConfig {
url: format!("127.0.0.1:{port}"),
mode: Mode::Plain,
suffix: b"\r\n".to_vec(),
message_handler: None,
heartbeat,
reconnect_timeout_ms: None,
reconnect_delay_initial_ms: None,
reconnect_backoff_factor: None,
reconnect_delay_max_ms: None,
reconnect_jitter_ms: None,
reconnect_max_attempts: None,
connection_max_retries: None,
idle_timeout_ms: None,
certs_dir: None,
};
let client = SocketClient::connect(config, None, None, None)
.await
.unwrap();
sleep(Duration::from_secs(3)).await;
{
let lock = received.lock().await;
let pings = lock
.iter()
.filter(|line| line == &&b"ping".to_vec())
.count();
assert!(
pings >= 2,
"Expected at least 2 heartbeat pings; got {pings}"
);
}
client.close().await;
server_task.abort();
}
#[tokio::test]
async fn test_reconnect_success() {
Python::initialize();
let (port, listener) = bind_test_server().await;
let server_task = task::spawn(async move {
let (mut socket, _) = listener.accept().await.expect("First accept failed");
sleep(Duration::from_millis(500)).await;
let _ = socket.shutdown().await;
sleep(Duration::from_millis(500)).await;
let (socket, _) = listener.accept().await.expect("Second accept failed");
run_echo_server(socket).await;
});
let config = SocketConfig {
url: format!("127.0.0.1:{port}"),
mode: Mode::Plain,
suffix: b"\r\n".to_vec(),
message_handler: None,
heartbeat: None,
reconnect_timeout_ms: Some(5_000),
reconnect_delay_initial_ms: Some(500),
reconnect_delay_max_ms: Some(5_000),
reconnect_backoff_factor: Some(2.0),
reconnect_jitter_ms: Some(50),
reconnect_max_attempts: None,
connection_max_retries: None,
idle_timeout_ms: None,
certs_dir: None,
};
let client = SocketClient::connect(config, None, None, None)
.await
.expect("Client connect failed unexpectedly");
assert!(client.is_active(), "Client should start as active");
wait_until_async(|| async { client.is_active() }, Duration::from_secs(10)).await;
client
.send_bytes(b"TestReconnect".into())
.await
.expect("Send failed");
client.close().await;
server_task.abort();
}
}
#[cfg(test)]
#[cfg(not(feature = "turmoil"))]
mod rust_tests {
use nautilus_common::testing::wait_until_async;
use rstest::rstest;
use tokio::{
io::{AsyncReadExt, AsyncWriteExt},
net::TcpListener,
task,
time::{Duration, sleep},
};
use super::*;
#[rstest]
#[tokio::test]
async fn test_reconnect_then_close() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let port = listener.local_addr().unwrap().port();
let server = task::spawn(async move {
if let Ok((mut sock, _)) = listener.accept().await {
drop(sock.shutdown());
}
sleep(Duration::from_secs(1)).await;
});
let config = SocketConfig {
url: format!("127.0.0.1:{port}"),
mode: Mode::Plain,
suffix: b"\r\n".to_vec(),
message_handler: None,
heartbeat: None,
reconnect_timeout_ms: Some(1_000),
reconnect_delay_initial_ms: Some(50),
reconnect_delay_max_ms: Some(100),
reconnect_backoff_factor: Some(1.0),
reconnect_jitter_ms: Some(0),
connection_max_retries: Some(1),
reconnect_max_attempts: None,
idle_timeout_ms: None,
certs_dir: None,
};
let client = SocketClient::connect(config.clone(), None, None, None)
.await
.unwrap();
wait_until_async(
|| async { client.is_reconnecting() },
Duration::from_secs(2),
)
.await;
client.close().await;
assert!(client.is_closed());
server.abort();
}
#[rstest]
#[tokio::test]
async fn test_reconnect_state_flips_when_reader_stops() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let port = listener.local_addr().unwrap().port();
let server = task::spawn(async move {
if let Ok((sock, _)) = listener.accept().await {
drop(sock);
}
sleep(Duration::from_millis(50)).await;
});
let config = SocketConfig {
url: format!("127.0.0.1:{port}"),
mode: Mode::Plain,
suffix: b"\r\n".to_vec(),
message_handler: None,
heartbeat: None,
reconnect_timeout_ms: Some(1_000),
reconnect_delay_initial_ms: Some(50),
reconnect_delay_max_ms: Some(100),
reconnect_backoff_factor: Some(1.0),
reconnect_jitter_ms: Some(0),
connection_max_retries: Some(1),
reconnect_max_attempts: None,
idle_timeout_ms: None,
certs_dir: None,
};
let client = SocketClient::connect(config, None, None, None)
.await
.unwrap();
wait_until_async(
|| async { client.is_reconnecting() },
Duration::from_secs(2),
)
.await;
client.close().await;
server.abort();
}
#[rstest]
fn test_parse_socket_url_raw_address() {
let (socket_addr, request_url) =
SocketClientInner::parse_socket_url("example.com:6130", Mode::Tls).unwrap();
assert_eq!(socket_addr, "example.com:6130");
assert_eq!(request_url, "wss://example.com:6130");
let (socket_addr, request_url) =
SocketClientInner::parse_socket_url("localhost:8080", Mode::Plain).unwrap();
assert_eq!(socket_addr, "localhost:8080");
assert_eq!(request_url, "ws://localhost:8080");
}
#[rstest]
fn test_parse_socket_url_with_scheme() {
let (socket_addr, request_url) =
SocketClientInner::parse_socket_url("wss://example.com:443/path", Mode::Tls).unwrap();
assert_eq!(socket_addr, "example.com:443");
assert_eq!(request_url, "wss://example.com:443/path");
let (socket_addr, request_url) =
SocketClientInner::parse_socket_url("ws://localhost:8080", Mode::Plain).unwrap();
assert_eq!(socket_addr, "localhost:8080");
assert_eq!(request_url, "ws://localhost:8080");
}
#[rstest]
fn test_parse_socket_url_default_ports() {
let (socket_addr, _) =
SocketClientInner::parse_socket_url("wss://example.com", Mode::Tls).unwrap();
assert_eq!(socket_addr, "example.com:443");
let (socket_addr, _) =
SocketClientInner::parse_socket_url("ws://example.com", Mode::Plain).unwrap();
assert_eq!(socket_addr, "example.com:80");
let (socket_addr, _) =
SocketClientInner::parse_socket_url("https://example.com", Mode::Tls).unwrap();
assert_eq!(socket_addr, "example.com:443");
let (socket_addr, _) =
SocketClientInner::parse_socket_url("http://example.com", Mode::Plain).unwrap();
assert_eq!(socket_addr, "example.com:80");
}
#[rstest]
fn test_parse_socket_url_unknown_scheme_uses_mode() {
let (socket_addr, _) =
SocketClientInner::parse_socket_url("custom://example.com", Mode::Tls).unwrap();
assert_eq!(socket_addr, "example.com:443");
let (socket_addr, _) =
SocketClientInner::parse_socket_url("custom://example.com", Mode::Plain).unwrap();
assert_eq!(socket_addr, "example.com:80");
}
#[rstest]
fn test_parse_socket_url_ipv6() {
let (socket_addr, request_url) =
SocketClientInner::parse_socket_url("[::1]:8080", Mode::Plain).unwrap();
assert_eq!(socket_addr, "[::1]:8080");
assert_eq!(request_url, "ws://[::1]:8080");
let (socket_addr, _) =
SocketClientInner::parse_socket_url("ws://[::1]:8080", Mode::Plain).unwrap();
assert_eq!(socket_addr, "[::1]:8080");
}
#[rstest]
#[tokio::test]
async fn test_url_parsing_raw_socket_address() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let port = listener.local_addr().unwrap().port();
let server = task::spawn(async move {
if let Ok((sock, _)) = listener.accept().await {
drop(sock);
}
sleep(Duration::from_millis(50)).await;
});
let config = SocketConfig {
url: format!("127.0.0.1:{port}"), mode: Mode::Plain,
suffix: b"\r\n".to_vec(),
message_handler: None,
heartbeat: None,
reconnect_timeout_ms: Some(1_000),
reconnect_delay_initial_ms: Some(50),
reconnect_delay_max_ms: Some(100),
reconnect_backoff_factor: Some(1.0),
reconnect_jitter_ms: Some(0),
connection_max_retries: Some(1),
reconnect_max_attempts: None,
idle_timeout_ms: None,
certs_dir: None,
};
let client = SocketClient::connect(config, None, None, None).await;
assert!(
client.is_ok(),
"Client should connect with raw socket address format"
);
if let Ok(client) = client {
client.close().await;
}
server.abort();
}
#[rstest]
#[tokio::test]
async fn test_url_parsing_with_scheme() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let port = listener.local_addr().unwrap().port();
let server = task::spawn(async move {
if let Ok((sock, _)) = listener.accept().await {
drop(sock);
}
sleep(Duration::from_millis(50)).await;
});
let config = SocketConfig {
url: format!("ws://127.0.0.1:{port}"), mode: Mode::Plain,
suffix: b"\r\n".to_vec(),
message_handler: None,
heartbeat: None,
reconnect_timeout_ms: Some(1_000),
reconnect_delay_initial_ms: Some(50),
reconnect_delay_max_ms: Some(100),
reconnect_backoff_factor: Some(1.0),
reconnect_jitter_ms: Some(0),
connection_max_retries: Some(1),
reconnect_max_attempts: None,
idle_timeout_ms: None,
certs_dir: None,
};
let client = SocketClient::connect(config, None, None, None).await;
assert!(
client.is_ok(),
"Client should connect with URL scheme format"
);
if let Ok(client) = client {
client.close().await;
}
server.abort();
}
#[rstest]
fn test_parse_socket_url_ipv6_with_zone() {
let (socket_addr, request_url) =
SocketClientInner::parse_socket_url("[fe80::1%eth0]:8080", Mode::Plain).unwrap();
assert_eq!(socket_addr, "[fe80::1%eth0]:8080");
assert_eq!(request_url, "ws://[fe80::1%eth0]:8080");
let (socket_addr, request_url) =
SocketClientInner::parse_socket_url("ws://[fe80::1%lo]:9090", Mode::Plain).unwrap();
assert_eq!(socket_addr, "[fe80::1%lo]:9090");
assert_eq!(request_url, "ws://[fe80::1%lo]:9090");
}
#[rstest]
#[tokio::test]
async fn test_ipv6_loopback_connection() {
if TcpListener::bind("[::1]:0").await.is_err() {
eprintln!("IPv6 not available, skipping test");
return;
}
let listener = TcpListener::bind("[::1]:0").await.unwrap();
let port = listener.local_addr().unwrap().port();
let server = task::spawn(async move {
if let Ok((mut sock, _)) = listener.accept().await {
let mut buf = vec![0u8; 1024];
if let Ok(n) = sock.read(&mut buf).await {
let _ = sock.write_all(&buf[..n]).await;
}
}
sleep(Duration::from_millis(50)).await;
});
let config = SocketConfig {
url: format!("[::1]:{port}"), mode: Mode::Plain,
suffix: b"\r\n".to_vec(),
message_handler: None,
heartbeat: None,
reconnect_timeout_ms: Some(1_000),
reconnect_delay_initial_ms: Some(50),
reconnect_delay_max_ms: Some(100),
reconnect_backoff_factor: Some(1.0),
reconnect_jitter_ms: Some(0),
connection_max_retries: Some(1),
reconnect_max_attempts: None,
idle_timeout_ms: None,
certs_dir: None,
};
let client = SocketClient::connect(config, None, None, None).await;
assert!(
client.is_ok(),
"Client should connect to IPv6 loopback address"
);
if let Ok(client) = client {
client.close().await;
}
server.abort();
}
#[rstest]
#[tokio::test]
async fn test_send_waits_during_reconnection() {
use nautilus_common::testing::wait_until_async;
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let port = listener.local_addr().unwrap().port();
let server = task::spawn(async move {
if let Ok((sock, _)) = listener.accept().await {
drop(sock);
}
sleep(Duration::from_millis(500)).await;
if let Ok((mut sock, _)) = listener.accept().await {
let mut buf = vec![0u8; 1024];
while let Ok(n) = sock.read(&mut buf).await {
if n == 0 {
break;
}
if sock.write_all(&buf[..n]).await.is_err() {
break;
}
}
}
});
let config = SocketConfig {
url: format!("127.0.0.1:{port}"),
mode: Mode::Plain,
suffix: b"\r\n".to_vec(),
message_handler: None,
heartbeat: None,
reconnect_timeout_ms: Some(5_000), reconnect_delay_initial_ms: Some(100),
reconnect_delay_max_ms: Some(200),
reconnect_backoff_factor: Some(1.0),
reconnect_jitter_ms: Some(0),
connection_max_retries: Some(1),
reconnect_max_attempts: None,
idle_timeout_ms: None,
certs_dir: None,
};
let client = SocketClient::connect(config, None, None, None)
.await
.unwrap();
wait_until_async(
|| async { client.is_reconnecting() },
Duration::from_secs(2),
)
.await;
let send_result = tokio::time::timeout(
Duration::from_secs(3),
client.send_bytes(b"test_message".to_vec()),
)
.await;
assert!(
send_result.is_ok() && send_result.unwrap().is_ok(),
"Send should succeed after waiting for reconnection"
);
client.close().await;
server.abort();
}
#[rstest]
#[tokio::test]
async fn test_send_bytes_timeout_uses_configured_reconnect_timeout() {
use nautilus_common::testing::wait_until_async;
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let port = listener.local_addr().unwrap().port();
let server = task::spawn(async move {
if let Ok((sock, _)) = listener.accept().await {
drop(sock);
}
drop(listener);
sleep(Duration::from_secs(60)).await;
});
let config = SocketConfig {
url: format!("127.0.0.1:{port}"),
mode: Mode::Plain,
suffix: b"\r\n".to_vec(),
message_handler: None,
heartbeat: None,
reconnect_timeout_ms: Some(1_000), reconnect_delay_initial_ms: Some(200), reconnect_delay_max_ms: Some(200),
reconnect_backoff_factor: Some(1.0),
reconnect_jitter_ms: Some(0),
connection_max_retries: Some(1),
reconnect_max_attempts: None,
idle_timeout_ms: None,
certs_dir: None,
};
let client = SocketClient::connect(config, None, None, None)
.await
.unwrap();
wait_until_async(
|| async { client.is_reconnecting() },
Duration::from_secs(3),
)
.await;
let start = std::time::Instant::now();
let send_result = client.send_bytes(b"test".to_vec()).await;
let elapsed = start.elapsed();
assert!(
send_result.is_err(),
"Send should fail when client stuck in RECONNECT, was: {send_result:?}"
);
assert!(
matches!(send_result, Err(crate::error::SendError::Timeout)),
"Send should return Timeout error, was: {send_result:?}"
);
assert!(
elapsed >= Duration::from_millis(900),
"Send should timeout after at least 1s (configured timeout), took {elapsed:?}"
);
client.close().await;
server.abort();
}
#[rstest]
#[tokio::test]
async fn test_idle_timeout_triggers_reconnect() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let port = listener.local_addr().unwrap().port();
let server = task::spawn(async move {
let (_sock1, _) = listener.accept().await.unwrap();
sleep(Duration::from_secs(5)).await;
});
let config = SocketConfig {
url: format!("127.0.0.1:{port}"),
mode: Mode::Plain,
suffix: b"\r\n".to_vec(),
message_handler: None,
heartbeat: None,
reconnect_timeout_ms: Some(2_000),
reconnect_delay_initial_ms: Some(50),
reconnect_delay_max_ms: Some(100),
reconnect_backoff_factor: Some(1.0),
reconnect_jitter_ms: Some(0),
connection_max_retries: Some(1),
reconnect_max_attempts: Some(1),
idle_timeout_ms: Some(500),
certs_dir: None,
};
let client = SocketClient::connect(config, None, None, None)
.await
.unwrap();
assert!(client.is_active());
wait_until_async(
|| async { client.is_reconnecting() || client.is_closed() },
Duration::from_secs(3),
)
.await;
assert!(
!client.is_active(),
"Client should not be active after idle timeout"
);
client.close().await;
server.abort();
}
#[rstest]
#[tokio::test]
async fn test_idle_timeout_resets_on_data() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let port = listener.local_addr().unwrap().port();
let server = task::spawn(async move {
let (mut sock, _) = listener.accept().await.unwrap();
for _ in 0..10 {
sleep(Duration::from_millis(200)).await;
if sock.write_all(b"ping\r\n").await.is_err() {
break;
}
}
});
let config = SocketConfig {
url: format!("127.0.0.1:{port}"),
mode: Mode::Plain,
suffix: b"\r\n".to_vec(),
message_handler: None,
heartbeat: None,
reconnect_timeout_ms: Some(2_000),
reconnect_delay_initial_ms: Some(50),
reconnect_delay_max_ms: Some(100),
reconnect_backoff_factor: Some(1.0),
reconnect_jitter_ms: Some(0),
connection_max_retries: Some(1),
reconnect_max_attempts: Some(1),
idle_timeout_ms: Some(1_000),
certs_dir: None,
};
let client = SocketClient::connect(config, None, None, None)
.await
.unwrap();
assert!(client.is_active());
sleep(Duration::from_millis(1_500)).await;
assert!(
client.is_active(),
"Client should remain active when data is flowing"
);
client.close().await;
server.abort();
}
#[rstest]
#[tokio::test]
async fn test_close_during_backoff_exits_promptly() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let port = listener.local_addr().unwrap().port();
let server = task::spawn(async move {
if let Ok((mut sock, _)) = listener.accept().await {
drop(sock.shutdown());
}
sleep(Duration::from_secs(60)).await;
});
let config = SocketConfig {
url: format!("127.0.0.1:{port}"),
mode: Mode::Plain,
suffix: b"\r\n".to_vec(),
message_handler: None,
heartbeat: None,
reconnect_timeout_ms: Some(1_000),
reconnect_delay_initial_ms: Some(10_000), reconnect_delay_max_ms: Some(10_000),
reconnect_backoff_factor: Some(1.0),
reconnect_jitter_ms: Some(0),
connection_max_retries: None,
reconnect_max_attempts: None,
idle_timeout_ms: None,
certs_dir: None,
};
let client = SocketClient::connect(config, None, None, None)
.await
.unwrap();
wait_until_async(
|| async { client.is_reconnecting() },
Duration::from_secs(3),
)
.await;
sleep(Duration::from_millis(1_500)).await;
let start = std::time::Instant::now();
client.close().await;
let elapsed = start.elapsed();
assert!(client.is_closed(), "Client should be closed");
assert!(
elapsed < Duration::from_secs(2),
"Close should interrupt backoff sleep, took {elapsed:?}"
);
server.abort();
}
#[rstest]
#[tokio::test]
async fn test_zero_idle_timeout_rejected() {
let config = SocketConfig {
url: "127.0.0.1:9999".to_string(),
mode: Mode::Plain,
suffix: b"\r\n".to_vec(),
message_handler: None,
heartbeat: None,
reconnect_timeout_ms: None,
reconnect_delay_initial_ms: None,
reconnect_delay_max_ms: None,
reconnect_backoff_factor: None,
reconnect_jitter_ms: None,
reconnect_max_attempts: None,
connection_max_retries: Some(1),
idle_timeout_ms: Some(0),
certs_dir: None,
};
let result = SocketClient::connect(config, None, None, None).await;
assert!(result.is_err(), "Zero idle timeout should be rejected");
let err_msg = result.unwrap_err().to_string();
assert!(
err_msg.contains("Idle timeout cannot be zero"),
"Error should mention zero idle timeout, was: {err_msg}"
);
}
#[rstest]
#[tokio::test]
async fn test_empty_suffix_rejected() {
let config = SocketConfig {
url: "127.0.0.1:9999".to_string(),
mode: Mode::Plain,
suffix: vec![],
message_handler: None,
heartbeat: None,
reconnect_timeout_ms: None,
reconnect_delay_initial_ms: None,
reconnect_delay_max_ms: None,
reconnect_backoff_factor: None,
reconnect_jitter_ms: None,
reconnect_max_attempts: None,
connection_max_retries: Some(1),
idle_timeout_ms: None,
certs_dir: None,
};
let result = SocketClient::connect(config, None, None, None).await;
assert!(
result.is_err(),
"Empty suffix should cause connection to fail"
);
let err_msg = result.unwrap_err().to_string();
assert!(
err_msg.contains("suffix cannot be empty"),
"Error should mention empty suffix, was: {err_msg}"
);
}
}