Skip to main content

nautilus_network/python/
socket.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2026 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16use std::{sync::atomic::Ordering, time::Duration};
17
18use nautilus_core::python::{clone_py_object, to_pyruntime_err};
19use pyo3::{Py, prelude::*};
20use tokio_tungstenite::tungstenite::stream::Mode;
21
22use crate::{
23    mode::ConnectionMode,
24    socket::{SocketClient, SocketConfig, TcpMessageHandler, WriterCommand},
25};
26
27#[pymethods]
28impl SocketConfig {
29    #[new]
30    #[allow(clippy::too_many_arguments)]
31    #[pyo3(signature = (url, ssl, suffix, handler, heartbeat=None, reconnect_timeout_ms=10_000, reconnect_delay_initial_ms=2_000, reconnect_delay_max_ms=30_000, reconnect_backoff_factor=1.5, reconnect_jitter_ms=100, connection_max_retries=5, reconnect_max_attempts=None, idle_timeout_ms=None, certs_dir=None))]
32    fn py_new(
33        url: String,
34        ssl: bool,
35        suffix: Vec<u8>,
36        handler: Py<PyAny>,
37        heartbeat: Option<(u64, Vec<u8>)>,
38        reconnect_timeout_ms: Option<u64>,
39        reconnect_delay_initial_ms: Option<u64>,
40        reconnect_delay_max_ms: Option<u64>,
41        reconnect_backoff_factor: Option<f64>,
42        reconnect_jitter_ms: Option<u64>,
43        connection_max_retries: Option<u32>,
44        reconnect_max_attempts: Option<u32>,
45        idle_timeout_ms: Option<u64>,
46        certs_dir: Option<String>,
47    ) -> Self {
48        let mode = if ssl { Mode::Tls } else { Mode::Plain };
49
50        // Create function pointer that calls Python handler
51        let handler_clone = clone_py_object(&handler);
52        let message_handler: TcpMessageHandler = std::sync::Arc::new(move |data: &[u8]| {
53            Python::attach(|py| {
54                if let Err(e) = handler_clone.call1(py, (data,)) {
55                    log::error!("Error calling Python message handler: {e}");
56                }
57            });
58        });
59
60        Self {
61            url,
62            mode,
63            suffix,
64            message_handler: Some(message_handler),
65            heartbeat,
66            reconnect_timeout_ms,
67            reconnect_delay_initial_ms,
68            reconnect_delay_max_ms,
69            reconnect_backoff_factor,
70            reconnect_jitter_ms,
71            connection_max_retries,
72            reconnect_max_attempts,
73            idle_timeout_ms,
74            certs_dir,
75        }
76    }
77}
78
79#[pymethods]
80impl SocketClient {
81    /// Create a socket client.
82    ///
83    /// # Errors
84    ///
85    /// - Throws an Exception if it is unable to make socket connection.
86    #[staticmethod]
87    #[pyo3(name = "connect")]
88    #[pyo3(signature = (config, post_connection=None, post_reconnection=None, post_disconnection=None))]
89    fn py_connect(
90        config: SocketConfig,
91        post_connection: Option<Py<PyAny>>,
92        post_reconnection: Option<Py<PyAny>>,
93        post_disconnection: Option<Py<PyAny>>,
94        py: Python<'_>,
95    ) -> PyResult<Bound<'_, PyAny>> {
96        // Convert Python callbacks to function pointers
97        let post_connection_fn = post_connection.map(|callback| {
98            let callback_clone = clone_py_object(&callback);
99            std::sync::Arc::new(move || {
100                Python::attach(|py| {
101                    if let Err(e) = callback_clone.call0(py) {
102                        log::error!("Error calling post_connection handler: {e}");
103                    }
104                });
105            }) as std::sync::Arc<dyn Fn() + Send + Sync>
106        });
107
108        let post_reconnection_fn = post_reconnection.map(|callback| {
109            let callback_clone = clone_py_object(&callback);
110            std::sync::Arc::new(move || {
111                Python::attach(|py| {
112                    if let Err(e) = callback_clone.call0(py) {
113                        log::error!("Error calling post_reconnection handler: {e}");
114                    }
115                });
116            }) as std::sync::Arc<dyn Fn() + Send + Sync>
117        });
118
119        let post_disconnection_fn = post_disconnection.map(|callback| {
120            let callback_clone = clone_py_object(&callback);
121            std::sync::Arc::new(move || {
122                Python::attach(|py| {
123                    if let Err(e) = callback_clone.call0(py) {
124                        log::error!("Error calling post_disconnection handler: {e}");
125                    }
126                });
127            }) as std::sync::Arc<dyn Fn() + Send + Sync>
128        });
129
130        pyo3_async_runtimes::tokio::future_into_py(py, async move {
131            Self::connect(
132                config,
133                post_connection_fn,
134                post_reconnection_fn,
135                post_disconnection_fn,
136            )
137            .await
138            .map_err(to_pyruntime_err)
139        })
140    }
141
142    /// Check if the client is still alive.
143    ///
144    /// Even if the connection is disconnected the client will still be alive
145    /// and trying to reconnect.
146    ///
147    /// This is particularly useful for check why a `send` failed. It could
148    /// be because the connection disconnected and the client is still alive
149    /// and reconnecting. In such cases the send can be retried after some
150    /// delay
151    #[pyo3(name = "is_active")]
152    fn py_is_active(slf: PyRef<'_, Self>) -> bool {
153        slf.is_active()
154    }
155
156    #[pyo3(name = "is_reconnecting")]
157    fn py_is_reconnecting(slf: PyRef<'_, Self>) -> bool {
158        slf.is_reconnecting()
159    }
160
161    #[pyo3(name = "is_disconnecting")]
162    fn py_is_disconnecting(slf: PyRef<'_, Self>) -> bool {
163        slf.is_disconnecting()
164    }
165
166    #[pyo3(name = "is_closed")]
167    fn py_is_closed(slf: PyRef<'_, Self>) -> bool {
168        slf.is_closed()
169    }
170
171    #[pyo3(name = "mode")]
172    fn py_mode(slf: PyRef<'_, Self>) -> String {
173        slf.connection_mode().to_string()
174    }
175
176    /// Reconnect the client.
177    #[pyo3(name = "reconnect")]
178    fn py_reconnect<'py>(slf: PyRef<'_, Self>, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
179        let mode = slf.connection_mode.clone();
180        let mode_str = ConnectionMode::from_atomic(&mode).to_string();
181        log::debug!("Reconnect from mode {mode_str}");
182
183        pyo3_async_runtimes::tokio::future_into_py(py, async move {
184            match ConnectionMode::from_atomic(&mode) {
185                ConnectionMode::Reconnect => {
186                    log::warn!("Cannot reconnect - socket already reconnecting");
187                }
188                ConnectionMode::Disconnect => {
189                    log::warn!("Cannot reconnect - socket disconnecting");
190                }
191                ConnectionMode::Closed => {
192                    log::warn!("Cannot reconnect - socket closed");
193                }
194                _ => {
195                    mode.store(ConnectionMode::Reconnect.as_u8(), Ordering::SeqCst);
196                    let timeout = tokio::time::timeout(Duration::from_secs(30), async {
197                        loop {
198                            let current = ConnectionMode::from_atomic(&mode);
199                            if current.is_active() {
200                                return Ok(());
201                            }
202                            if current.is_closed() || current.is_disconnect() {
203                                return Err("Connection closed during reconnect");
204                            }
205                            tokio::time::sleep(Duration::from_millis(10)).await;
206                        }
207                    })
208                    .await;
209
210                    match timeout {
211                        Ok(Ok(())) => log::debug!("Reconnected successfully"),
212                        Ok(Err(e)) => log::warn!("Reconnect aborted: {e}"),
213                        Err(_) => log::error!("Reconnect timed out after 30s"),
214                    }
215                }
216            }
217
218            Ok(())
219        })
220    }
221
222    /// Close the client.
223    ///
224    /// The connection is not completely closed until all references
225    /// to the client are gone and the client is dropped.
226    ///
227    /// # Safety
228    ///
229    /// - The client should not be used after closing it
230    /// - Any auto-reconnect job should be aborted before closing the client
231    #[pyo3(name = "close")]
232    fn py_close<'py>(slf: PyRef<'_, Self>, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
233        let mode = slf.connection_mode.clone();
234        let mode_str = ConnectionMode::from_atomic(&mode).to_string();
235        log::debug!("Close from mode {mode_str}");
236
237        pyo3_async_runtimes::tokio::future_into_py(py, async move {
238            match ConnectionMode::from_atomic(&mode) {
239                ConnectionMode::Closed => {
240                    log::debug!("Socket already closed");
241                }
242                ConnectionMode::Disconnect => {
243                    log::debug!("Socket already disconnecting");
244                }
245                _ => {
246                    mode.store(ConnectionMode::Disconnect.as_u8(), Ordering::SeqCst);
247
248                    let timeout = tokio::time::timeout(Duration::from_secs(5), async {
249                        while !ConnectionMode::from_atomic(&mode).is_closed() {
250                            tokio::time::sleep(Duration::from_millis(10)).await;
251                        }
252                    })
253                    .await;
254
255                    if timeout.is_err() {
256                        log::error!("Timeout waiting for socket to close, forcing closed state");
257                        mode.store(ConnectionMode::Closed.as_u8(), Ordering::SeqCst);
258                    }
259                }
260            }
261
262            Ok(())
263        })
264    }
265
266    /// Send bytes data to the connection.
267    ///
268    /// # Errors
269    ///
270    /// - Throws an Exception if it is not able to send data.
271    #[pyo3(name = "send")]
272    fn py_send<'py>(
273        slf: PyRef<'_, Self>,
274        data: Vec<u8>,
275        py: Python<'py>,
276    ) -> PyResult<Bound<'py, PyAny>> {
277        log::trace!("Sending {}", String::from_utf8_lossy(&data));
278
279        let mode = slf.connection_mode.clone();
280        let writer_tx = slf.writer_tx.clone();
281
282        pyo3_async_runtimes::tokio::future_into_py(py, async move {
283            if ConnectionMode::from_atomic(&mode).is_closed() {
284                let msg = format!(
285                    "Cannot send data ({}): socket closed",
286                    String::from_utf8_lossy(&data)
287                );
288
289                let io_err = std::io::Error::new(std::io::ErrorKind::NotConnected, msg);
290                return Err(to_pyruntime_err(io_err));
291            }
292
293            let timeout = Duration::from_secs(2);
294            let check_interval = Duration::from_millis(1);
295
296            if !ConnectionMode::from_atomic(&mode).is_active() {
297                log::debug!("Waiting for client to become ACTIVE before sending (2s)...");
298                match tokio::time::timeout(timeout, async {
299                    while !ConnectionMode::from_atomic(&mode).is_active() {
300                        if matches!(
301                            ConnectionMode::from_atomic(&mode),
302                            ConnectionMode::Disconnect | ConnectionMode::Closed
303                        ) {
304                            return Err("Client disconnected waiting to send");
305                        }
306
307                        tokio::time::sleep(check_interval).await;
308                    }
309
310                    Ok(())
311                })
312                .await
313                {
314                    Ok(Ok(())) => log::debug!("Client now active"),
315                    Ok(Err(e)) => {
316                        let err_msg = format!(
317                            "Failed sending data ({}): {e}",
318                            String::from_utf8_lossy(&data)
319                        );
320
321                        let io_err = std::io::Error::new(std::io::ErrorKind::NotConnected, err_msg);
322                        return Err(to_pyruntime_err(io_err));
323                    }
324                    Err(_) => {
325                        let err_msg = format!(
326                            "Failed sending data ({}): timeout waiting to become ACTIVE",
327                            String::from_utf8_lossy(&data)
328                        );
329
330                        let io_err = std::io::Error::new(std::io::ErrorKind::TimedOut, err_msg);
331                        return Err(to_pyruntime_err(io_err));
332                    }
333                }
334            }
335
336            let msg = WriterCommand::Send(data.into());
337            writer_tx.send(msg).map_err(to_pyruntime_err)
338        })
339    }
340}