use std::fmt::Debug;
use serde::{Deserialize, Serialize};
use crate::error::{NetworkConfigError, NetworkConfigResult};
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum TransportBackend {
#[cfg_attr(not(feature = "transport-sockudo"), default)]
Tungstenite,
#[cfg_attr(feature = "transport-sockudo", default)]
Sockudo,
}
#[cfg_attr(
feature = "python",
pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.network", from_py_object)
)]
#[cfg_attr(
feature = "python",
pyo3_stub_gen::derive::gen_stub_pyclass(module = "nautilus_trader.network")
)]
#[allow(
clippy::unsafe_derive_deserialize,
reason = "PyO3-backed config still needs serde deserialization for strict config decoding"
)]
#[derive(Clone, Debug, Serialize, Deserialize, bon::Builder)]
#[builder(finish_fn(name = build_inner, vis = ""))]
#[serde(deny_unknown_fields)]
pub struct WebSocketConfig {
pub url: String,
#[serde(default)]
#[builder(default)]
pub headers: Vec<(String, String)>,
#[serde(default)]
pub heartbeat: Option<u64>,
#[serde(default)]
pub heartbeat_msg: Option<String>,
#[serde(default)]
pub reconnect_timeout_ms: Option<u64>,
#[serde(default)]
pub reconnect_delay_initial_ms: Option<u64>,
#[serde(default)]
pub reconnect_delay_max_ms: Option<u64>,
#[serde(default)]
pub reconnect_backoff_factor: Option<f64>,
#[serde(default)]
pub reconnect_jitter_ms: Option<u64>,
#[serde(default)]
pub reconnect_max_attempts: Option<u32>,
#[serde(default)]
pub idle_timeout_ms: Option<u64>,
#[serde(default)]
#[builder(default)]
pub backend: TransportBackend,
#[serde(default)]
pub proxy_url: Option<String>,
}
impl<S: web_socket_config_builder::IsComplete> WebSocketConfigBuilder<S> {
pub fn build(self) -> NetworkConfigResult<WebSocketConfig> {
let config = self.build_inner();
config.validate()?;
Ok(config)
}
}
impl WebSocketConfig {
pub fn validate(&self) -> NetworkConfigResult<()> {
let mut errors = Vec::new();
if self.url.trim().is_empty() {
errors.push(NetworkConfigError::invalid("url", "must not be empty"));
}
if let Some(interval) = self.heartbeat
&& interval == 0
{
errors.push(NetworkConfigError::invalid(
"heartbeat",
"interval must be positive",
));
}
for (field, value) in [
("reconnect_timeout_ms", self.reconnect_timeout_ms),
(
"reconnect_delay_initial_ms",
self.reconnect_delay_initial_ms,
),
("reconnect_delay_max_ms", self.reconnect_delay_max_ms),
("idle_timeout_ms", self.idle_timeout_ms),
] {
if let Some(value) = value
&& value == 0
{
errors.push(NetworkConfigError::invalid(
field,
format!("must be positive, was {value}"),
));
}
}
if let Some(factor) = self.reconnect_backoff_factor
&& !(factor.is_finite() && factor >= 1.0)
{
errors.push(NetworkConfigError::invalid(
"reconnect_backoff_factor",
format!("must be finite and >= 1.0, was {factor}"),
));
}
if let (Some(initial), Some(max)) =
(self.reconnect_delay_initial_ms, self.reconnect_delay_max_ms)
&& initial > max
{
errors.push(NetworkConfigError::invalid(
"reconnect_delay_initial_ms",
format!("must not exceed reconnect_delay_max_ms ({max}), was {initial}"),
));
}
NetworkConfigError::collect(errors)
}
}
#[cfg(test)]
mod tests {
use rstest::rstest;
use serde_json::json;
use super::WebSocketConfig;
use crate::error::NetworkConfigError;
#[rstest]
fn test_deserialize_websocket_config_rejects_unknown_field() {
let config = json!({
"url": "wss://example.com/ws",
"unexpected": true,
});
let error = serde_json::from_value::<WebSocketConfig>(config).unwrap_err();
assert!(error.to_string().contains("unknown field `unexpected`"));
}
fn valid_config() -> WebSocketConfig {
WebSocketConfig::builder()
.url("wss://example.com/ws".to_string())
.build()
.expect("baseline websocket config should be valid")
}
#[rstest]
fn test_builder_accepts_valid_config() {
let result = WebSocketConfig::builder()
.url("wss://example.com/ws".to_string())
.build();
assert!(result.is_ok());
}
#[rstest]
fn test_validate_accepts_zero_jitter() {
let mut config = valid_config();
config.reconnect_jitter_ms = Some(0);
assert!(config.validate().is_ok());
}
#[rstest]
#[case::empty_url(|c: &mut WebSocketConfig| c.url = String::new(), "url")]
#[case::heartbeat(|c: &mut WebSocketConfig| c.heartbeat = Some(0), "heartbeat")]
#[case::reconnect_timeout(|c: &mut WebSocketConfig| c.reconnect_timeout_ms = Some(0), "reconnect_timeout_ms")]
#[case::reconnect_delay_initial(|c: &mut WebSocketConfig| c.reconnect_delay_initial_ms = Some(0), "reconnect_delay_initial_ms")]
#[case::reconnect_delay_max(|c: &mut WebSocketConfig| c.reconnect_delay_max_ms = Some(0), "reconnect_delay_max_ms")]
#[case::idle_timeout(|c: &mut WebSocketConfig| c.idle_timeout_ms = Some(0), "idle_timeout_ms")]
fn test_validate_rejects_invalid_field(
#[case] mutate: fn(&mut WebSocketConfig),
#[case] expected_field: &str,
) {
let mut config = valid_config();
mutate(&mut config);
let err = config
.validate()
.expect_err("invalid value should be rejected");
assert!(
matches!(err, NetworkConfigError::Invalid { field, .. } if field == expected_field)
);
}
#[rstest]
#[case::too_small(0.5)]
#[case::nan(f64::NAN)]
#[case::infinite(f64::INFINITY)]
fn test_validate_rejects_invalid_backoff_factor(#[case] factor: f64) {
let mut config = valid_config();
config.reconnect_backoff_factor = Some(factor);
let err = config
.validate()
.expect_err("invalid backoff factor should be rejected");
assert!(
matches!(err, NetworkConfigError::Invalid { field, .. } if field == "reconnect_backoff_factor")
);
}
#[rstest]
fn test_validate_rejects_delay_initial_exceeding_max() {
let mut config = valid_config();
config.reconnect_delay_initial_ms = Some(5_000);
config.reconnect_delay_max_ms = Some(1_000);
let err = config
.validate()
.expect_err("initial delay above max should be rejected");
assert!(
matches!(err, NetworkConfigError::Invalid { field, .. } if field == "reconnect_delay_initial_ms")
);
}
#[rstest]
fn test_validate_collects_multiple_errors() {
let mut config = valid_config();
config.url = String::new();
config.reconnect_timeout_ms = Some(0);
let err = config.validate().expect_err("multiple invalid fields");
match err {
NetworkConfigError::Multiple { errors } => assert_eq!(errors.len(), 2),
other @ NetworkConfigError::Invalid { .. } => {
panic!("expected Multiple, was {other:?}")
}
}
}
}