1use std::{
17 sync::{
18 Arc,
19 atomic::{AtomicU8, Ordering},
20 },
21 time::Duration,
22};
23
24use nautilus_core::{
25 collections::into_ustr_vec,
26 python::{clone_py_object, to_pyruntime_err, to_pyvalue_err},
27};
28use pyo3::{Py, create_exception, exceptions::PyException, prelude::*, types::PyBytes};
29use tokio_tungstenite::tungstenite::{Message, Utf8Bytes};
30
31use crate::{
32 RECONNECTED,
33 mode::ConnectionMode,
34 ratelimiter::quota::Quota,
35 websocket::{
36 WebSocketClient, WebSocketConfig,
37 types::{MessageHandler, PingHandler, WriterCommand},
38 },
39};
40
41create_exception!(network, WebSocketClientError, PyException);
42
43#[allow(clippy::needless_pass_by_value)]
44fn to_websocket_pyerr(e: tokio_tungstenite::tungstenite::Error) -> PyErr {
45 PyErr::new::<WebSocketClientError, _>(e.to_string())
46}
47
48#[pymethods]
49#[pyo3_stub_gen::derive::gen_stub_pymethods]
50impl WebSocketConfig {
51 #[new]
75 #[allow(clippy::too_many_arguments)]
76 #[pyo3(signature = (
77 url,
78 headers,
79 heartbeat=None,
80 heartbeat_msg=None,
81 reconnect_timeout_ms=10_000,
82 reconnect_delay_initial_ms=2_000,
83 reconnect_delay_max_ms=30_000,
84 reconnect_backoff_factor=1.5,
85 reconnect_jitter_ms=100,
86 reconnect_max_attempts=None,
87 idle_timeout_ms=None,
88 ))]
89 fn py_new(
90 url: String,
91 headers: Vec<(String, String)>,
92 heartbeat: Option<u64>,
93 heartbeat_msg: Option<String>,
94 reconnect_timeout_ms: Option<u64>,
95 reconnect_delay_initial_ms: Option<u64>,
96 reconnect_delay_max_ms: Option<u64>,
97 reconnect_backoff_factor: Option<f64>,
98 reconnect_jitter_ms: Option<u64>,
99 reconnect_max_attempts: Option<u32>,
100 idle_timeout_ms: Option<u64>,
101 ) -> Self {
102 Self {
103 url,
104 headers,
105 heartbeat,
106 heartbeat_msg,
107 reconnect_timeout_ms,
108 reconnect_delay_initial_ms,
109 reconnect_delay_max_ms,
110 reconnect_backoff_factor,
111 reconnect_jitter_ms,
112 reconnect_max_attempts,
113 idle_timeout_ms,
114 }
115 }
116}
117
118#[pymethods]
119#[pyo3_stub_gen::derive::gen_stub_pymethods]
120impl WebSocketClient {
121 #[staticmethod]
133 #[pyo3(name = "connect", signature = (loop_, config, handler, ping_handler = None, post_reconnection = None, keyed_quotas = Vec::new(), default_quota = None))]
134 #[allow(clippy::too_many_arguments, clippy::needless_pass_by_value)]
135 fn py_connect(
136 loop_: Py<PyAny>,
137 config: WebSocketConfig,
138 handler: Py<PyAny>,
139 ping_handler: Option<Py<PyAny>>,
140 post_reconnection: Option<Py<PyAny>>,
141 keyed_quotas: Vec<(String, Quota)>,
142 default_quota: Option<Quota>,
143 py: Python<'_>,
144 ) -> PyResult<Bound<'_, PyAny>> {
145 let call_soon_threadsafe: Py<PyAny> = loop_.getattr(py, "call_soon_threadsafe")?;
146 let call_soon_clone = clone_py_object(&call_soon_threadsafe);
147 let handler_clone = clone_py_object(&handler);
148
149 let message_handler: MessageHandler = Arc::new(move |msg: Message| {
150 if matches!(msg, Message::Text(ref text) if text.as_str() == RECONNECTED) {
151 return;
152 }
153
154 Python::attach(|py| {
155 let py_bytes = match &msg {
156 Message::Binary(data) => PyBytes::new(py, data),
157 Message::Text(text) => PyBytes::new(py, text.as_bytes()),
158 _ => return,
159 };
160
161 if let Err(e) = call_soon_clone.call1(py, (&handler_clone, py_bytes)) {
162 log::error!("Error scheduling message handler on event loop: {e}");
163 }
164 });
165 });
166
167 let ping_handler_fn = ping_handler.map(|ping_handler| {
168 let ping_handler_clone = clone_py_object(&ping_handler);
169 let call_soon_clone = clone_py_object(&call_soon_threadsafe);
170
171 let ping_handler_fn: PingHandler = Arc::new(move |data: Vec<u8>| {
172 Python::attach(|py| {
173 let py_bytes = PyBytes::new(py, &data);
174 if let Err(e) = call_soon_clone.call1(py, (&ping_handler_clone, py_bytes)) {
175 log::error!("Error scheduling ping handler on event loop: {e}");
176 }
177 });
178 });
179 ping_handler_fn
180 });
181
182 let post_reconnection_fn = post_reconnection.map(|callback| {
183 let callback_clone = clone_py_object(&callback);
184 Arc::new(move || {
185 Python::attach(|py| {
186 if let Err(e) = callback_clone.call0(py) {
187 log::error!("Error calling post_reconnection handler: {e}");
188 }
189 });
190 }) as std::sync::Arc<dyn Fn() + Send + Sync>
191 });
192
193 pyo3_async_runtimes::tokio::future_into_py(py, async move {
194 Self::connect(
195 config,
196 Some(message_handler),
197 ping_handler_fn,
198 post_reconnection_fn,
199 keyed_quotas,
200 default_quota,
201 )
202 .await
203 .map_err(to_websocket_pyerr)
204 })
205 }
206
207 #[pyo3(name = "disconnect")]
212 #[allow(clippy::needless_pass_by_value)]
213 fn py_disconnect<'py>(slf: PyRef<'_, Self>, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
214 let connection_mode = slf.connection_mode.clone();
215 let state_notify = slf.state_notify.clone();
216 let mode = ConnectionMode::from_atomic(&connection_mode);
217 log::debug!("Close from mode {mode}");
218
219 pyo3_async_runtimes::tokio::future_into_py(py, async move {
220 match ConnectionMode::from_atomic(&connection_mode) {
221 ConnectionMode::Closed => {
222 log::debug!("WebSocket already closed");
223 }
224 ConnectionMode::Disconnect => {
225 log::debug!("WebSocket already disconnecting");
226 }
227 _ => {
228 connection_mode.store(ConnectionMode::Disconnect.as_u8(), Ordering::SeqCst);
229 state_notify.notify_one();
230
231 let timeout = tokio::time::timeout(Duration::from_secs(5), async {
232 while !ConnectionMode::from_atomic(&connection_mode).is_closed() {
233 tokio::time::sleep(Duration::from_millis(10)).await;
234 }
235 })
236 .await;
237
238 if timeout.is_err() {
239 log::error!("Timeout waiting for WebSocket to close, forcing closed state");
240 connection_mode.store(ConnectionMode::Closed.as_u8(), Ordering::SeqCst);
241 }
242 }
243 }
244
245 Ok(())
246 })
247 }
248
249 #[pyo3(name = "is_active")]
254 #[allow(clippy::needless_pass_by_value)]
255 fn py_is_active(slf: PyRef<'_, Self>) -> bool {
256 !slf.controller_task.is_finished()
257 }
258
259 #[pyo3(name = "is_reconnecting")]
264 #[allow(clippy::needless_pass_by_value)]
265 fn py_is_reconnecting(slf: PyRef<'_, Self>) -> bool {
266 slf.is_reconnecting()
267 }
268
269 #[pyo3(name = "is_disconnecting")]
273 #[allow(clippy::needless_pass_by_value)]
274 fn py_is_disconnecting(slf: PyRef<'_, Self>) -> bool {
275 slf.is_disconnecting()
276 }
277
278 #[pyo3(name = "is_closed")]
284 #[allow(clippy::needless_pass_by_value)]
285 fn py_is_closed(slf: PyRef<'_, Self>) -> bool {
286 slf.is_closed()
287 }
288
289 #[pyo3(name = "send")]
295 #[pyo3(signature = (data, keys=None))]
296 #[allow(clippy::needless_pass_by_value)]
297 fn py_send<'py>(
298 slf: PyRef<'_, Self>,
299 data: Vec<u8>,
300 py: Python<'py>,
301 keys: Option<Vec<String>>,
302 ) -> PyResult<Bound<'py, PyAny>> {
303 let rate_limiter = slf.rate_limiter.clone();
304 let writer_tx = slf.writer_tx.clone();
305 let mode = slf.connection_mode.clone();
306 let keys = keys.map(into_ustr_vec);
307
308 pyo3_async_runtimes::tokio::future_into_py(py, async move {
309 if !ConnectionMode::from_atomic(&mode).is_active() {
310 let msg = "Cannot send data: connection not active".to_string();
311 log::error!("{msg}");
312 return Err(to_pyruntime_err(std::io::Error::new(
313 std::io::ErrorKind::NotConnected,
314 msg,
315 )));
316 }
317
318 tokio::select! {
319 () = rate_limiter.await_keys_ready(keys.as_deref()) => {}
320 () = poll_until_closed(&mode) => {
321 return Err(to_pyruntime_err(std::io::Error::new(
322 std::io::ErrorKind::ConnectionAborted,
323 "Connection closed while waiting for rate limit",
324 )));
325 }
326 }
327
328 log::trace!("Sending binary: {data:?}");
329
330 let msg = Message::Binary(data.into());
331 writer_tx
332 .send(WriterCommand::Send(msg))
333 .map_err(to_pyruntime_err)
334 })
335 }
336
337 #[pyo3(name = "send_text")]
343 #[pyo3(signature = (data, keys=None))]
344 #[allow(clippy::needless_pass_by_value)]
345 fn py_send_text<'py>(
346 slf: PyRef<'_, Self>,
347 data: Vec<u8>,
348 py: Python<'py>,
349 keys: Option<Vec<String>>,
350 ) -> PyResult<Bound<'py, PyAny>> {
351 let data_str = String::from_utf8(data).map_err(to_pyvalue_err)?;
352 let data = Utf8Bytes::from(data_str);
353 let rate_limiter = slf.rate_limiter.clone();
354 let writer_tx = slf.writer_tx.clone();
355 let mode = slf.connection_mode.clone();
356 let keys = keys.map(into_ustr_vec);
357
358 pyo3_async_runtimes::tokio::future_into_py(py, async move {
359 if !ConnectionMode::from_atomic(&mode).is_active() {
360 let e = std::io::Error::new(
361 std::io::ErrorKind::NotConnected,
362 "Cannot send text: connection not active",
363 );
364 return Err(to_pyruntime_err(e));
365 }
366
367 tokio::select! {
368 () = rate_limiter.await_keys_ready(keys.as_deref()) => {}
369 () = poll_until_closed(&mode) => {
370 return Err(to_pyruntime_err(std::io::Error::new(
371 std::io::ErrorKind::ConnectionAborted,
372 "Connection closed while waiting for rate limit",
373 )));
374 }
375 }
376
377 log::trace!("Sending text: {data}");
378
379 let msg = Message::Text(data);
380 writer_tx
381 .send(WriterCommand::Send(msg))
382 .map_err(to_pyruntime_err)
383 })
384 }
385
386 #[pyo3(name = "send_pong")]
388 #[allow(clippy::needless_pass_by_value)]
389 fn py_send_pong<'py>(
390 slf: PyRef<'_, Self>,
391 data: Vec<u8>,
392 py: Python<'py>,
393 ) -> PyResult<Bound<'py, PyAny>> {
394 let writer_tx = slf.writer_tx.clone();
395 let mode = slf.connection_mode.clone();
396 let data_len = data.len();
397
398 pyo3_async_runtimes::tokio::future_into_py(py, async move {
399 if !ConnectionMode::from_atomic(&mode).is_active() {
400 let e = std::io::Error::new(
401 std::io::ErrorKind::NotConnected,
402 "Cannot send pong: connection not active",
403 );
404 return Err(to_pyruntime_err(e));
405 }
406 log::trace!("Sending pong frame ({data_len} bytes)");
407
408 let msg = Message::Pong(data.into());
409 writer_tx
410 .send(WriterCommand::Send(msg))
411 .map_err(to_pyruntime_err)
412 })
413 }
414}
415
416async fn poll_until_closed(mode: &Arc<AtomicU8>) {
417 loop {
418 if matches!(
419 ConnectionMode::from_atomic(mode),
420 ConnectionMode::Disconnect | ConnectionMode::Closed
421 ) {
422 break;
423 }
424
425 tokio::time::sleep(Duration::from_millis(100)).await;
426 }
427}
428
429#[cfg(test)]
430#[cfg(target_os = "linux")] mod tests {
432 use std::ffi::CString;
433
434 use futures_util::{SinkExt, StreamExt};
435 use nautilus_core::python::IntoPyObjectNautilusExt;
436 use pyo3::{prelude::*, types::PyBytes};
437 use tokio::{
438 net::TcpListener,
439 task::{self, JoinHandle},
440 time::{Duration, sleep},
441 };
442 use tokio_tungstenite::{
443 accept_hdr_async,
444 tungstenite::{
445 Message,
446 handshake::server::{self, Callback},
447 http::HeaderValue,
448 },
449 };
450
451 use crate::websocket::{MessageHandler, WebSocketClient, WebSocketConfig};
452
453 struct TestServer {
454 task: JoinHandle<()>,
455 port: u16,
456 }
457
458 #[derive(Debug, Clone)]
459 struct TestCallback {
460 key: String,
461 value: HeaderValue,
462 }
463
464 impl Callback for TestCallback {
465 #[allow(clippy::panic_in_result_fn)]
466 fn on_request(
467 self,
468 request: &server::Request,
469 response: server::Response,
470 ) -> Result<server::Response, server::ErrorResponse> {
471 let _ = response;
472 let value = request.headers().get(&self.key);
473 assert!(value.is_some());
474
475 if let Some(value) = request.headers().get(&self.key) {
476 assert_eq!(value, self.value);
477 }
478
479 Ok(response)
480 }
481 }
482
483 impl TestServer {
484 async fn setup(key: String, value: String) -> Self {
485 let server = TcpListener::bind("127.0.0.1:0").await.unwrap();
486 let port = TcpListener::local_addr(&server).unwrap().port();
487
488 let test_call_back = TestCallback {
489 key,
490 value: HeaderValue::from_str(&value).unwrap(),
491 };
492
493 let task = task::spawn(async move {
495 loop {
497 let (conn, _) = server.accept().await.unwrap();
498 let mut websocket = accept_hdr_async(conn, test_call_back.clone())
499 .await
500 .unwrap();
501
502 task::spawn(async move {
503 while let Some(Ok(msg)) = websocket.next().await {
504 match msg {
505 tokio_tungstenite::tungstenite::protocol::Message::Text(txt)
506 if txt == "close-now" =>
507 {
508 log::debug!("Forcibly closing from server side");
509 let _ = websocket.close(None).await;
511 break;
512 }
513 tokio_tungstenite::tungstenite::protocol::Message::Text(_)
515 | tokio_tungstenite::tungstenite::protocol::Message::Binary(_) => {
516 if websocket.send(msg).await.is_err() {
517 break;
518 }
519 }
520 tokio_tungstenite::tungstenite::protocol::Message::Close(
522 _frame,
523 ) => {
524 let _ = websocket.close(None).await;
525 break;
526 }
527 _ => {}
529 }
530 }
531 });
532 }
533 });
534
535 Self { task, port }
536 }
537 }
538
539 impl Drop for TestServer {
540 fn drop(&mut self) {
541 self.task.abort();
542 }
543 }
544
545 fn create_test_handler() -> (Py<PyAny>, Py<PyAny>) {
546 let code_raw = "
547class Counter:
548 def __init__(self):
549 self.count = 0
550 self.check = False
551
552 def handler(self, bytes):
553 msg = bytes.decode()
554 if msg == 'ping':
555 self.count += 1
556 elif msg == 'heartbeat message':
557 self.check = True
558
559 def get_check(self):
560 return self.check
561
562 def get_count(self):
563 return self.count
564
565counter = Counter()
566";
567
568 let code = CString::new(code_raw).unwrap();
569 let filename = CString::new("test".to_string()).unwrap();
570 let module = CString::new("test".to_string()).unwrap();
571 Python::attach(|py| {
572 let pymod = PyModule::from_code(py, &code, &filename, &module).unwrap();
573
574 let counter = pymod.getattr("counter").unwrap().into_py_any_unwrap(py);
575 let handler = counter
576 .getattr(py, "handler")
577 .unwrap()
578 .into_py_any_unwrap(py);
579
580 (counter, handler)
581 })
582 }
583
584 #[tokio::test]
585 async fn basic_client_test() {
586 const N: usize = 10;
587
588 Python::initialize();
589
590 let mut success_count = 0;
591 let header_key = "hello-custom-key".to_string();
592 let header_value = "hello-custom-value".to_string();
593
594 let server = TestServer::setup(header_key.clone(), header_value.clone()).await;
595 let (counter, handler) = create_test_handler();
596
597 let config = WebSocketConfig::py_new(
598 format!("ws://127.0.0.1:{}", server.port),
599 vec![(header_key, header_value)],
600 None,
601 None,
602 None,
603 None,
604 None,
605 None,
606 None,
607 None,
608 None,
609 );
610
611 let handler_clone = Python::attach(|py| handler.clone_ref(py));
612
613 let message_handler: MessageHandler = std::sync::Arc::new(move |msg: Message| {
614 Python::attach(|py| {
615 let data = match msg {
616 Message::Binary(data) => data.to_vec(),
617 Message::Text(text) => text.as_bytes().to_vec(),
618 _ => return,
619 };
620 let py_bytes = PyBytes::new(py, &data);
621 if let Err(e) = handler_clone.call1(py, (py_bytes,)) {
622 log::error!("Error calling handler: {e}");
623 }
624 });
625 });
626
627 let client =
628 WebSocketClient::connect(config, Some(message_handler), None, None, vec![], None)
629 .await
630 .unwrap();
631
632 for _ in 0..N {
633 client.send_bytes(b"ping".to_vec(), None).await.unwrap();
634 success_count += 1;
635 }
636
637 sleep(Duration::from_secs(1)).await;
638 let count_value: usize = Python::attach(|py| {
639 counter
640 .getattr(py, "get_count")
641 .unwrap()
642 .call0(py)
643 .unwrap()
644 .extract(py)
645 .unwrap()
646 });
647 assert_eq!(count_value, success_count);
648
649 client.send_close_message().await.unwrap();
651
652 sleep(Duration::from_secs(2)).await;
654 for _ in 0..N {
655 client.send_bytes(b"ping".to_vec(), None).await.unwrap();
656 success_count += 1;
657 }
658
659 sleep(Duration::from_secs(1)).await;
660 let count_value: usize = Python::attach(|py| {
661 counter
662 .getattr(py, "get_count")
663 .unwrap()
664 .call0(py)
665 .unwrap()
666 .extract(py)
667 .unwrap()
668 });
669 assert_eq!(count_value, success_count);
670 assert_eq!(success_count, N + N);
671
672 client.disconnect().await;
673 assert!(client.is_disconnected());
674 }
675
676 #[tokio::test]
677 async fn message_ping_test() {
678 Python::initialize();
679
680 let header_key = "hello-custom-key".to_string();
681 let header_value = "hello-custom-value".to_string();
682
683 let (checker, handler) = create_test_handler();
684
685 let server = TestServer::setup(header_key.clone(), header_value.clone()).await;
686 let config = WebSocketConfig::py_new(
687 format!("ws://127.0.0.1:{}", server.port),
688 vec![(header_key, header_value)],
689 Some(1),
690 Some("heartbeat message".to_string()),
691 None,
692 None,
693 None,
694 None,
695 None,
696 None,
697 None,
698 );
699
700 let handler_clone = Python::attach(|py| handler.clone_ref(py));
701
702 let message_handler: MessageHandler = std::sync::Arc::new(move |msg: Message| {
703 Python::attach(|py| {
704 let data = match msg {
705 Message::Binary(data) => data.to_vec(),
706 Message::Text(text) => text.as_bytes().to_vec(),
707 _ => return,
708 };
709 let py_bytes = PyBytes::new(py, &data);
710 if let Err(e) = handler_clone.call1(py, (py_bytes,)) {
711 log::error!("Error calling handler: {e}");
712 }
713 });
714 });
715
716 let client =
717 WebSocketClient::connect(config, Some(message_handler), None, None, vec![], None)
718 .await
719 .unwrap();
720
721 sleep(Duration::from_secs(2)).await;
722 let check_value: bool = Python::attach(|py| {
723 checker
724 .getattr(py, "get_check")
725 .unwrap()
726 .call0(py)
727 .unwrap()
728 .extract(py)
729 .unwrap()
730 });
731 assert!(check_value);
732
733 client.disconnect().await;
734 assert!(client.is_disconnected());
735 }
736}