1use std::{
17 sync::{
18 Arc,
19 atomic::{AtomicU8, Ordering},
20 },
21 time::Duration,
22};
23
24use nautilus_core::{
25 collections::into_ustr_vec,
26 python::{clone_py_object, to_pyruntime_err, to_pyvalue_err},
27};
28use pyo3::{Py, create_exception, exceptions::PyException, prelude::*, types::PyBytes};
29
30use crate::{
31 RECONNECTED,
32 mode::ConnectionMode,
33 ratelimiter::quota::Quota,
34 transport::{Message, TransportError},
35 websocket::{
36 TransportBackend, WebSocketClient, WebSocketConfig,
37 types::{MessageHandler, PingHandler, WriterCommand},
38 },
39};
40
41create_exception!(network, WebSocketClientError, PyException);
42
43#[expect(clippy::needless_pass_by_value)]
44fn to_websocket_pyerr(e: TransportError) -> PyErr {
45 PyErr::new::<WebSocketClientError, _>(e.to_string())
46}
47
48#[pymethods]
49#[pyo3_stub_gen::derive::gen_stub_pymethods]
50impl WebSocketConfig {
51 #[new]
75 #[expect(clippy::too_many_arguments)]
76 #[pyo3(signature = (
77 url,
78 headers,
79 heartbeat=None,
80 heartbeat_msg=None,
81 reconnect_timeout_ms=10_000,
82 reconnect_delay_initial_ms=2_000,
83 reconnect_delay_max_ms=30_000,
84 reconnect_backoff_factor=1.5,
85 reconnect_jitter_ms=100,
86 reconnect_max_attempts=None,
87 idle_timeout_ms=None,
88 proxy_url=None,
89 ))]
90 fn py_new(
91 url: String,
92 headers: Vec<(String, String)>,
93 heartbeat: Option<u64>,
94 heartbeat_msg: Option<String>,
95 reconnect_timeout_ms: Option<u64>,
96 reconnect_delay_initial_ms: Option<u64>,
97 reconnect_delay_max_ms: Option<u64>,
98 reconnect_backoff_factor: Option<f64>,
99 reconnect_jitter_ms: Option<u64>,
100 reconnect_max_attempts: Option<u32>,
101 idle_timeout_ms: Option<u64>,
102 proxy_url: Option<String>,
103 ) -> Self {
104 Self {
105 url,
106 headers,
107 heartbeat,
108 heartbeat_msg,
109 reconnect_timeout_ms,
110 reconnect_delay_initial_ms,
111 reconnect_delay_max_ms,
112 reconnect_backoff_factor,
113 reconnect_jitter_ms,
114 reconnect_max_attempts,
115 idle_timeout_ms,
116 backend: TransportBackend::default(),
117 proxy_url,
118 }
119 }
120}
121
122#[pymethods]
123#[pyo3_stub_gen::derive::gen_stub_pymethods]
124impl WebSocketClient {
125 #[staticmethod]
137 #[pyo3(name = "connect", signature = (loop_, config, handler, ping_handler = None, post_reconnection = None, keyed_quotas = Vec::new(), default_quota = None))]
138 #[expect(clippy::too_many_arguments, clippy::needless_pass_by_value)]
139 fn py_connect(
140 loop_: Py<PyAny>,
141 config: WebSocketConfig,
142 handler: Py<PyAny>,
143 ping_handler: Option<Py<PyAny>>,
144 post_reconnection: Option<Py<PyAny>>,
145 keyed_quotas: Vec<(String, Quota)>,
146 default_quota: Option<Quota>,
147 py: Python<'_>,
148 ) -> PyResult<Bound<'_, PyAny>> {
149 let call_soon_threadsafe: Py<PyAny> = loop_.getattr(py, "call_soon_threadsafe")?;
150 let call_soon_clone = clone_py_object(&call_soon_threadsafe);
151 let handler_clone = clone_py_object(&handler);
152
153 let message_handler: MessageHandler = Arc::new(move |msg: Message| {
154 if matches!(msg, Message::Text(ref text) if text.as_ref() == RECONNECTED.as_bytes()) {
155 return;
156 }
157
158 Python::attach(|py| {
159 let py_bytes = match &msg {
160 Message::Binary(data) | Message::Text(data) => PyBytes::new(py, data.as_ref()),
161 _ => return,
162 };
163
164 if let Err(e) = call_soon_clone.call1(py, (&handler_clone, py_bytes)) {
165 log::error!("Error scheduling message handler on event loop: {e}");
166 }
167 });
168 });
169
170 let ping_handler_fn = ping_handler.map(|ping_handler| {
171 let ping_handler_clone = clone_py_object(&ping_handler);
172 let call_soon_clone = clone_py_object(&call_soon_threadsafe);
173
174 let ping_handler_fn: PingHandler = Arc::new(move |data: Vec<u8>| {
175 Python::attach(|py| {
176 let py_bytes = PyBytes::new(py, &data);
177 if let Err(e) = call_soon_clone.call1(py, (&ping_handler_clone, py_bytes)) {
178 log::error!("Error scheduling ping handler on event loop: {e}");
179 }
180 });
181 });
182 ping_handler_fn
183 });
184
185 let post_reconnection_fn = post_reconnection.map(|callback| {
186 let callback_clone = clone_py_object(&callback);
187 Arc::new(move || {
188 Python::attach(|py| {
189 if let Err(e) = callback_clone.call0(py) {
190 log::error!("Error calling post_reconnection handler: {e}");
191 }
192 });
193 }) as std::sync::Arc<dyn Fn() + Send + Sync>
194 });
195
196 pyo3_async_runtimes::tokio::future_into_py(py, async move {
197 Box::pin(Self::connect(
198 config,
199 Some(message_handler),
200 ping_handler_fn,
201 post_reconnection_fn,
202 keyed_quotas,
203 default_quota,
204 ))
205 .await
206 .map_err(to_websocket_pyerr)
207 })
208 }
209
210 #[pyo3(name = "disconnect")]
215 #[expect(clippy::needless_pass_by_value)]
216 fn py_disconnect<'py>(slf: PyRef<'_, Self>, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
217 let connection_mode = slf.connection_mode.clone();
218 let state_notify = slf.state_notify.clone();
219 let mode = ConnectionMode::from_atomic(&connection_mode);
220 log::debug!("Close from mode {mode}");
221
222 pyo3_async_runtimes::tokio::future_into_py(py, async move {
223 match ConnectionMode::from_atomic(&connection_mode) {
224 ConnectionMode::Closed => {
225 log::debug!("WebSocket already closed");
226 }
227 ConnectionMode::Disconnect => {
228 log::debug!("WebSocket already disconnecting");
229 }
230 _ => {
231 connection_mode.store(ConnectionMode::Disconnect.as_u8(), Ordering::SeqCst);
232 state_notify.notify_one();
233
234 let timeout = tokio::time::timeout(Duration::from_secs(5), async {
235 while !ConnectionMode::from_atomic(&connection_mode).is_closed() {
236 tokio::time::sleep(Duration::from_millis(10)).await;
237 }
238 })
239 .await;
240
241 if timeout.is_err() {
242 log::error!("Timeout waiting for WebSocket to close, forcing closed state");
243 connection_mode.store(ConnectionMode::Closed.as_u8(), Ordering::SeqCst);
244 }
245 }
246 }
247
248 Ok(())
249 })
250 }
251
252 #[pyo3(name = "is_active")]
257 #[expect(clippy::needless_pass_by_value)]
258 fn py_is_active(slf: PyRef<'_, Self>) -> bool {
259 !slf.controller_task.is_finished()
260 }
261
262 #[pyo3(name = "is_reconnecting")]
267 #[expect(clippy::needless_pass_by_value)]
268 fn py_is_reconnecting(slf: PyRef<'_, Self>) -> bool {
269 slf.is_reconnecting()
270 }
271
272 #[pyo3(name = "is_disconnecting")]
276 #[expect(clippy::needless_pass_by_value)]
277 fn py_is_disconnecting(slf: PyRef<'_, Self>) -> bool {
278 slf.is_disconnecting()
279 }
280
281 #[pyo3(name = "is_closed")]
287 #[expect(clippy::needless_pass_by_value)]
288 fn py_is_closed(slf: PyRef<'_, Self>) -> bool {
289 slf.is_closed()
290 }
291
292 #[pyo3(name = "send")]
300 #[pyo3(signature = (data, keys=None))]
301 #[expect(clippy::needless_pass_by_value)]
302 fn py_send<'py>(
303 slf: PyRef<'_, Self>,
304 data: Vec<u8>,
305 py: Python<'py>,
306 keys: Option<Vec<String>>,
307 ) -> PyResult<Bound<'py, PyAny>> {
308 let rate_limiter = slf.rate_limiter.clone();
309 let writer_tx = slf.writer_tx.clone();
310 let mode = slf.connection_mode.clone();
311 let keys = keys.map(into_ustr_vec);
312
313 pyo3_async_runtimes::tokio::future_into_py(py, async move {
314 if !ConnectionMode::from_atomic(&mode).is_active() {
315 let msg = "Cannot send data: connection not active".to_string();
316 log::error!("{msg}");
317 return Err(to_websocket_pyerr(TransportError::Io(std::io::Error::new(
318 std::io::ErrorKind::NotConnected,
319 msg,
320 ))));
321 }
322
323 tokio::select! {
324 biased;
325 () = rate_limiter.await_keys_ready(keys.as_deref()) => {}
326 () = poll_until_closed(&mode) => {
327 return Err(to_websocket_pyerr(TransportError::Io(std::io::Error::new(
328 std::io::ErrorKind::ConnectionAborted,
329 "Connection closed while waiting for rate limit",
330 ))));
331 }
332 }
333
334 log::trace!("Sending binary: {data:?}");
335
336 let msg = Message::Binary(data.into());
337 writer_tx
338 .send(WriterCommand::Send(msg))
339 .map_err(to_pyruntime_err)
340 })
341 }
342
343 #[pyo3(name = "send_text")]
349 #[pyo3(signature = (data, keys=None))]
350 #[expect(clippy::needless_pass_by_value)]
351 fn py_send_text<'py>(
352 slf: PyRef<'_, Self>,
353 data: Vec<u8>,
354 py: Python<'py>,
355 keys: Option<Vec<String>>,
356 ) -> PyResult<Bound<'py, PyAny>> {
357 let data_str = String::from_utf8(data).map_err(to_pyvalue_err)?;
358 let rate_limiter = slf.rate_limiter.clone();
359 let writer_tx = slf.writer_tx.clone();
360 let mode = slf.connection_mode.clone();
361 let keys = keys.map(into_ustr_vec);
362
363 pyo3_async_runtimes::tokio::future_into_py(py, async move {
364 if !ConnectionMode::from_atomic(&mode).is_active() {
365 return Err(to_websocket_pyerr(TransportError::Io(std::io::Error::new(
366 std::io::ErrorKind::NotConnected,
367 "Cannot send text: connection not active",
368 ))));
369 }
370
371 tokio::select! {
372 biased;
373 () = rate_limiter.await_keys_ready(keys.as_deref()) => {}
374 () = poll_until_closed(&mode) => {
375 return Err(to_websocket_pyerr(TransportError::Io(std::io::Error::new(
376 std::io::ErrorKind::ConnectionAborted,
377 "Connection closed while waiting for rate limit",
378 ))));
379 }
380 }
381
382 log::trace!("Sending text: {data_str}");
383
384 let msg = Message::Text(data_str.into());
385 writer_tx
386 .send(WriterCommand::Send(msg))
387 .map_err(to_pyruntime_err)
388 })
389 }
390
391 #[pyo3(name = "send_pong")]
393 #[expect(clippy::needless_pass_by_value)]
394 fn py_send_pong<'py>(
395 slf: PyRef<'_, Self>,
396 data: Vec<u8>,
397 py: Python<'py>,
398 ) -> PyResult<Bound<'py, PyAny>> {
399 let writer_tx = slf.writer_tx.clone();
400 let mode = slf.connection_mode.clone();
401 let data_len = data.len();
402
403 pyo3_async_runtimes::tokio::future_into_py(py, async move {
404 if !ConnectionMode::from_atomic(&mode).is_active() {
405 log::debug!("Skipping pong: connection not active");
406 return Ok(());
407 }
408 log::trace!("Sending pong frame ({data_len} bytes)");
409
410 let msg = Message::Pong(data.into());
411 writer_tx
412 .send(WriterCommand::Send(msg))
413 .map_err(to_pyruntime_err)
414 })
415 }
416}
417
418async fn poll_until_closed(mode: &Arc<AtomicU8>) {
419 loop {
420 if matches!(
421 ConnectionMode::from_atomic(mode),
422 ConnectionMode::Disconnect | ConnectionMode::Closed
423 ) {
424 break;
425 }
426
427 tokio::time::sleep(Duration::from_millis(100)).await;
428 }
429}
430
431#[cfg(test)]
432#[cfg(target_os = "linux")] mod tests {
434 use std::ffi::CString;
435
436 use futures_util::{SinkExt, StreamExt};
437 use nautilus_core::python::IntoPyObjectNautilusExt;
438 use pyo3::{prelude::*, types::PyBytes};
439 use tokio::{
440 net::TcpListener,
441 task::{self, JoinHandle},
442 time::{Duration, sleep},
443 };
444 use tokio_tungstenite::{
445 accept_hdr_async,
446 tungstenite::{
447 handshake::server::{self, Callback},
448 http::HeaderValue,
449 },
450 };
451
452 use crate::{
453 transport::Message,
454 websocket::{MessageHandler, WebSocketClient, WebSocketConfig},
455 };
456
457 struct TestServer {
458 task: JoinHandle<()>,
459 port: u16,
460 }
461
462 #[derive(Debug, Clone)]
463 struct TestCallback {
464 key: String,
465 value: HeaderValue,
466 }
467
468 impl Callback for TestCallback {
469 #[expect(clippy::panic_in_result_fn)]
470 fn on_request(
471 self,
472 request: &server::Request,
473 response: server::Response,
474 ) -> Result<server::Response, server::ErrorResponse> {
475 let _ = response;
476 let value = request.headers().get(&self.key);
477 assert!(value.is_some());
478
479 if let Some(value) = request.headers().get(&self.key) {
480 assert_eq!(value, self.value);
481 }
482
483 Ok(response)
484 }
485 }
486
487 impl TestServer {
488 async fn setup(key: String, value: String) -> Self {
489 let server = TcpListener::bind("127.0.0.1:0").await.unwrap();
490 let port = TcpListener::local_addr(&server).unwrap().port();
491
492 let test_call_back = TestCallback {
493 key,
494 value: HeaderValue::from_str(&value).unwrap(),
495 };
496
497 let task = task::spawn(async move {
499 loop {
501 let (conn, _) = server.accept().await.unwrap();
502 let mut websocket = accept_hdr_async(conn, test_call_back.clone())
503 .await
504 .unwrap();
505
506 task::spawn(async move {
507 #[expect(clippy::collapsible_match)]
509 while let Some(Ok(msg)) = websocket.next().await {
510 match msg {
511 tokio_tungstenite::tungstenite::protocol::Message::Text(txt)
512 if txt == "close-now" =>
513 {
514 log::debug!("Forcibly closing from server side");
515 let _ = websocket.close(None).await;
517 break;
518 }
519 tokio_tungstenite::tungstenite::protocol::Message::Text(_)
521 | tokio_tungstenite::tungstenite::protocol::Message::Binary(_) => {
522 if websocket.send(msg).await.is_err() {
523 break;
524 }
525 }
526 tokio_tungstenite::tungstenite::protocol::Message::Close(
528 _frame,
529 ) => {
530 let _ = websocket.close(None).await;
531 break;
532 }
533 _ => {}
535 }
536 }
537 });
538 }
539 });
540
541 Self { task, port }
542 }
543 }
544
545 impl Drop for TestServer {
546 fn drop(&mut self) {
547 self.task.abort();
548 }
549 }
550
551 fn create_test_handler() -> (Py<PyAny>, Py<PyAny>) {
552 let code_raw = "
553class Counter:
554 def __init__(self):
555 self.count = 0
556 self.check = False
557
558 def handler(self, bytes):
559 msg = bytes.decode()
560 if msg == 'ping':
561 self.count += 1
562 elif msg == 'heartbeat message':
563 self.check = True
564
565 def get_check(self):
566 return self.check
567
568 def get_count(self):
569 return self.count
570
571counter = Counter()
572";
573
574 let code = CString::new(code_raw).unwrap();
575 let filename = CString::new("test".to_string()).unwrap();
576 let module = CString::new("test".to_string()).unwrap();
577 Python::attach(|py| {
578 let pymod = PyModule::from_code(py, &code, &filename, &module).unwrap();
579
580 let counter = pymod.getattr("counter").unwrap().into_py_any_unwrap(py);
581 let handler = counter
582 .getattr(py, "handler")
583 .unwrap()
584 .into_py_any_unwrap(py);
585
586 (counter, handler)
587 })
588 }
589
590 #[tokio::test]
591 async fn basic_client_test() {
592 const N: usize = 10;
593
594 Python::initialize();
595
596 let mut success_count = 0;
597 let header_key = "hello-custom-key".to_string();
598 let header_value = "hello-custom-value".to_string();
599
600 let server = TestServer::setup(header_key.clone(), header_value.clone()).await;
601 let (counter, handler) = create_test_handler();
602
603 let config = WebSocketConfig::py_new(
604 format!("ws://127.0.0.1:{}", server.port),
605 vec![(header_key, header_value)],
606 None,
607 None,
608 None,
609 None,
610 None,
611 None,
612 None,
613 None,
614 None,
615 None,
616 );
617
618 let handler_clone = Python::attach(|py| handler.clone_ref(py));
619
620 let message_handler: MessageHandler = std::sync::Arc::new(move |msg: Message| {
621 Python::attach(|py| {
622 let data = match msg {
623 Message::Binary(data) | Message::Text(data) => data.to_vec(),
624 _ => return,
625 };
626 let py_bytes = PyBytes::new(py, &data);
627 if let Err(e) = handler_clone.call1(py, (py_bytes,)) {
628 log::error!("Error calling handler: {e}");
629 }
630 });
631 });
632
633 let client =
634 WebSocketClient::connect(config, Some(message_handler), None, None, vec![], None)
635 .await
636 .unwrap();
637
638 for _ in 0..N {
639 client.send_bytes(b"ping".to_vec(), None).await.unwrap();
640 success_count += 1;
641 }
642
643 sleep(Duration::from_secs(1)).await;
644 let count_value: usize = Python::attach(|py| {
645 counter
646 .getattr(py, "get_count")
647 .unwrap()
648 .call0(py)
649 .unwrap()
650 .extract(py)
651 .unwrap()
652 });
653 assert_eq!(count_value, success_count);
654
655 client.send_close_message().await.unwrap();
657
658 sleep(Duration::from_secs(2)).await;
660
661 for _ in 0..N {
662 client.send_bytes(b"ping".to_vec(), None).await.unwrap();
663 success_count += 1;
664 }
665
666 sleep(Duration::from_secs(1)).await;
667 let count_value: usize = Python::attach(|py| {
668 counter
669 .getattr(py, "get_count")
670 .unwrap()
671 .call0(py)
672 .unwrap()
673 .extract(py)
674 .unwrap()
675 });
676 assert_eq!(count_value, success_count);
677 assert_eq!(success_count, N + N);
678
679 client.disconnect().await;
680 assert!(client.is_disconnected());
681 }
682
683 #[tokio::test]
684 async fn message_ping_test() {
685 Python::initialize();
686
687 let header_key = "hello-custom-key".to_string();
688 let header_value = "hello-custom-value".to_string();
689
690 let (checker, handler) = create_test_handler();
691
692 let server = TestServer::setup(header_key.clone(), header_value.clone()).await;
693 let config = WebSocketConfig::py_new(
694 format!("ws://127.0.0.1:{}", server.port),
695 vec![(header_key, header_value)],
696 Some(1),
697 Some("heartbeat message".to_string()),
698 None,
699 None,
700 None,
701 None,
702 None,
703 None,
704 None,
705 None,
706 );
707
708 let handler_clone = Python::attach(|py| handler.clone_ref(py));
709
710 let message_handler: MessageHandler = std::sync::Arc::new(move |msg: Message| {
711 Python::attach(|py| {
712 let data = match msg {
713 Message::Binary(data) | Message::Text(data) => data.to_vec(),
714 _ => return,
715 };
716 let py_bytes = PyBytes::new(py, &data);
717 if let Err(e) = handler_clone.call1(py, (py_bytes,)) {
718 log::error!("Error calling handler: {e}");
719 }
720 });
721 });
722
723 let client =
724 WebSocketClient::connect(config, Some(message_handler), None, None, vec![], None)
725 .await
726 .unwrap();
727
728 sleep(Duration::from_secs(2)).await;
729 let check_value: bool = Python::attach(|py| {
730 checker
731 .getattr(py, "get_check")
732 .unwrap()
733 .call0(py)
734 .unwrap()
735 .extract(py)
736 .unwrap()
737 });
738 assert!(check_value);
739
740 client.disconnect().await;
741 assert!(client.is_disconnected());
742 }
743}