use std::{sync::atomic::Ordering, time::Duration};
use nautilus_core::python::{clone_py_object, to_pyruntime_err};
use pyo3::{Py, prelude::*};
use tokio_tungstenite::tungstenite::stream::Mode;
use crate::{
mode::ConnectionMode,
socket::{SocketClient, SocketConfig, TcpMessageHandler, WriterCommand},
};
#[pymethods]
#[pyo3_stub_gen::derive::gen_stub_pymethods]
impl SocketConfig {
#[new]
#[allow(clippy::too_many_arguments, clippy::needless_pass_by_value)]
#[pyo3(signature = (url, ssl, suffix, handler, heartbeat=None, reconnect_timeout_ms=10_000, reconnect_delay_initial_ms=2_000, reconnect_delay_max_ms=30_000, reconnect_backoff_factor=1.5, reconnect_jitter_ms=100, connection_max_retries=5, reconnect_max_attempts=None, idle_timeout_ms=None, certs_dir=None))]
fn py_new(
url: String,
ssl: bool,
suffix: Vec<u8>,
handler: Py<PyAny>,
heartbeat: Option<(u64, Vec<u8>)>,
reconnect_timeout_ms: Option<u64>,
reconnect_delay_initial_ms: Option<u64>,
reconnect_delay_max_ms: Option<u64>,
reconnect_backoff_factor: Option<f64>,
reconnect_jitter_ms: Option<u64>,
connection_max_retries: Option<u32>,
reconnect_max_attempts: Option<u32>,
idle_timeout_ms: Option<u64>,
certs_dir: Option<String>,
) -> Self {
let mode = if ssl { Mode::Tls } else { Mode::Plain };
let handler_clone = clone_py_object(&handler);
let message_handler: TcpMessageHandler = std::sync::Arc::new(move |data: &[u8]| {
Python::attach(|py| {
if let Err(e) = handler_clone.call1(py, (data,)) {
log::error!("Error calling Python message handler: {e}");
}
});
});
Self {
url,
mode,
suffix,
message_handler: Some(message_handler),
heartbeat,
reconnect_timeout_ms,
reconnect_delay_initial_ms,
reconnect_delay_max_ms,
reconnect_backoff_factor,
reconnect_jitter_ms,
connection_max_retries,
reconnect_max_attempts,
idle_timeout_ms,
certs_dir,
}
}
}
#[pymethods]
#[pyo3_stub_gen::derive::gen_stub_pymethods]
impl SocketClient {
#[staticmethod]
#[pyo3(name = "connect")]
#[pyo3(signature = (config, post_connection=None, post_reconnection=None, post_disconnection=None))]
fn py_connect(
config: SocketConfig,
post_connection: Option<Py<PyAny>>,
post_reconnection: Option<Py<PyAny>>,
post_disconnection: Option<Py<PyAny>>,
py: Python<'_>,
) -> PyResult<Bound<'_, PyAny>> {
let post_connection_fn = post_connection.map(|callback| {
let callback_clone = clone_py_object(&callback);
std::sync::Arc::new(move || {
Python::attach(|py| {
if let Err(e) = callback_clone.call0(py) {
log::error!("Error calling post_connection handler: {e}");
}
});
}) as std::sync::Arc<dyn Fn() + Send + Sync>
});
let post_reconnection_fn = post_reconnection.map(|callback| {
let callback_clone = clone_py_object(&callback);
std::sync::Arc::new(move || {
Python::attach(|py| {
if let Err(e) = callback_clone.call0(py) {
log::error!("Error calling post_reconnection handler: {e}");
}
});
}) as std::sync::Arc<dyn Fn() + Send + Sync>
});
let post_disconnection_fn = post_disconnection.map(|callback| {
let callback_clone = clone_py_object(&callback);
std::sync::Arc::new(move || {
Python::attach(|py| {
if let Err(e) = callback_clone.call0(py) {
log::error!("Error calling post_disconnection handler: {e}");
}
});
}) as std::sync::Arc<dyn Fn() + Send + Sync>
});
pyo3_async_runtimes::tokio::future_into_py(py, async move {
Self::connect(
config,
post_connection_fn,
post_reconnection_fn,
post_disconnection_fn,
)
.await
.map_err(to_pyruntime_err)
})
}
#[pyo3(name = "is_active")]
#[allow(clippy::needless_pass_by_value)]
fn py_is_active(slf: PyRef<'_, Self>) -> bool {
slf.is_active()
}
#[pyo3(name = "is_reconnecting")]
#[allow(clippy::needless_pass_by_value)]
fn py_is_reconnecting(slf: PyRef<'_, Self>) -> bool {
slf.is_reconnecting()
}
#[pyo3(name = "is_disconnecting")]
#[allow(clippy::needless_pass_by_value)]
fn py_is_disconnecting(slf: PyRef<'_, Self>) -> bool {
slf.is_disconnecting()
}
#[pyo3(name = "is_closed")]
#[allow(clippy::needless_pass_by_value)]
fn py_is_closed(slf: PyRef<'_, Self>) -> bool {
slf.is_closed()
}
#[pyo3(name = "mode")]
#[allow(clippy::needless_pass_by_value)]
fn py_mode(slf: PyRef<'_, Self>) -> String {
slf.connection_mode().to_string()
}
#[pyo3(name = "reconnect")]
#[allow(clippy::needless_pass_by_value)]
fn py_reconnect<'py>(slf: PyRef<'_, Self>, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
let connection_mode = slf.connection_mode.clone();
let state_notify = slf.state_notify.clone();
let mode_str = ConnectionMode::from_atomic(&connection_mode).to_string();
log::debug!("Reconnect from mode {mode_str}");
pyo3_async_runtimes::tokio::future_into_py(py, async move {
match ConnectionMode::from_atomic(&connection_mode) {
ConnectionMode::Reconnect => {
log::warn!("Cannot reconnect - socket already reconnecting");
}
ConnectionMode::Disconnect => {
log::warn!("Cannot reconnect - socket disconnecting");
}
ConnectionMode::Closed => {
log::warn!("Cannot reconnect - socket closed");
}
_ => {
connection_mode.store(ConnectionMode::Reconnect.as_u8(), Ordering::SeqCst);
state_notify.notify_one();
let fallback_interval = Duration::from_millis(100);
let timeout = tokio::time::timeout(Duration::from_secs(30), async {
loop {
let notified = state_notify.notified();
let current = ConnectionMode::from_atomic(&connection_mode);
if current.is_active() {
return Ok(());
}
if current.is_closed() || current.is_disconnect() {
return Err("Connection closed during reconnect");
}
tokio::select! {
() = notified => {}
() = tokio::time::sleep(fallback_interval) => {}
}
}
})
.await;
match timeout {
Ok(Ok(())) => log::debug!("Reconnected successfully"),
Ok(Err(e)) => log::warn!("Reconnect aborted: {e}"),
Err(_) => log::error!("Reconnect timed out after 30s"),
}
}
}
Ok(())
})
}
#[pyo3(name = "close")]
#[allow(clippy::needless_pass_by_value)]
fn py_close<'py>(slf: PyRef<'_, Self>, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
let connection_mode = slf.connection_mode.clone();
let state_notify = slf.state_notify.clone();
let mode_str = ConnectionMode::from_atomic(&connection_mode).to_string();
log::debug!("Close from mode {mode_str}");
pyo3_async_runtimes::tokio::future_into_py(py, async move {
match ConnectionMode::from_atomic(&connection_mode) {
ConnectionMode::Closed => {
log::debug!("Socket already closed");
}
ConnectionMode::Disconnect => {
log::debug!("Socket already disconnecting");
}
_ => {
connection_mode.store(ConnectionMode::Disconnect.as_u8(), Ordering::SeqCst);
state_notify.notify_one();
let timeout = tokio::time::timeout(Duration::from_secs(5), async {
while !ConnectionMode::from_atomic(&connection_mode).is_closed() {
tokio::time::sleep(Duration::from_millis(10)).await;
}
})
.await;
if timeout.is_err() {
log::error!("Timeout waiting for socket to close, forcing closed state");
connection_mode.store(ConnectionMode::Closed.as_u8(), Ordering::SeqCst);
}
}
}
Ok(())
})
}
#[pyo3(name = "send")]
#[allow(clippy::needless_pass_by_value)]
fn py_send<'py>(
slf: PyRef<'_, Self>,
data: Vec<u8>,
py: Python<'py>,
) -> PyResult<Bound<'py, PyAny>> {
log::trace!("Sending {}", String::from_utf8_lossy(&data));
let connection_mode = slf.connection_mode.clone();
let state_notify = slf.state_notify.clone();
let writer_tx = slf.writer_tx.clone();
pyo3_async_runtimes::tokio::future_into_py(py, async move {
match ConnectionMode::from_atomic(&connection_mode) {
ConnectionMode::Disconnect | ConnectionMode::Closed => {
let msg = format!(
"Cannot send data ({}): socket closed",
String::from_utf8_lossy(&data)
);
let io_err = std::io::Error::new(std::io::ErrorKind::NotConnected, msg);
return Err(to_pyruntime_err(io_err));
}
mode if !mode.is_active() => {
let timeout = Duration::from_secs(2);
let fallback_interval = Duration::from_millis(100);
log::debug!("Waiting for client to become ACTIVE before sending (2s)...");
match tokio::time::timeout(timeout, async {
loop {
let notified = state_notify.notified();
let mode = ConnectionMode::from_atomic(&connection_mode);
if mode.is_active() {
return Ok(());
}
if matches!(mode, ConnectionMode::Disconnect | ConnectionMode::Closed) {
return Err("Client disconnected waiting to send");
}
tokio::select! {
() = notified => {}
() = tokio::time::sleep(fallback_interval) => {}
}
}
})
.await
{
Ok(Ok(())) => log::debug!("Client now active"),
Ok(Err(e)) => {
let err_msg = format!(
"Failed sending data ({}): {e}",
String::from_utf8_lossy(&data)
);
let io_err =
std::io::Error::new(std::io::ErrorKind::NotConnected, err_msg);
return Err(to_pyruntime_err(io_err));
}
Err(_) => {
let err_msg = format!(
"Failed sending data ({}): timeout waiting to become ACTIVE",
String::from_utf8_lossy(&data)
);
let io_err = std::io::Error::new(std::io::ErrorKind::TimedOut, err_msg);
return Err(to_pyruntime_err(io_err));
}
}
}
_ => {}
}
let msg = WriterCommand::Send(data.into());
writer_tx.send(msg).map_err(to_pyruntime_err)
})
}
}