use std::fmt::Debug;
use tokio_tungstenite::tungstenite::stream::Mode;
use super::types::TcpMessageHandler;
use crate::error::{NetworkConfigError, NetworkConfigResult};
#[derive(bon::Builder)]
#[builder(finish_fn(name = build_inner, vis = ""))]
#[cfg_attr(
feature = "python",
pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.network", from_py_object)
)]
#[cfg_attr(
feature = "python",
pyo3_stub_gen::derive::gen_stub_pyclass(module = "nautilus_trader.network")
)]
pub struct SocketConfig {
pub url: String,
pub mode: Mode,
pub suffix: Vec<u8>,
pub message_handler: Option<TcpMessageHandler>,
pub heartbeat: Option<(u64, Vec<u8>)>,
pub reconnect_timeout_ms: Option<u64>,
pub reconnect_delay_initial_ms: Option<u64>,
pub reconnect_delay_max_ms: Option<u64>,
pub reconnect_backoff_factor: Option<f64>,
pub reconnect_jitter_ms: Option<u64>,
pub connection_max_retries: Option<u32>,
pub reconnect_max_attempts: Option<u32>,
pub idle_timeout_ms: Option<u64>,
pub certs_dir: Option<String>,
}
impl<S: socket_config_builder::IsComplete> SocketConfigBuilder<S> {
pub fn build(self) -> NetworkConfigResult<SocketConfig> {
let config = self.build_inner();
config.validate()?;
Ok(config)
}
}
impl SocketConfig {
pub fn validate(&self) -> NetworkConfigResult<()> {
let mut errors = Vec::new();
if self.url.trim().is_empty() {
errors.push(NetworkConfigError::invalid("url", "must not be empty"));
}
if let Some((interval, _)) = &self.heartbeat
&& *interval == 0
{
errors.push(NetworkConfigError::invalid(
"heartbeat",
"interval must be positive",
));
}
for (field, value) in [
("reconnect_timeout_ms", self.reconnect_timeout_ms),
(
"reconnect_delay_initial_ms",
self.reconnect_delay_initial_ms,
),
("reconnect_delay_max_ms", self.reconnect_delay_max_ms),
("idle_timeout_ms", self.idle_timeout_ms),
] {
if let Some(value) = value
&& value == 0
{
errors.push(NetworkConfigError::invalid(
field,
format!("must be positive, was {value}"),
));
}
}
if let Some(factor) = self.reconnect_backoff_factor
&& !(factor.is_finite() && factor >= 1.0)
{
errors.push(NetworkConfigError::invalid(
"reconnect_backoff_factor",
format!("must be finite and >= 1.0, was {factor}"),
));
}
if let (Some(initial), Some(max)) =
(self.reconnect_delay_initial_ms, self.reconnect_delay_max_ms)
&& initial > max
{
errors.push(NetworkConfigError::invalid(
"reconnect_delay_initial_ms",
format!("must not exceed reconnect_delay_max_ms ({max}), was {initial}"),
));
}
NetworkConfigError::collect(errors)
}
}
impl Debug for SocketConfig {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct(stringify!(SocketConfig))
.field("url", &self.url)
.field("mode", &self.mode)
.field("suffix", &self.suffix)
.field(
"message_handler",
&self.message_handler.as_ref().map(|_| "<function>"),
)
.field("heartbeat", &self.heartbeat)
.field("reconnect_timeout_ms", &self.reconnect_timeout_ms)
.field(
"reconnect_delay_initial_ms",
&self.reconnect_delay_initial_ms,
)
.field("reconnect_delay_max_ms", &self.reconnect_delay_max_ms)
.field("reconnect_backoff_factor", &self.reconnect_backoff_factor)
.field("reconnect_jitter_ms", &self.reconnect_jitter_ms)
.field("connection_max_retries", &self.connection_max_retries)
.field("reconnect_max_attempts", &self.reconnect_max_attempts)
.field("idle_timeout_ms", &self.idle_timeout_ms)
.field("certs_dir", &self.certs_dir)
.finish()
}
}
impl Clone for SocketConfig {
fn clone(&self) -> Self {
Self {
url: self.url.clone(),
mode: self.mode,
suffix: self.suffix.clone(),
message_handler: self.message_handler.clone(),
heartbeat: self.heartbeat.clone(),
reconnect_timeout_ms: self.reconnect_timeout_ms,
reconnect_delay_initial_ms: self.reconnect_delay_initial_ms,
reconnect_delay_max_ms: self.reconnect_delay_max_ms,
reconnect_backoff_factor: self.reconnect_backoff_factor,
reconnect_jitter_ms: self.reconnect_jitter_ms,
connection_max_retries: self.connection_max_retries,
reconnect_max_attempts: self.reconnect_max_attempts,
idle_timeout_ms: self.idle_timeout_ms,
certs_dir: self.certs_dir.clone(),
}
}
}
#[cfg(test)]
mod tests {
use rstest::rstest;
use tokio_tungstenite::tungstenite::stream::Mode;
use super::SocketConfig;
use crate::error::NetworkConfigError;
fn valid_config() -> SocketConfig {
SocketConfig::builder()
.url("tcp://127.0.0.1:8080".to_string())
.mode(Mode::Plain)
.suffix(vec![b'\n'])
.build()
.expect("baseline socket config should be valid")
}
#[rstest]
fn test_builder_accepts_valid_config() {
let result = SocketConfig::builder()
.url("tcp://127.0.0.1:8080".to_string())
.mode(Mode::Plain)
.suffix(vec![b'\n'])
.build();
assert!(result.is_ok());
}
#[rstest]
fn test_validate_accepts_zero_jitter() {
let mut config = valid_config();
config.reconnect_jitter_ms = Some(0);
assert!(config.validate().is_ok());
}
#[rstest]
#[case::empty_url(|c: &mut SocketConfig| c.url = String::new(), "url")]
#[case::heartbeat(|c: &mut SocketConfig| c.heartbeat = Some((0, vec![])), "heartbeat")]
#[case::reconnect_timeout(|c: &mut SocketConfig| c.reconnect_timeout_ms = Some(0), "reconnect_timeout_ms")]
#[case::reconnect_delay_initial(|c: &mut SocketConfig| c.reconnect_delay_initial_ms = Some(0), "reconnect_delay_initial_ms")]
#[case::reconnect_delay_max(|c: &mut SocketConfig| c.reconnect_delay_max_ms = Some(0), "reconnect_delay_max_ms")]
#[case::idle_timeout(|c: &mut SocketConfig| c.idle_timeout_ms = Some(0), "idle_timeout_ms")]
fn test_validate_rejects_invalid_field(
#[case] mutate: fn(&mut SocketConfig),
#[case] expected_field: &str,
) {
let mut config = valid_config();
mutate(&mut config);
let err = config
.validate()
.expect_err("invalid value should be rejected");
assert!(
matches!(err, NetworkConfigError::Invalid { field, .. } if field == expected_field)
);
}
#[rstest]
#[case::too_small(0.5)]
#[case::nan(f64::NAN)]
#[case::infinite(f64::INFINITY)]
fn test_validate_rejects_invalid_backoff_factor(#[case] factor: f64) {
let mut config = valid_config();
config.reconnect_backoff_factor = Some(factor);
let err = config
.validate()
.expect_err("invalid backoff factor should be rejected");
assert!(
matches!(err, NetworkConfigError::Invalid { field, .. } if field == "reconnect_backoff_factor")
);
}
#[rstest]
fn test_validate_rejects_delay_initial_exceeding_max() {
let mut config = valid_config();
config.reconnect_delay_initial_ms = Some(5_000);
config.reconnect_delay_max_ms = Some(1_000);
let err = config
.validate()
.expect_err("initial delay above max should be rejected");
assert!(
matches!(err, NetworkConfigError::Invalid { field, .. } if field == "reconnect_delay_initial_ms")
);
}
#[rstest]
fn test_validate_collects_multiple_errors() {
let mut config = valid_config();
config.url = String::new();
config.reconnect_timeout_ms = Some(0);
let err = config.validate().expect_err("multiple invalid fields");
match err {
NetworkConfigError::Multiple { errors } => assert_eq!(errors.len(), 2),
other @ NetworkConfigError::Invalid { .. } => {
panic!("expected Multiple, was {other:?}")
}
}
}
}