1use std::{sync::atomic::Ordering, time::Duration};
17
18use nautilus_core::python::{clone_py_object, to_pyruntime_err};
19use pyo3::{Py, prelude::*};
20use tokio_tungstenite::tungstenite::stream::Mode;
21
22use crate::{
23 mode::ConnectionMode,
24 socket::{SocketClient, SocketConfig, TcpMessageHandler, WriterCommand},
25};
26
27#[pymethods]
28impl SocketConfig {
29 #[new]
30 #[allow(clippy::too_many_arguments)]
31 #[pyo3(signature = (url, ssl, suffix, handler, heartbeat=None, reconnect_timeout_ms=10_000, reconnect_delay_initial_ms=2_000, reconnect_delay_max_ms=30_000, reconnect_backoff_factor=1.5, reconnect_jitter_ms=100, connection_max_retries=5, reconnect_max_attempts=None, idle_timeout_ms=None, certs_dir=None))]
32 fn py_new(
33 url: String,
34 ssl: bool,
35 suffix: Vec<u8>,
36 handler: Py<PyAny>,
37 heartbeat: Option<(u64, Vec<u8>)>,
38 reconnect_timeout_ms: Option<u64>,
39 reconnect_delay_initial_ms: Option<u64>,
40 reconnect_delay_max_ms: Option<u64>,
41 reconnect_backoff_factor: Option<f64>,
42 reconnect_jitter_ms: Option<u64>,
43 connection_max_retries: Option<u32>,
44 reconnect_max_attempts: Option<u32>,
45 idle_timeout_ms: Option<u64>,
46 certs_dir: Option<String>,
47 ) -> Self {
48 let mode = if ssl { Mode::Tls } else { Mode::Plain };
49
50 let handler_clone = clone_py_object(&handler);
52 let message_handler: TcpMessageHandler = std::sync::Arc::new(move |data: &[u8]| {
53 Python::attach(|py| {
54 if let Err(e) = handler_clone.call1(py, (data,)) {
55 log::error!("Error calling Python message handler: {e}");
56 }
57 });
58 });
59
60 Self {
61 url,
62 mode,
63 suffix,
64 message_handler: Some(message_handler),
65 heartbeat,
66 reconnect_timeout_ms,
67 reconnect_delay_initial_ms,
68 reconnect_delay_max_ms,
69 reconnect_backoff_factor,
70 reconnect_jitter_ms,
71 connection_max_retries,
72 reconnect_max_attempts,
73 idle_timeout_ms,
74 certs_dir,
75 }
76 }
77}
78
79#[pymethods]
80impl SocketClient {
81 #[staticmethod]
87 #[pyo3(name = "connect")]
88 #[pyo3(signature = (config, post_connection=None, post_reconnection=None, post_disconnection=None))]
89 fn py_connect(
90 config: SocketConfig,
91 post_connection: Option<Py<PyAny>>,
92 post_reconnection: Option<Py<PyAny>>,
93 post_disconnection: Option<Py<PyAny>>,
94 py: Python<'_>,
95 ) -> PyResult<Bound<'_, PyAny>> {
96 let post_connection_fn = post_connection.map(|callback| {
98 let callback_clone = clone_py_object(&callback);
99 std::sync::Arc::new(move || {
100 Python::attach(|py| {
101 if let Err(e) = callback_clone.call0(py) {
102 log::error!("Error calling post_connection handler: {e}");
103 }
104 });
105 }) as std::sync::Arc<dyn Fn() + Send + Sync>
106 });
107
108 let post_reconnection_fn = post_reconnection.map(|callback| {
109 let callback_clone = clone_py_object(&callback);
110 std::sync::Arc::new(move || {
111 Python::attach(|py| {
112 if let Err(e) = callback_clone.call0(py) {
113 log::error!("Error calling post_reconnection handler: {e}");
114 }
115 });
116 }) as std::sync::Arc<dyn Fn() + Send + Sync>
117 });
118
119 let post_disconnection_fn = post_disconnection.map(|callback| {
120 let callback_clone = clone_py_object(&callback);
121 std::sync::Arc::new(move || {
122 Python::attach(|py| {
123 if let Err(e) = callback_clone.call0(py) {
124 log::error!("Error calling post_disconnection handler: {e}");
125 }
126 });
127 }) as std::sync::Arc<dyn Fn() + Send + Sync>
128 });
129
130 pyo3_async_runtimes::tokio::future_into_py(py, async move {
131 Self::connect(
132 config,
133 post_connection_fn,
134 post_reconnection_fn,
135 post_disconnection_fn,
136 )
137 .await
138 .map_err(to_pyruntime_err)
139 })
140 }
141
142 #[pyo3(name = "is_active")]
152 fn py_is_active(slf: PyRef<'_, Self>) -> bool {
153 slf.is_active()
154 }
155
156 #[pyo3(name = "is_reconnecting")]
157 fn py_is_reconnecting(slf: PyRef<'_, Self>) -> bool {
158 slf.is_reconnecting()
159 }
160
161 #[pyo3(name = "is_disconnecting")]
162 fn py_is_disconnecting(slf: PyRef<'_, Self>) -> bool {
163 slf.is_disconnecting()
164 }
165
166 #[pyo3(name = "is_closed")]
167 fn py_is_closed(slf: PyRef<'_, Self>) -> bool {
168 slf.is_closed()
169 }
170
171 #[pyo3(name = "mode")]
172 fn py_mode(slf: PyRef<'_, Self>) -> String {
173 slf.connection_mode().to_string()
174 }
175
176 #[pyo3(name = "reconnect")]
178 fn py_reconnect<'py>(slf: PyRef<'_, Self>, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
179 let mode = slf.connection_mode.clone();
180 let mode_str = ConnectionMode::from_atomic(&mode).to_string();
181 log::debug!("Reconnect from mode {mode_str}");
182
183 pyo3_async_runtimes::tokio::future_into_py(py, async move {
184 match ConnectionMode::from_atomic(&mode) {
185 ConnectionMode::Reconnect => {
186 log::warn!("Cannot reconnect - socket already reconnecting");
187 }
188 ConnectionMode::Disconnect => {
189 log::warn!("Cannot reconnect - socket disconnecting");
190 }
191 ConnectionMode::Closed => {
192 log::warn!("Cannot reconnect - socket closed");
193 }
194 _ => {
195 mode.store(ConnectionMode::Reconnect.as_u8(), Ordering::SeqCst);
196 let timeout = tokio::time::timeout(Duration::from_secs(30), async {
197 loop {
198 let current = ConnectionMode::from_atomic(&mode);
199 if current.is_active() {
200 return Ok(());
201 }
202 if current.is_closed() || current.is_disconnect() {
203 return Err("Connection closed during reconnect");
204 }
205 tokio::time::sleep(Duration::from_millis(10)).await;
206 }
207 })
208 .await;
209
210 match timeout {
211 Ok(Ok(())) => log::debug!("Reconnected successfully"),
212 Ok(Err(e)) => log::warn!("Reconnect aborted: {e}"),
213 Err(_) => log::error!("Reconnect timed out after 30s"),
214 }
215 }
216 }
217
218 Ok(())
219 })
220 }
221
222 #[pyo3(name = "close")]
232 fn py_close<'py>(slf: PyRef<'_, Self>, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
233 let mode = slf.connection_mode.clone();
234 let mode_str = ConnectionMode::from_atomic(&mode).to_string();
235 log::debug!("Close from mode {mode_str}");
236
237 pyo3_async_runtimes::tokio::future_into_py(py, async move {
238 match ConnectionMode::from_atomic(&mode) {
239 ConnectionMode::Closed => {
240 log::debug!("Socket already closed");
241 }
242 ConnectionMode::Disconnect => {
243 log::debug!("Socket already disconnecting");
244 }
245 _ => {
246 mode.store(ConnectionMode::Disconnect.as_u8(), Ordering::SeqCst);
247
248 let timeout = tokio::time::timeout(Duration::from_secs(5), async {
249 while !ConnectionMode::from_atomic(&mode).is_closed() {
250 tokio::time::sleep(Duration::from_millis(10)).await;
251 }
252 })
253 .await;
254
255 if timeout.is_err() {
256 log::error!("Timeout waiting for socket to close, forcing closed state");
257 mode.store(ConnectionMode::Closed.as_u8(), Ordering::SeqCst);
258 }
259 }
260 }
261
262 Ok(())
263 })
264 }
265
266 #[pyo3(name = "send")]
272 fn py_send<'py>(
273 slf: PyRef<'_, Self>,
274 data: Vec<u8>,
275 py: Python<'py>,
276 ) -> PyResult<Bound<'py, PyAny>> {
277 log::trace!("Sending {}", String::from_utf8_lossy(&data));
278
279 let mode = slf.connection_mode.clone();
280 let writer_tx = slf.writer_tx.clone();
281
282 pyo3_async_runtimes::tokio::future_into_py(py, async move {
283 if ConnectionMode::from_atomic(&mode).is_closed() {
284 let msg = format!(
285 "Cannot send data ({}): socket closed",
286 String::from_utf8_lossy(&data)
287 );
288
289 let io_err = std::io::Error::new(std::io::ErrorKind::NotConnected, msg);
290 return Err(to_pyruntime_err(io_err));
291 }
292
293 let timeout = Duration::from_secs(2);
294 let check_interval = Duration::from_millis(1);
295
296 if !ConnectionMode::from_atomic(&mode).is_active() {
297 log::debug!("Waiting for client to become ACTIVE before sending (2s)...");
298 match tokio::time::timeout(timeout, async {
299 while !ConnectionMode::from_atomic(&mode).is_active() {
300 if matches!(
301 ConnectionMode::from_atomic(&mode),
302 ConnectionMode::Disconnect | ConnectionMode::Closed
303 ) {
304 return Err("Client disconnected waiting to send");
305 }
306
307 tokio::time::sleep(check_interval).await;
308 }
309
310 Ok(())
311 })
312 .await
313 {
314 Ok(Ok(())) => log::debug!("Client now active"),
315 Ok(Err(e)) => {
316 let err_msg = format!(
317 "Failed sending data ({}): {e}",
318 String::from_utf8_lossy(&data)
319 );
320
321 let io_err = std::io::Error::new(std::io::ErrorKind::NotConnected, err_msg);
322 return Err(to_pyruntime_err(io_err));
323 }
324 Err(_) => {
325 let err_msg = format!(
326 "Failed sending data ({}): timeout waiting to become ACTIVE",
327 String::from_utf8_lossy(&data)
328 );
329
330 let io_err = std::io::Error::new(std::io::ErrorKind::TimedOut, err_msg);
331 return Err(to_pyruntime_err(io_err));
332 }
333 }
334 }
335
336 let msg = WriterCommand::Send(data.into());
337 writer_tx.send(msg).map_err(to_pyruntime_err)
338 })
339 }
340}