1use async_trait::async_trait;
2use flate2::read::ZlibDecoder;
3use futures::{SinkExt, StreamExt, future::try_join_all, stream::FuturesUnordered};
4use http::header::USER_AGENT;
5use serde::de::DeserializeOwned;
6use serde_json::{Value, json};
7use std::{
8 collections::{BTreeMap, HashMap, VecDeque},
9 io::Read,
10 marker::PhantomData,
11 mem::take,
12 sync::{
13 Arc,
14 atomic::{AtomicBool, AtomicUsize, Ordering},
15 },
16 time::Duration,
17};
18use tokio::{
19 net::TcpStream,
20 select, spawn,
21 sync::{
22 Mutex, Notify,
23 mpsc::{Receiver, Sender, UnboundedSender, channel, unbounded_channel},
24 oneshot,
25 },
26 task::JoinHandle,
27 time::{sleep, timeout},
28};
29use tokio_tungstenite::{
30 Connector, MaybeTlsStream, WebSocketStream, connect_async_tls_with_config,
31 tungstenite::{
32 Message,
33 client::IntoClientRequest,
34 protocol::{CloseFrame, WebSocketConfig, frame::coding::CloseCode},
35 },
36};
37use tokio_util::time::DelayQueue;
38use tracing::{debug, error, info, warn};
39
40use super::{
41 config::{AgentConnector, ConfigurationWebsocketApi, ConfigurationWebsocketStreams},
42 errors::{WebsocketConnectionFailureReason, WebsocketError},
43 models::{StreamId, WebsocketApiResponse, WebsocketEvent, WebsocketMode},
44 utils::{build_websocket_api_message, normalize_stream_id, random_string, validate_time_unit},
45};
46
47pub type WebSocketClient = WebSocketStream<MaybeTlsStream<TcpStream>>;
48
49const MAX_CONN_DURATION: Duration = Duration::from_secs(23 * 60 * 60);
50
51pub struct Subscription {
52 handle: JoinHandle<()>,
53}
54
55impl Subscription {
56 pub fn unsubscribe(self) {
71 self.handle.abort();
72 }
73}
74
75#[derive(Clone)]
76pub enum WebsocketBase {
77 WebsocketApi(Arc<WebsocketApi>),
78 WebsocketStreams(Arc<WebsocketStreams>),
79}
80
81pub struct WebsocketEventEmitter {
82 subscribers: Arc<std::sync::Mutex<Vec<UnboundedSender<WebsocketEvent>>>>,
83}
84
85impl Default for WebsocketEventEmitter {
86 fn default() -> Self {
87 Self::new()
88 }
89}
90
91impl WebsocketEventEmitter {
92 #[must_use]
93 pub fn new() -> Self {
94 Self {
95 subscribers: Arc::new(std::sync::Mutex::new(Vec::new())),
96 }
97 }
98
99 pub fn subscribe<F>(&self, mut callback: F) -> Subscription
126 where
127 F: FnMut(WebsocketEvent) + Send + 'static,
128 {
129 let (tx, mut rx) = unbounded_channel();
130 let mut guard = match self.subscribers.lock() {
131 Ok(guard) => guard,
132 Err(poisoned) => poisoned.into_inner(),
133 };
134 guard.push(tx);
135 drop(guard);
136
137 let handle = spawn(async move {
138 while let Some(event) = rx.recv().await {
139 callback(event);
140 }
141 });
142 Subscription { handle }
143 }
144
145 pub fn emit(&self, event: &WebsocketEvent) {
155 let mut guard = match self.subscribers.lock() {
156 Ok(guard) => guard,
157 Err(poisoned) => poisoned.into_inner(),
158 };
159
160 guard.retain(|tx| {
161 if tx.send(event.clone()).is_ok() {
162 true
163 } else {
164 warn!("subscriber dropped without unsubscribing");
165 false
166 }
167 });
168 }
169}
170
171#[async_trait]
186pub trait WebsocketHandler: Send + Sync + 'static {
187 async fn on_open(&self, url: String, connection: Arc<WebsocketConnection>);
188 async fn on_message(&self, data: String, connection: Arc<WebsocketConnection>);
189 async fn get_reconnect_url(
190 &self,
191 default_url: String,
192 connection: Arc<WebsocketConnection>,
193 ) -> String;
194}
195
196pub struct PendingRequest {
197 pub completion: oneshot::Sender<Result<Value, WebsocketError>>,
198}
199
200#[derive(Clone)]
201pub struct WebsocketSessionLogonReq {
202 pub method: String,
203 pub payload: BTreeMap<String, Value>,
204 pub options: WebsocketMessageSendOptions,
205}
206
207pub struct WebsocketConnectionState {
208 pub reconnection_pending: bool,
209 pub renewal_pending: bool,
210 pub close_initiated: bool,
211 pub pending_requests: HashMap<String, PendingRequest>,
212 pub pending_subscriptions: VecDeque<String>,
213 pub stream_callbacks: HashMap<String, Vec<Arc<dyn Fn(&Value) + Send + Sync + 'static>>>,
214 pub is_session_logged_on: bool,
215 pub session_logon_req: Option<WebsocketSessionLogonReq>,
216 pub url_path: Option<String>,
217 pub handler: Option<Arc<dyn WebsocketHandler>>,
218 pub ws_write_tx: Option<UnboundedSender<Message>>,
219}
220
221impl Default for WebsocketConnectionState {
222 fn default() -> Self {
223 Self::new()
224 }
225}
226
227impl WebsocketConnectionState {
228 #[must_use]
229 pub fn new() -> Self {
230 Self {
231 reconnection_pending: false,
232 renewal_pending: false,
233 close_initiated: false,
234 pending_requests: HashMap::new(),
235 pending_subscriptions: VecDeque::new(),
236 stream_callbacks: HashMap::new(),
237 is_session_logged_on: false,
238 session_logon_req: None,
239 url_path: None,
240 handler: None,
241 ws_write_tx: None,
242 }
243 }
244}
245
246pub struct WebsocketConnection {
247 pub id: String,
248 pub drain_notify: Notify,
249 pub state: Mutex<WebsocketConnectionState>,
250}
251
252impl WebsocketConnection {
253 pub fn new(id: impl Into<String>) -> Arc<Self> {
254 Arc::new(Self {
255 id: id.into(),
256 drain_notify: Notify::new(),
257 state: Mutex::new(WebsocketConnectionState::new()),
258 })
259 }
260
261 pub async fn set_handler(&self, handler: Arc<dyn WebsocketHandler>) {
262 let mut conn_state = self.state.lock().await;
263 conn_state.handler = Some(handler);
264 }
265}
266
267struct ReconnectEntry {
268 connection_id: String,
269 url: String,
270 is_renewal: bool,
271}
272
273pub struct WebsocketCommon {
274 pub events: WebsocketEventEmitter,
275 mode: WebsocketMode,
276 round_robin_index: AtomicUsize,
277 connection_pool: Vec<Arc<WebsocketConnection>>,
278 reconnect_tx: Sender<ReconnectEntry>,
279 renewal_tx: Sender<(String, String)>,
280 reconnect_delay: usize,
281 agent: Option<AgentConnector>,
282 user_agent: Option<String>,
283}
284
285impl WebsocketCommon {
286 #[must_use]
287 pub fn new(
288 mut initial_pool: Vec<Arc<WebsocketConnection>>,
289 mode: WebsocketMode,
290 reconnect_delay: usize,
291 agent: Option<AgentConnector>,
292 user_agent: Option<String>,
293 ) -> Arc<Self> {
294 if initial_pool.is_empty() {
295 for _ in 0..mode.pool_size() {
296 let id = random_string();
297 initial_pool.push(WebsocketConnection::new(id));
298 }
299 }
300
301 let (reconnect_tx, reconnect_rx) = channel::<ReconnectEntry>(mode.pool_size());
302 let (renewal_tx, renewal_rx) = channel::<(String, String)>(mode.pool_size());
303
304 let common = Arc::new(Self {
305 events: WebsocketEventEmitter::new(),
306 mode,
307 round_robin_index: AtomicUsize::new(0),
308 connection_pool: initial_pool,
309 reconnect_tx,
310 renewal_tx,
311 reconnect_delay,
312 agent,
313 user_agent,
314 });
315
316 Self::spawn_reconnect_loop(Arc::clone(&common), reconnect_rx);
317 Self::spawn_renewal_loop(&Arc::clone(&common), renewal_rx);
318
319 common
320 }
321
322 fn spawn_reconnect_loop(common: Arc<Self>, mut reconnect_rx: Receiver<ReconnectEntry>) {
340 spawn(async move {
341 while let Some(entry) = reconnect_rx.recv().await {
342 info!("Scheduling reconnect for id {}", entry.connection_id);
343
344 if !entry.is_renewal {
345 sleep(Duration::from_millis(common.reconnect_delay as u64)).await;
346 }
347
348 if let Some(conn_arc) = common
349 .connection_pool
350 .iter()
351 .find(|c| c.id == entry.connection_id)
352 .cloned()
353 {
354 let common_clone = Arc::clone(&common);
355 if let Err(err) = common_clone
356 .init_connect(&entry.url, entry.is_renewal, Some(conn_arc.clone()))
357 .await
358 {
359 error!(
360 "Reconnect failed for {} → {}: {:?}",
361 entry.connection_id, entry.url, err
362 );
363 }
364
365 sleep(Duration::from_secs(1)).await;
366 } else {
367 warn!("No connection {} found for reconnect", entry.connection_id);
368 }
369 }
370 });
371 }
372
373 fn spawn_renewal_loop(common: &Arc<Self>, renewal_rx: Receiver<(String, String)>) {
387 let common = Arc::clone(common);
388 spawn(async move {
389 let mut dq = DelayQueue::new();
390 let mut renewal_rx = renewal_rx;
391
392 loop {
393 select! {
394 Some((conn_id, url)) = renewal_rx.recv() => {
395 debug!("Scheduling renewal for {}", conn_id);
396 dq.insert((conn_id, url), MAX_CONN_DURATION);
397 }
398
399 Some(expired) = dq.next() => {
400 let (conn_id, default_url) = expired.into_inner();
401
402 if let Some(conn_arc) = common
403 .connection_pool
404 .iter()
405 .find(|c| c.id == conn_id)
406 .cloned()
407 {
408 debug!("Renewing connection {}", conn_id);
409 let url = common
410 .get_reconnect_url(&default_url, Arc::clone(&conn_arc))
411 .await;
412 if let Err(e) = common.reconnect_tx.send(ReconnectEntry {
413 connection_id: conn_id.clone(),
414 url,
415 is_renewal: true,
416 }).await {
417 error!(
418 "Failed to enqueue renewal for {}: {:?}",
419 conn_id, e
420 );
421 }
422 } else {
423 warn!("No connection {} found for renewal", conn_id);
424 }
425 }
426 }
427 }
428 });
429 }
430
431 pub async fn is_connection_ready(
449 &self,
450 connection: &WebsocketConnection,
451 allow_non_established: bool,
452 ) -> bool {
453 let conn_state = connection.state.lock().await;
454 (allow_non_established || conn_state.ws_write_tx.is_some())
455 && !conn_state.reconnection_pending
456 && !conn_state.close_initiated
457 }
458
459 async fn is_connected(&self, connection: Option<&Arc<WebsocketConnection>>) -> bool {
475 if let Some(conn_arc) = connection {
476 return self.is_connection_ready(conn_arc, false).await;
477 }
478
479 for conn_arc in &self.connection_pool {
480 if self.is_connection_ready(conn_arc, false).await {
481 return true;
482 }
483 }
484
485 false
486 }
487
488 async fn get_available_connections(
505 &self,
506 allow_non_established: bool,
507 url_path: Option<&str>,
508 ) -> Vec<Arc<WebsocketConnection>> {
509 if matches!(self.mode, WebsocketMode::Single) && url_path.is_none() {
510 return vec![Arc::clone(&self.connection_pool[0])];
511 }
512
513 let mut ready = Vec::new();
514 for conn in &self.connection_pool {
515 if self.is_connection_ready(conn, allow_non_established).await {
516 ready.push(Arc::clone(conn));
517 }
518 }
519
520 ready
521 }
522
523 async fn get_connection(
544 &self,
545 allow_non_established: bool,
546 url_path: Option<&str>,
547 ) -> Result<Arc<WebsocketConnection>, WebsocketError> {
548 let candidates = self
549 .get_available_connections(allow_non_established, url_path)
550 .await;
551
552 let mut ready = Vec::new();
553 for conn in candidates {
554 if let Some(path) = url_path {
555 let st = conn.state.lock().await;
556 if st.url_path.as_deref() != Some(path) {
557 continue;
558 }
559 }
560 ready.push(conn);
561 }
562
563 if ready.is_empty() {
564 return Err(WebsocketError::NotConnected);
565 }
566
567 let idx = self.round_robin_index.fetch_add(1, Ordering::Relaxed) % ready.len();
568
569 Ok(Arc::clone(&ready[idx]))
570 }
571
572 async fn close_connection_gracefully(
589 &self,
590 ws_write_tx_to_close: UnboundedSender<Message>,
591 connection: Arc<WebsocketConnection>,
592 ) -> Result<(), WebsocketError> {
593 debug!("Waiting for pending requests to complete before disconnecting.");
594
595 let drain = async {
596 loop {
597 {
598 let conn_state = connection.state.lock().await;
599 if conn_state.pending_requests.is_empty() {
600 debug!("All pending requests completed, proceeding to close.");
601 break;
602 }
603 }
604 connection.drain_notify.notified().await;
605 }
606 };
607
608 if timeout(Duration::from_secs(30), drain).await.is_err() {
609 warn!("Timeout waiting for pending requests; forcing close.");
610 }
611
612 info!("Closing WebSocket connection for {}", connection.id);
613 let _ = ws_write_tx_to_close.send(Message::Close(Some(CloseFrame {
614 code: CloseCode::Normal,
615 reason: "".into(),
616 })));
617
618 Ok(())
619 }
620
621 async fn get_reconnect_url(
638 &self,
639 default_url: &str,
640 connection: Arc<WebsocketConnection>,
641 ) -> String {
642 if let Some(handler) = {
643 let conn_state = connection.state.lock().await;
644 conn_state.handler.clone()
645 } {
646 return handler
647 .get_reconnect_url(default_url.to_string(), Arc::clone(&connection))
648 .await;
649 }
650
651 default_url.to_string()
652 }
653
654 async fn on_open(
676 &self,
677 url: String,
678 connection: Arc<WebsocketConnection>,
679 old_ws_writer: Option<UnboundedSender<Message>>,
680 ) {
681 if let Some(handler) = {
682 let conn_state = connection.state.lock().await;
683 conn_state.handler.clone()
684 } {
685 handler.on_open(url.clone(), Arc::clone(&connection)).await;
686 }
687
688 let conn_id = &connection.id;
689 info!("Connected to WebSocket Server with id {}: {}", conn_id, url);
690
691 {
692 let mut conn_state = connection.state.lock().await;
693
694 if conn_state.renewal_pending {
695 conn_state.renewal_pending = false;
696 drop(conn_state);
697 if let Some(tx) = old_ws_writer {
698 info!("Connection renewal in progress; closing previous connection.");
699 let _ = self
700 .close_connection_gracefully(tx, Arc::clone(&connection))
701 .await;
702 }
703 return;
704 }
705
706 if conn_state.close_initiated {
707 drop(conn_state);
708 if let Some(tx) = connection.state.lock().await.ws_write_tx.clone() {
709 info!("Close initiated; closing connection.");
710 let _ = self
711 .close_connection_gracefully(tx, Arc::clone(&connection))
712 .await;
713 }
714 return;
715 }
716
717 self.events.emit(&WebsocketEvent::Open);
718 }
719 }
720
721 async fn on_message(&self, msg: String, connection: Arc<WebsocketConnection>) {
733 if let Some(handler) = connection.state.lock().await.handler.clone() {
734 let handler_clone = handler.clone();
735 let data = msg.clone();
736 let conn_clone = connection.clone();
737 spawn(async move {
738 handler_clone.on_message(data, conn_clone).await;
739 });
740 }
741 self.events.emit(&WebsocketEvent::Message(msg));
742 }
743
744 async fn create_websocket(
767 url: &str,
768 agent: Option<AgentConnector>,
769 user_agent: Option<String>,
770 ) -> Result<WebSocketStream<MaybeTlsStream<tokio::net::TcpStream>>, WebsocketError> {
771 let mut req = url
772 .into_client_request()
773 .map_err(|e| WebsocketError::Handshake(e.to_string()))?;
774
775 if let Some(ua) = user_agent {
776 req.headers_mut().insert(USER_AGENT, ua.parse().unwrap());
777 }
778
779 let ws_config: Option<WebSocketConfig> = None;
780 let disable_nagle = false;
781 let connector: Option<Connector> = agent.map(|dbg| dbg.0);
782
783 let timeout_duration = Duration::from_secs(10);
784 let handshake = connect_async_tls_with_config(req, ws_config, disable_nagle, connector);
785 match timeout(timeout_duration, handshake).await {
786 Ok(Ok((ws_stream, response))) => {
787 debug!("WebSocket connected: {:?}", response);
788 Ok(ws_stream)
789 }
790 Ok(Err(e)) => {
791 let msg = e.to_string();
792 error!("WebSocket handshake failed: {}", msg);
793 Err(WebsocketError::Handshake(msg))
794 }
795 Err(_) => {
796 error!(
797 "WebSocket connection timed out after {}s",
798 timeout_duration.as_secs()
799 );
800 Err(WebsocketError::Timeout)
801 }
802 }
803 }
804
805 async fn connect_pool(
825 self: Arc<Self>,
826 url: &str,
827 connections: Option<Vec<Arc<WebsocketConnection>>>,
828 ) -> Result<(), WebsocketError> {
829 let pool: Vec<Arc<WebsocketConnection>> = match connections {
830 Some(v) => v,
831 None => self.connection_pool.clone(),
832 };
833
834 let mut tasks = FuturesUnordered::new();
835
836 for conn in pool {
837 let common = Arc::clone(&self);
838 let url = url.to_owned();
839
840 tasks.push(async move {
841 match common.init_connect(&url, false, Some(conn)).await {
842 Ok(()) => {
843 info!("Successfully connected to {}", url);
844 Ok(())
845 }
846 Err(err) => {
847 error!("Failed to connect to {}: {:?}", url, err);
848 Err(err)
849 }
850 }
851 });
852 }
853
854 while let Some(result) = tasks.next().await {
855 result?;
856 }
857
858 Ok(())
859 }
860
861 async fn init_connect(
882 self: Arc<Self>,
883 url: &str,
884 is_renewal: bool,
885 connection: Option<Arc<WebsocketConnection>>,
886 ) -> Result<(), WebsocketError> {
887 let conn = connection.unwrap_or(self.get_connection(true, None).await?);
888
889 {
890 let mut conn_state = conn.state.lock().await;
891 if conn_state.renewal_pending && is_renewal {
892 info!("Renewal in progress {}→{}", conn.id, url);
893 return Ok(());
894 }
895 if conn_state.ws_write_tx.is_some() && !is_renewal && !conn_state.reconnection_pending {
896 info!("Exists {}; skipping {}", conn.id, url);
897 return Ok(());
898 }
899 if is_renewal {
900 conn_state.renewal_pending = true;
901 }
902
903 conn_state.is_session_logged_on = false;
904 }
905
906 let ws = Self::create_websocket(url, self.agent.clone(), self.user_agent.clone())
907 .await
908 .map_err(|e| {
909 error!("Handshake failed {}: {:?}", url, e);
910 e
911 })?;
912
913 info!("Established {} → {}", conn.id, url);
914
915 if let Err(e) = self.renewal_tx.try_send((conn.id.clone(), url.to_string())) {
916 error!("Failed to schedule renewal for {}: {:?}", conn.id, e);
917 }
918
919 let (write_half, mut read_half) = ws.split();
920 let (tx, mut rx) = unbounded_channel::<Message>();
921
922 let old_writer = {
923 let mut conn_state = conn.state.lock().await;
924 conn_state.reconnection_pending = false;
925 conn_state.ws_write_tx.replace(tx.clone())
926 };
927
928 {
929 let wconn = conn.clone();
930 let common_clone = self.clone();
931 let writer_url = url.to_string();
932
933 spawn(async move {
934 let mut sink = write_half;
935 while let Some(msg) = rx.recv().await {
936 if let Err(e) = sink.send(msg).await {
937 let failure_reason =
938 WebsocketConnectionFailureReason::from_tungstenite_error(&e);
939
940 error!(
941 "Write error on {}: {:?}, classified as {:?}",
942 wconn.id, e, failure_reason
943 );
944
945 let mut conn_state = wconn.state.lock().await;
947 if !conn_state.close_initiated
948 && !is_renewal
949 && failure_reason.should_reconnect()
950 {
951 info!(
952 "Writer connection {} has recoverable error, attempting reconnection: {:?}",
953 wconn.id, failure_reason
954 );
955 conn_state.reconnection_pending = true;
956 conn_state.is_session_logged_on = false;
957 drop(conn_state);
958 let reconnect_url = common_clone
959 .get_reconnect_url(&writer_url, Arc::clone(&wconn))
960 .await;
961
962 let _ = common_clone
963 .reconnect_tx
964 .send(ReconnectEntry {
965 connection_id: wconn.id.clone(),
966 url: reconnect_url,
967 is_renewal: false,
968 })
969 .await;
970 } else {
971 warn!(
972 "Writer connection {} has permanent error, will not reconnect: {:?}",
973 wconn.id, failure_reason
974 );
975 }
976
977 break;
978 }
979 }
980 debug!("Writer {} exit", wconn.id);
981 });
982 }
983
984 {
985 let common = self.clone();
986 let conn = conn.clone();
987 let url = url.to_string();
988 spawn(async move {
989 common.on_open(url, conn, old_writer).await;
990 });
991 }
992
993 {
994 let common = self.clone();
995 let reader_conn = conn.clone();
996 let read_url = url.to_string();
997
998 spawn(async move {
999 while let Some(item) = read_half.next().await {
1000 match item {
1001 Ok(Message::Text(msg)) => {
1002 common
1003 .on_message(msg.to_string(), Arc::clone(&reader_conn))
1004 .await;
1005 }
1006 Ok(Message::Binary(bin)) => {
1007 let mut decoder = ZlibDecoder::new(&bin[..]);
1008 let mut decompressed = String::new();
1009 if let Err(err) = decoder.read_to_string(&mut decompressed) {
1010 error!("Binary message decompress failed: {:?}", err);
1011 continue;
1012 }
1013 common
1014 .on_message(decompressed, Arc::clone(&reader_conn))
1015 .await;
1016 }
1017 Ok(Message::Ping(payload)) => {
1018 info!("PING received from server on {}", reader_conn.id);
1019 common.events.emit(&WebsocketEvent::Ping);
1020 if let Some(tx) = reader_conn.state.lock().await.ws_write_tx.clone() {
1021 let _ = tx.send(Message::Pong(payload));
1022 info!(
1023 "Responded PONG to server's PING message on {}",
1024 reader_conn.id
1025 );
1026 }
1027 }
1028 Ok(Message::Pong(_)) => {
1029 info!("Received PONG from server on {}", reader_conn.id);
1030 common.events.emit(&WebsocketEvent::Pong);
1031 }
1032 Ok(Message::Close(frame)) => {
1033 let (code, reason) = frame
1034 .map_or((1000, String::new()), |CloseFrame { code, reason }| {
1035 (code.into(), reason.to_string())
1036 });
1037 common
1038 .events
1039 .emit(&WebsocketEvent::Close(code, reason.clone()));
1040
1041 let user_initiated = {
1043 let conn_state = reader_conn.state.lock().await;
1044 conn_state.close_initiated
1045 };
1046
1047 let failure_reason = WebsocketConnectionFailureReason::from_close_code(
1048 code,
1049 user_initiated,
1050 );
1051
1052 info!(
1053 "Connection {} received close frame: code={}, reason='{}', classified as {:?}",
1054 reader_conn.id, code, reason, failure_reason
1055 );
1056
1057 let mut conn_state = reader_conn.state.lock().await;
1058 if !conn_state.close_initiated
1059 && !is_renewal
1060 && failure_reason.should_reconnect()
1061 {
1062 info!(
1063 "Connection {} received close frame with reconnectable failure: {:?}",
1064 reader_conn.id, failure_reason
1065 );
1066 conn_state.reconnection_pending = true;
1067 conn_state.is_session_logged_on = false;
1068 drop(conn_state);
1069 let reconnect_url = common
1070 .get_reconnect_url(&read_url, Arc::clone(&reader_conn))
1071 .await;
1072
1073 let _ = common
1074 .reconnect_tx
1075 .send(ReconnectEntry {
1076 connection_id: reader_conn.id.clone(),
1077 url: reconnect_url,
1078 is_renewal: false,
1079 })
1080 .await;
1081 } else {
1082 warn!(
1083 "Connection {} received close frame with non-reconnectable failure: {:?}",
1084 reader_conn.id, failure_reason
1085 );
1086
1087 common.events.emit(&WebsocketEvent::Error(format!(
1089 "[CRITICAL] Connection {} permanently failed: {:?}",
1090 reader_conn.id, failure_reason
1091 )));
1092 }
1093
1094 break;
1095 }
1096 Err(e) => {
1097 let failure_reason =
1099 WebsocketConnectionFailureReason::from_tungstenite_error(&e);
1100
1101 error!(
1102 "WebSocket error on {}: {:?}, classified as {:?}",
1103 reader_conn.id, e, failure_reason
1104 );
1105
1106 common.events.emit(&WebsocketEvent::Error(e.to_string()));
1107
1108 let mut conn_state = reader_conn.state.lock().await;
1110 if !conn_state.close_initiated
1111 && !is_renewal
1112 && failure_reason.should_reconnect()
1113 {
1114 info!(
1115 "Connection {} has recoverable error, attempting reconnection: {:?}",
1116 reader_conn.id, failure_reason
1117 );
1118 conn_state.reconnection_pending = true;
1119 conn_state.is_session_logged_on = false;
1120 drop(conn_state);
1121 let reconnect_url = common
1122 .get_reconnect_url(&read_url, Arc::clone(&reader_conn))
1123 .await;
1124
1125 let _ = common
1126 .reconnect_tx
1127 .send(ReconnectEntry {
1128 connection_id: reader_conn.id.clone(),
1129 url: reconnect_url,
1130 is_renewal: false,
1131 })
1132 .await;
1133 } else {
1134 warn!(
1135 "Connection {} has permanent error, will not reconnect: {:?}",
1136 reader_conn.id, failure_reason
1137 );
1138
1139 common.events.emit(&WebsocketEvent::Error(format!(
1141 "[CRITICAL] Connection {} permanently failed: {:?}",
1142 reader_conn.id, failure_reason
1143 )));
1144 }
1145
1146 break;
1147 }
1148 _ => {}
1149 }
1150 }
1151
1152 info!("WebSocket stream ended for connection {}", reader_conn.id);
1154
1155 let failure_reason = WebsocketConnectionFailureReason::StreamEnded;
1157
1158 info!(
1159 "WebSocket stream ended for connection {}, classified as {:?}",
1160 reader_conn.id, failure_reason
1161 );
1162
1163 let mut conn_state = reader_conn.state.lock().await;
1164 if !conn_state.close_initiated && !is_renewal && failure_reason.should_reconnect() {
1165 info!(
1166 "Connection {} stream ended unexpectedly, attempting reconnection",
1167 reader_conn.id
1168 );
1169 conn_state.reconnection_pending = true;
1170 conn_state.is_session_logged_on = false;
1171 conn_state.ws_write_tx = None;
1172 drop(conn_state);
1173 let reconnect_url = common
1174 .get_reconnect_url(&read_url, Arc::clone(&reader_conn))
1175 .await;
1176
1177 let _ = common
1178 .reconnect_tx
1179 .send(ReconnectEntry {
1180 connection_id: reader_conn.id.clone(),
1181 url: reconnect_url,
1182 is_renewal: false,
1183 })
1184 .await;
1185 } else {
1186 debug!(
1187 "Connection {} stream ended normally (close_initiated={}, is_renewal={})",
1188 reader_conn.id, conn_state.close_initiated, is_renewal
1189 );
1190 }
1191
1192 debug!("Reader actor for {} exiting", reader_conn.id);
1193 });
1194 }
1195
1196 Ok(())
1197 }
1198 async fn disconnect(&self) -> Result<(), WebsocketError> {
1213 if !self.is_connected(None).await {
1214 warn!("No active connection to close.");
1215 return Ok(());
1216 }
1217
1218 let mut shutdowns = FuturesUnordered::new();
1219 for conn in &self.connection_pool {
1220 {
1221 let mut conn_state = conn.state.lock().await;
1222 conn_state.close_initiated = true;
1223 if let Some(tx) = &conn_state.ws_write_tx {
1224 shutdowns.push(self.close_connection_gracefully(tx.clone(), Arc::clone(conn)));
1225 }
1226 }
1227 }
1228
1229 let close_all = async {
1230 while let Some(result) = shutdowns.next().await {
1231 result?;
1232 }
1233 Ok::<(), WebsocketError>(())
1234 };
1235
1236 match timeout(Duration::from_secs(30), close_all).await {
1237 Ok(Ok(())) => {
1238 info!("Disconnected all WebSocket connections successfully.");
1239 for conn in &self.connection_pool {
1240 let mut st = conn.state.lock().await;
1241 st.is_session_logged_on = false;
1242 st.session_logon_req = None;
1243 }
1244 Ok(())
1245 }
1246 Ok(Err(err)) => {
1247 error!("Error while disconnecting: {:?}", err);
1248 Err(err)
1249 }
1250 Err(_) => {
1251 error!("Timed out while disconnecting WebSocket connections.");
1252 Err(WebsocketError::Timeout)
1253 }
1254 }
1255 }
1256
1257 async fn ping_server(&self) {
1270 let mut ready = Vec::new();
1271 for conn in &self.connection_pool {
1272 if self.is_connection_ready(conn, false).await {
1273 let id = conn.id.clone();
1274 let ws_write_tx = {
1275 let conn_state = conn.state.lock().await;
1276 conn_state.ws_write_tx.clone()
1277 };
1278 ready.push((id, ws_write_tx));
1279 }
1280 }
1281
1282 if ready.is_empty() {
1283 warn!("No ready connections for PING.");
1284 return;
1285 }
1286 info!("Sending PING to {} WebSocket connections.", ready.len());
1287
1288 let mut tasks = FuturesUnordered::new();
1289 for (id, ws_write_tx_opt) in ready {
1290 if let Some(tx) = ws_write_tx_opt {
1291 tasks.push(async move {
1292 if let Err(e) = tx.send(Message::Ping(Vec::new().into())) {
1293 error!("Failed to send PING to {}: {:?}", id, e);
1294 } else {
1295 debug!("Sent PING to connection {}", id);
1296 }
1297 });
1298 } else {
1299 error!("Connection {} was ready but has no write channel", id);
1300 }
1301 }
1302
1303 while tasks.next().await.is_some() {}
1304 }
1305
1306 async fn send(
1324 &self,
1325 payload: String,
1326 id: Option<String>,
1327 wait_for_reply: bool,
1328 timeout: Duration,
1329 connection: Option<Arc<WebsocketConnection>>,
1330 ) -> Result<Option<oneshot::Receiver<Result<Value, WebsocketError>>>, WebsocketError> {
1331 let conn = if let Some(c) = connection {
1332 c
1333 } else {
1334 self.get_connection(false, None).await?
1335 };
1336
1337 if !self.is_connected(Some(&conn)).await {
1338 warn!("Send attempted on a non-connected socket");
1339 return Err(WebsocketError::NotConnected);
1340 }
1341
1342 let ws_write_tx = {
1343 let conn_state = conn.state.lock().await;
1344 conn_state
1345 .ws_write_tx
1346 .clone()
1347 .ok_or(WebsocketError::NotConnected)?
1348 };
1349
1350 debug!("Sending message to WebSocket on connection {}", conn.id);
1351
1352 ws_write_tx
1353 .send(Message::Text(payload.clone().into()))
1354 .map_err(|_| WebsocketError::NotConnected)?;
1355
1356 if !wait_for_reply {
1357 return Ok(None);
1358 }
1359
1360 let request_id = id.ok_or_else(|| {
1361 error!("id is required when waiting for a reply");
1362 WebsocketError::NotConnected
1363 })?;
1364
1365 let (tx, rx) = oneshot::channel();
1366 {
1367 let mut conn_state = conn.state.lock().await;
1368 conn_state
1369 .pending_requests
1370 .insert(request_id.clone(), PendingRequest { completion: tx });
1371 }
1372
1373 let conn_clone = Arc::clone(&conn);
1374 spawn(async move {
1375 sleep(timeout).await;
1376 let mut conn_state = conn_clone.state.lock().await;
1377 if let Some(pending_req) = conn_state.pending_requests.remove(&request_id) {
1378 let _ = pending_req.completion.send(Err(WebsocketError::Timeout));
1379 }
1380 });
1381
1382 Ok(Some(rx))
1383 }
1384}
1385
1386#[derive(Debug, Default, Clone)]
1387pub struct WebsocketMessageSendOptions {
1388 pub with_api_key: bool,
1389 pub is_signed: bool,
1390 pub is_session_logon: Option<bool>,
1391 pub is_session_logout: Option<bool>,
1392}
1393
1394impl WebsocketMessageSendOptions {
1395 #[must_use]
1396 pub fn new() -> Self {
1397 Self::default()
1398 }
1399
1400 #[must_use]
1401 pub fn with_api_key(mut self) -> Self {
1402 self.with_api_key = true;
1403 self
1404 }
1405
1406 #[must_use]
1407 pub fn signed(mut self) -> Self {
1408 self.is_signed = true;
1409 self
1410 }
1411
1412 #[must_use]
1413 pub fn session_logon(mut self) -> Self {
1414 self.is_session_logon = Some(true);
1415 self
1416 }
1417
1418 #[must_use]
1419 pub fn session_logout(mut self) -> Self {
1420 self.is_session_logout = Some(true);
1421 self
1422 }
1423}
1424
1425#[derive(Debug)]
1426pub enum SendWebsocketMessageResult<R> {
1427 Single(WebsocketApiResponse<R>),
1428 Multiple(Vec<WebsocketApiResponse<R>>),
1429}
1430
1431impl<R> IntoIterator for SendWebsocketMessageResult<R> {
1432 type Item = WebsocketApiResponse<R>;
1433 type IntoIter = std::vec::IntoIter<Self::Item>;
1434
1435 fn into_iter(self) -> Self::IntoIter {
1436 match self {
1437 SendWebsocketMessageResult::Single(resp) => vec![resp].into_iter(),
1438 SendWebsocketMessageResult::Multiple(v) => v.into_iter(),
1439 }
1440 }
1441}
1442
1443pub struct WebsocketApi {
1444 pub common: Arc<WebsocketCommon>,
1445 configuration: ConfigurationWebsocketApi,
1446 is_connecting: Arc<Mutex<bool>>,
1447 stream_callbacks: Mutex<HashMap<String, Vec<Arc<dyn Fn(&Value) + Send + Sync + 'static>>>>,
1448}
1449
1450impl WebsocketApi {
1451 #[must_use]
1452 pub fn new(
1473 configuration: ConfigurationWebsocketApi,
1474 connection_pool: Vec<Arc<WebsocketConnection>>,
1475 ) -> Arc<Self> {
1476 let agent_clone = configuration.agent.clone();
1477 let user_agent_clone = configuration.user_agent.clone();
1478 let common = WebsocketCommon::new(
1479 connection_pool,
1480 configuration.mode.clone(),
1481 usize::try_from(configuration.reconnect_delay)
1482 .expect("reconnect_delay should fit in usize"),
1483 agent_clone,
1484 Some(user_agent_clone),
1485 );
1486
1487 Arc::new(Self {
1488 common: Arc::clone(&common),
1489 configuration,
1490 is_connecting: Arc::new(Mutex::new(false)),
1491 stream_callbacks: Mutex::new(HashMap::new()),
1492 })
1493 }
1494
1495 pub async fn connect(self: Arc<Self>) -> Result<(), WebsocketError> {
1517 if self.common.is_connected(None).await {
1518 info!("WebSocket connection already established");
1519 return Ok(());
1520 }
1521
1522 {
1523 let mut flag = self.is_connecting.lock().await;
1524 if *flag {
1525 info!("Already connecting...");
1526 return Ok(());
1527 }
1528 *flag = true;
1529 }
1530
1531 let url = self.prepare_url(self.configuration.ws_url.as_deref().unwrap_or_default());
1532
1533 let handler: Arc<dyn WebsocketHandler> = self.clone();
1534 for slot in &self.common.connection_pool {
1535 slot.set_handler(handler.clone()).await;
1536 }
1537
1538 let result = select! {
1539 () = sleep(Duration::from_secs(10)) => Err(WebsocketError::Timeout),
1540 r = self.common.clone().connect_pool(&url, None) => r,
1541 };
1542
1543 {
1544 let mut flag = self.is_connecting.lock().await;
1545 *flag = false;
1546 }
1547
1548 result
1549 }
1550
1551 pub async fn disconnect(&self) -> Result<(), WebsocketError> {
1564 self.common.disconnect().await
1565 }
1566
1567 pub async fn is_connected(&self) -> bool {
1573 self.common.is_connected(None).await
1574 }
1575
1576 pub async fn ping_server(&self) {
1581 self.common.ping_server().await;
1582 }
1583
1584 pub async fn send_message<R>(
1614 &self,
1615 method: &str,
1616 payload: BTreeMap<String, Value>,
1617 options: WebsocketMessageSendOptions,
1618 ) -> Result<SendWebsocketMessageResult<R>, WebsocketError>
1619 where
1620 R: DeserializeOwned + Send + Sync + 'static,
1621 {
1622 if !self.common.is_connected(None).await {
1623 return Err(WebsocketError::NotConnected);
1624 }
1625
1626 let do_multi =
1627 options.is_session_logon.unwrap_or(false) || options.is_session_logout.unwrap_or(false);
1628
1629 let connections = if do_multi {
1630 self.common.get_available_connections(false, None).await
1631 } else {
1632 vec![self.common.get_connection(false, None).await?]
1633 };
1634
1635 let skip_auth = if do_multi {
1636 false
1637 } else {
1638 let connection = &connections[0];
1639 let conn_state = connection.state.lock().await;
1640 self.configuration.auto_session_relogon && conn_state.is_session_logged_on
1641 };
1642
1643 let payload_clone = payload.clone();
1644
1645 let (id, request) =
1646 build_websocket_api_message(&self.configuration, method, payload, &options, skip_auth);
1647 let raw_payload = serde_json::to_string(&request).unwrap();
1648 debug!("Sending message to WebSocket API: {:?}", request);
1649
1650 let timeout = Duration::from_millis(self.configuration.timeout);
1651
1652 let mut receivers = Vec::with_capacity(connections.len());
1653 for connection in &connections {
1654 let opt_rx = self
1655 .common
1656 .send(
1657 raw_payload.clone(),
1658 Some(id.clone()),
1659 true,
1660 timeout,
1661 Some(connection.clone()),
1662 )
1663 .await?;
1664 receivers.push((connection.clone(), opt_rx));
1665 }
1666
1667 let mut raw_msgs = Vec::with_capacity(receivers.len());
1668 for (_conn, opt_rx) in receivers {
1669 let rx = opt_rx.ok_or(WebsocketError::NoResponse)?;
1670 let msg = rx.await.unwrap_or(Err(WebsocketError::Timeout))?;
1671 raw_msgs.push(msg);
1672 }
1673
1674 let mut responses = Vec::with_capacity(raw_msgs.len());
1675 for msg in raw_msgs {
1676 let raw = msg
1677 .get("result")
1678 .or_else(|| msg.get("response"))
1679 .cloned()
1680 .unwrap_or(Value::Null);
1681
1682 let rate_limits = msg
1683 .get("rateLimits")
1684 .and_then(Value::as_array)
1685 .map(|arr| {
1686 arr.iter()
1687 .filter_map(|v| serde_json::from_value(v.clone()).ok())
1688 .collect()
1689 })
1690 .unwrap_or_default();
1691
1692 responses.push(WebsocketApiResponse {
1693 raw,
1694 rate_limits,
1695 _marker: PhantomData,
1696 });
1697 }
1698
1699 if do_multi && self.configuration.auto_session_relogon {
1700 for connection in &connections {
1701 let mut state = connection.state.lock().await;
1702 if options.is_session_logon.unwrap_or(false) {
1703 state.is_session_logged_on = true;
1704 state.session_logon_req = Some(WebsocketSessionLogonReq {
1705 method: method.to_string(),
1706 payload: payload_clone.clone(),
1707 options: options.clone(),
1708 });
1709 } else {
1710 state.is_session_logged_on = false;
1711 state.session_logon_req = None;
1712 }
1713 }
1714 }
1715
1716 Ok(if responses.len() == 1 && !do_multi {
1717 SendWebsocketMessageResult::Single(responses.into_iter().next().unwrap())
1718 } else {
1719 SendWebsocketMessageResult::Multiple(responses)
1720 })
1721 }
1722
1723 fn prepare_url(&self, ws_url: &str) -> String {
1738 let mut url = ws_url.to_string();
1739
1740 let time_unit = match &self.configuration.time_unit {
1741 Some(u) => u.to_string(),
1742 None => return url,
1743 };
1744
1745 match validate_time_unit(&time_unit) {
1746 Ok(Some(validated)) => {
1747 let sep = if url.contains('?') { '&' } else { '?' };
1748 url.push(sep);
1749 url.push_str("timeUnit=");
1750 url.push_str(validated);
1751 }
1752 Ok(None) => {}
1753 Err(e) => {
1754 error!("Invalid time unit provided: {:?}", e);
1755 }
1756 }
1757
1758 url
1759 }
1760}
1761
1762#[async_trait]
1763impl WebsocketHandler for WebsocketApi {
1764 async fn on_open(&self, _url: String, connection: Arc<WebsocketConnection>) {
1782 let session_req = {
1783 let conn_state = connection.state.lock().await;
1784 conn_state.session_logon_req.clone()
1785 };
1786
1787 let Some(req) = session_req else {
1788 return;
1789 };
1790
1791 let already_logged_on = {
1792 let conn_state = connection.state.lock().await;
1793 conn_state.is_session_logged_on
1794 };
1795
1796 if already_logged_on {
1797 debug!(
1798 "Connection {} already logged on, skipping re-logon",
1799 connection.id
1800 );
1801 return;
1802 }
1803
1804 let conn = connection.clone();
1805 let common = Arc::clone(&self.common);
1806 let configuration = self.configuration.clone();
1807 let method = req.method.clone();
1808 let payload = req.payload.clone();
1809 let options = req.options.clone();
1810
1811 spawn(async move {
1812 let (id, json_msg) =
1813 build_websocket_api_message(&configuration, &method, payload, &options, false);
1814
1815 let raw_message = match serde_json::to_string(&json_msg) {
1816 Ok(msg) => msg,
1817 Err(e) => {
1818 warn!(
1819 "Failed to serialize session logon message for connection {}: {}",
1820 conn.id, e
1821 );
1822 return;
1823 }
1824 };
1825
1826 debug!(
1827 "Session re-logon on connection {}: {}",
1828 conn.id, raw_message
1829 );
1830
1831 let rx = match common
1832 .send(
1833 raw_message,
1834 Some(id.clone()),
1835 true,
1836 Duration::from_millis(configuration.timeout),
1837 Some(conn.clone()),
1838 )
1839 .await
1840 {
1841 Ok(Some(rx)) => rx,
1842 Ok(None) => {
1843 warn!(
1844 "Session re-logon dispatch returned None for connection {}",
1845 conn.id
1846 );
1847 return;
1848 }
1849 Err(e) => {
1850 warn!(
1851 "Session re-logon dispatch failed on connection {}: {}",
1852 conn.id, e
1853 );
1854 return;
1855 }
1856 };
1857
1858 let Ok(result) = timeout(Duration::from_millis(configuration.timeout), rx).await else {
1859 warn!("Session re-logon timed out on connection {}", conn.id);
1860 return;
1861 };
1862
1863 let final_result = match result {
1864 Ok(final_result) => final_result,
1865 Err(e) => {
1866 warn!(
1867 "Session re-logon receiver error on connection {}: {}",
1868 conn.id, e
1869 );
1870 return;
1871 }
1872 };
1873
1874 let payload = match final_result {
1875 Ok(payload) => payload,
1876 Err(e) => {
1877 warn!(
1878 "Session re-logon payload error on connection {}: {}",
1879 conn.id, e
1880 );
1881 return;
1882 }
1883 };
1884
1885 debug!(
1886 "Session re-logon succeeded on connection {}: {}",
1887 conn.id, payload
1888 );
1889 let mut conn_state = conn.state.lock().await;
1890 conn_state.is_session_logged_on = true;
1891 });
1892 }
1893
1894 async fn on_message(&self, data: String, connection: Arc<WebsocketConnection>) {
1914 let msg: Value = match serde_json::from_str(&data) {
1915 Ok(v) => v,
1916 Err(err) => {
1917 error!("Failed to parse WebSocket message {} – {}", data, err);
1918 return;
1919 }
1920 };
1921
1922 if let Some(id) = msg.get("id").and_then(Value::as_str) {
1923 let maybe_sender = {
1924 let mut conn_state = connection.state.lock().await;
1925 conn_state.pending_requests.remove(id)
1926 };
1927
1928 if let Some(PendingRequest { completion }) = maybe_sender {
1929 connection.drain_notify.notify_one();
1930 let status = msg.get("status").and_then(Value::as_u64).unwrap_or(200);
1931 if status >= 400 {
1932 let error_map = msg
1933 .get("error")
1934 .and_then(Value::as_object)
1935 .unwrap_or(&serde_json::Map::new())
1936 .clone();
1937
1938 let code = error_map
1939 .get("code")
1940 .and_then(Value::as_i64)
1941 .unwrap_or(status.try_into().unwrap());
1942
1943 let message = error_map
1944 .get("msg")
1945 .and_then(Value::as_str)
1946 .unwrap_or("Unknown error")
1947 .to_string();
1948
1949 let _ = completion.send(Err(WebsocketError::ResponseError { code, message }));
1950 } else {
1951 let _ = completion.send(Ok(msg.clone()));
1952 }
1953 }
1954
1955 return;
1956 }
1957
1958 if let Some(event) = msg.get("event") {
1959 if event.get("e").is_some() {
1960 for callbacks in self.stream_callbacks.lock().await.values() {
1961 for callback in callbacks {
1962 callback(event);
1963 }
1964 }
1965
1966 return;
1967 }
1968 }
1969
1970 warn!(
1971 "Received response for unknown or timed-out request: {}",
1972 data
1973 );
1974 }
1975
1976 async fn get_reconnect_url(
1987 &self,
1988 default_url: String,
1989 _connection: Arc<WebsocketConnection>,
1990 ) -> String {
1991 default_url
1992 }
1993}
1994
1995pub struct WebsocketStreams {
1996 pub common: Arc<WebsocketCommon>,
1997 pub stream_id_is_strictly_number: AtomicBool,
1998 url_paths: Vec<String>,
1999 is_connecting: Mutex<bool>,
2000 connection_streams: Mutex<HashMap<String, Arc<WebsocketConnection>>>,
2001 configuration: ConfigurationWebsocketStreams,
2002}
2003
2004impl WebsocketStreams {
2005 #[must_use]
2021 pub fn new(
2022 configuration: ConfigurationWebsocketStreams,
2023 mut connection_pool: Vec<Arc<WebsocketConnection>>,
2024 url_paths: Vec<String>,
2025 ) -> Arc<Self> {
2026 if !url_paths.is_empty() {
2027 let base_pool_size = configuration.mode.pool_size();
2028 let expected = base_pool_size * url_paths.len();
2029
2030 while connection_pool.len() < expected {
2031 connection_pool.push(WebsocketConnection::new(random_string()));
2032 }
2033 }
2034
2035 let agent_clone = configuration.agent.clone();
2036 let user_agent_clone = configuration.user_agent.clone();
2037 let common = WebsocketCommon::new(
2038 connection_pool,
2039 configuration.mode.clone(),
2040 usize::try_from(configuration.reconnect_delay)
2041 .expect("reconnect_delay should fit in usize"),
2042 agent_clone,
2043 Some(user_agent_clone),
2044 );
2045 Arc::new(Self {
2046 common,
2047 is_connecting: Mutex::new(false),
2048 connection_streams: Mutex::new(HashMap::new()),
2049 configuration,
2050 stream_id_is_strictly_number: AtomicBool::new(false),
2051 url_paths,
2052 })
2053 }
2054
2055 pub async fn connect(self: Arc<Self>, streams: Vec<String>) -> Result<(), WebsocketError> {
2072 if self.common.is_connected(None).await {
2073 info!("WebSocket connection already established");
2074 return Ok(());
2075 }
2076
2077 {
2078 let mut flag = self.is_connecting.lock().await;
2079 if *flag {
2080 info!("Already connecting...");
2081 return Ok(());
2082 }
2083 *flag = true;
2084 }
2085
2086 let handler: Arc<dyn WebsocketHandler> = self.clone();
2087 for conn in &self.common.connection_pool {
2088 conn.set_handler(handler.clone()).await;
2089 }
2090
2091 let base_pool_size = self.configuration.mode.pool_size();
2092
2093 let connect_fut = async {
2094 if self.url_paths.is_empty() {
2095 let url = self.prepare_url(&streams, None);
2096 self.common.clone().connect_pool(&url, None).await
2097 } else {
2098 let mut futures = Vec::with_capacity(self.url_paths.len());
2099
2100 for (i, path) in self.url_paths.iter().enumerate() {
2101 let start = i * base_pool_size;
2102
2103 let subset: Vec<Arc<WebsocketConnection>> = self
2104 .common
2105 .connection_pool
2106 .iter()
2107 .skip(start)
2108 .take(base_pool_size)
2109 .cloned()
2110 .collect();
2111
2112 if subset.len() != base_pool_size {
2113 return Err(WebsocketError::ServerError(format!(
2114 "connection_pool too small for url_paths: need {} per path, got {} for path index {}",
2115 base_pool_size,
2116 subset.len(),
2117 i
2118 )));
2119 }
2120
2121 for c in &subset {
2122 let mut st = c.state.lock().await;
2123 st.url_path = Some(path.clone());
2124 }
2125
2126 let url = self.prepare_url(&streams, Some(path.as_str()));
2127 let common = self.common.clone();
2128
2129 futures.push(async move { common.connect_pool(&url, Some(subset)).await });
2130 }
2131
2132 try_join_all(futures).await?;
2133 Ok(())
2134 }
2135 };
2136
2137 let connect_res = select! {
2138 () = sleep(Duration::from_secs(10)) => Err(WebsocketError::Timeout),
2139 r = connect_fut => r,
2140 };
2141
2142 {
2143 let mut flag = self.is_connecting.lock().await;
2144 *flag = false;
2145 }
2146
2147 connect_res
2148 }
2149
2150 pub async fn disconnect(&self) -> Result<(), WebsocketError> {
2166 for connection in &self.common.connection_pool {
2167 let mut conn_state = connection.state.lock().await;
2168 conn_state.stream_callbacks.clear();
2169 conn_state.pending_subscriptions.clear();
2170 }
2171 self.connection_streams.lock().await.clear();
2172 self.common.disconnect().await
2173 }
2174
2175 pub async fn is_connected(&self) -> bool {
2181 self.common.is_connected(None).await
2182 }
2183
2184 pub async fn ping_server(&self) {
2193 self.common.ping_server().await;
2194 }
2195
2196 pub async fn subscribe(
2216 self: Arc<Self>,
2217 streams: Vec<String>,
2218 id: Option<StreamId>,
2219 url_path: Option<&str>,
2220 ) {
2221 let streams: Vec<String> = {
2222 let map = self.connection_streams.lock().await;
2223 streams
2224 .into_iter()
2225 .filter(|s| {
2226 let key = self.stream_key(s, url_path);
2227 !map.contains_key(&key)
2228 })
2229 .collect()
2230 };
2231
2232 if streams.is_empty() {
2233 return;
2234 }
2235
2236 let connection_streams = self.handle_stream_assignment(streams, url_path).await;
2237
2238 for (conn, assigned_streams) in connection_streams {
2239 if !self.common.is_connected(Some(&conn)).await {
2240 info!(
2241 "Connection {} is not ready. Queuing subscription for streams: {:?}",
2242 conn.id, assigned_streams
2243 );
2244
2245 let mut conn_state = conn.state.lock().await;
2246 conn_state
2247 .pending_subscriptions
2248 .extend(assigned_streams.iter().cloned());
2249
2250 continue;
2251 }
2252
2253 self.send_subscription_payload(&conn, &assigned_streams, id.clone());
2254 }
2255 }
2256
2257 pub async fn unsubscribe(
2286 &self,
2287 streams: Vec<String>,
2288 id: Option<StreamId>,
2289 url_path: Option<&str>,
2290 ) {
2291 let request_id = normalize_stream_id(
2292 id.clone(),
2293 self.stream_id_is_strictly_number.load(Ordering::Relaxed),
2294 );
2295
2296 for stream in streams {
2297 let key = self.stream_key(&stream, url_path);
2298 let maybe_conn = { self.connection_streams.lock().await.get(&key).cloned() };
2299
2300 let conn = match maybe_conn {
2301 Some(c) => {
2302 if !self.common.is_connected(Some(&c)).await {
2303 warn!(
2304 "Stream {} not associated with an active connection.",
2305 stream
2306 );
2307 continue;
2308 }
2309 c
2310 }
2311 None => {
2312 warn!("Stream {} was not subscribed.", stream);
2313 continue;
2314 }
2315 };
2316
2317 let has_callbacks = {
2318 let conn_state = conn.state.lock().await;
2319 conn_state
2320 .stream_callbacks
2321 .get(&key)
2322 .is_some_and(|v| !v.is_empty())
2323 };
2324
2325 if has_callbacks {
2326 continue;
2327 }
2328
2329 let payload = json!({
2330 "method": "UNSUBSCRIBE",
2331 "params": [stream.clone()],
2332 "id": request_id,
2333 });
2334
2335 info!("UNSUBSCRIBE → {:?}", payload);
2336
2337 let common = Arc::clone(&self.common);
2338 let conn_clone = Arc::clone(&conn);
2339 let msg = serde_json::to_string(&payload).unwrap();
2340 spawn(async move {
2341 let _ = common
2342 .send(msg, None, false, Duration::ZERO, Some(conn_clone))
2343 .await;
2344 });
2345
2346 {
2347 let mut connection_streams = self.connection_streams.lock().await;
2348 connection_streams.remove(&key);
2349 }
2350 {
2351 let mut conn_state = conn.state.lock().await;
2352 conn_state.stream_callbacks.remove(&key);
2353 }
2354 }
2355 }
2356
2357 pub async fn is_subscribed(&self, stream: &str) -> bool {
2371 let map = self.connection_streams.lock().await;
2372
2373 if map.contains_key(stream) {
2374 return true;
2375 }
2376
2377 let suffix = format!("::{}", stream);
2378 map.keys().any(|k| k.ends_with(&suffix))
2379 }
2380
2381 fn stream_key(&self, stream: &str, url_path: Option<&str>) -> String {
2393 match url_path {
2394 Some(p) if !p.is_empty() => format!("{p}::{stream}"),
2395 _ => stream.to_string(),
2396 }
2397 }
2398
2399 fn prepare_url(&self, streams: &[String], url_path: Option<&str>) -> String {
2416 let mut url = format!(
2417 "{}/stream?streams={}",
2418 match url_path {
2419 Some(path) => format!(
2420 "{}/{}",
2421 self.configuration.ws_url.as_deref().unwrap_or(""),
2422 path
2423 ),
2424 None => self
2425 .configuration
2426 .ws_url
2427 .as_deref()
2428 .unwrap_or("")
2429 .to_string(),
2430 },
2431 streams.join("/")
2432 );
2433
2434 let time_unit = match &self.configuration.time_unit {
2435 Some(u) => u.to_string(),
2436 None => return url,
2437 };
2438
2439 match validate_time_unit(&time_unit) {
2440 Ok(Some(validated)) => {
2441 let sep = if url.contains('?') { '&' } else { '?' };
2442 url.push(sep);
2443 url.push_str("timeUnit=");
2444 url.push_str(validated);
2445 }
2446 Ok(None) => {}
2447 Err(e) => {
2448 error!("Invalid time unit provided: {:?}", e);
2449 }
2450 }
2451
2452 url
2453 }
2454
2455 async fn handle_stream_assignment(
2474 &self,
2475 streams: Vec<String>,
2476 url_path: Option<&str>,
2477 ) -> Vec<(Arc<WebsocketConnection>, Vec<String>)> {
2478 let mut connection_streams: Vec<(String, Arc<WebsocketConnection>)> = Vec::new();
2479
2480 for stream in streams {
2481 let key = self.stream_key(&stream, url_path);
2482
2483 let mut conn_opt = {
2484 let map = self.connection_streams.lock().await;
2485 map.get(&key).cloned()
2486 };
2487
2488 let need_new = if let Some(conn) = &conn_opt {
2489 let state = conn.state.lock().await;
2490 state.close_initiated || state.reconnection_pending
2491 } else {
2492 true
2493 };
2494
2495 if need_new {
2496 match self.common.get_connection(true, url_path).await {
2497 Ok(new_conn) => {
2498 let mut map = self.connection_streams.lock().await;
2499 map.insert(key.clone(), new_conn.clone());
2500 conn_opt = Some(new_conn);
2501 }
2502 Err(err) => {
2503 warn!(
2504 "No available WebSocket connection to subscribe stream `{}` (key `{}`): {:?}",
2505 stream, key, err
2506 );
2507 continue;
2508 }
2509 }
2510 }
2511
2512 if let Some(conn) = conn_opt {
2513 {
2514 let mut conn_state = conn.state.lock().await;
2515 conn_state.stream_callbacks.entry(key.clone()).or_default();
2516 }
2517 connection_streams.push((stream, conn));
2518 }
2519 }
2520
2521 let mut groups: Vec<(Arc<WebsocketConnection>, Vec<String>)> = Vec::new();
2522 for (stream, conn) in connection_streams {
2523 if let Some((_, vec)) = groups.iter_mut().find(|(c, _)| Arc::ptr_eq(c, &conn)) {
2524 vec.push(stream);
2525 } else {
2526 groups.push((conn, vec![stream]));
2527 }
2528 }
2529
2530 groups
2531 }
2532
2533 fn send_subscription_payload(
2546 &self,
2547 connection: &Arc<WebsocketConnection>,
2548 streams: &Vec<String>,
2549 id: Option<StreamId>,
2550 ) {
2551 let request_id = normalize_stream_id(
2552 id.clone(),
2553 self.stream_id_is_strictly_number.load(Ordering::Relaxed),
2554 );
2555
2556 let payload = json!({
2557 "method": "SUBSCRIBE",
2558 "params": streams,
2559 "id": request_id,
2560 });
2561
2562 info!("SUBSCRIBE → {:?}", payload);
2563
2564 let common = Arc::clone(&self.common);
2565 let msg = match serde_json::to_string(&payload) {
2566 Ok(s) => s,
2567 Err(e) => {
2568 error!("Failed to serialize SUBSCRIBE payload: {}", e);
2569 return;
2570 }
2571 };
2572 let conn_clone = Arc::clone(connection);
2573
2574 spawn(async move {
2575 let _ = common
2576 .send(msg, None, false, Duration::ZERO, Some(conn_clone))
2577 .await;
2578 });
2579 }
2580}
2581
2582#[async_trait]
2583impl WebsocketHandler for WebsocketStreams {
2584 async fn on_open(&self, _url: String, connection: Arc<WebsocketConnection>) {
2601 let pending_subs: Vec<String> = {
2602 let mut conn_state = connection.state.lock().await;
2603 take(&mut conn_state.pending_subscriptions)
2604 .into_iter()
2605 .collect()
2606 };
2607
2608 if !pending_subs.is_empty() {
2609 info!("Processing queued subscriptions for connection");
2610 self.send_subscription_payload(&connection, &pending_subs, None);
2611 }
2612 }
2613
2614 async fn on_message(&self, data: String, connection: Arc<WebsocketConnection>) {
2631 let msg: Value = match serde_json::from_str(&data) {
2632 Ok(v) => v,
2633 Err(err) => {
2634 error!(
2635 "Failed to parse WebSocket stream message {} – {}",
2636 data, err
2637 );
2638 return;
2639 }
2640 };
2641
2642 let (stream_name, payload) = match (
2643 msg.get("stream").and_then(Value::as_str),
2644 msg.get("data").cloned(),
2645 ) {
2646 (Some(name), Some(data)) => (name.to_string(), data),
2647 _ => return,
2648 };
2649
2650 let callbacks = {
2651 let conn_state = connection.state.lock().await;
2652 let key = self.stream_key(&stream_name, conn_state.url_path.as_deref());
2653 conn_state
2654 .stream_callbacks
2655 .get(&key)
2656 .cloned()
2657 .unwrap_or_else(Vec::new)
2658 };
2659
2660 for callback in callbacks {
2661 callback(&payload);
2662 }
2663 }
2664
2665 async fn get_reconnect_url(
2676 &self,
2677 _default_url: String,
2678 connection: Arc<WebsocketConnection>,
2679 ) -> String {
2680 let connection_streams = self.connection_streams.lock().await;
2681 let reconnect_streams = connection_streams
2682 .iter()
2683 .filter_map(|(key, conn_arc)| {
2684 if Arc::ptr_eq(conn_arc, &connection) {
2685 let stream = match key.split_once("::") {
2686 Some((_prefix, rest)) => rest.to_string(),
2687 None => key.clone(),
2688 };
2689 Some(stream)
2690 } else {
2691 None
2692 }
2693 })
2694 .collect::<Vec<_>>();
2695
2696 let url_path = {
2697 let st = connection.state.lock().await;
2698 st.url_path.as_deref().map(std::string::ToString::to_string)
2699 };
2700
2701 self.prepare_url(&reconnect_streams, url_path.as_deref())
2702 }
2703}
2704
2705pub struct WebsocketStream<T> {
2706 websocket_base: WebsocketBase,
2707 stream_or_id: String,
2708 url_path: Option<String>,
2709 callback: Mutex<Option<Arc<dyn Fn(&Value) + Send + Sync>>>,
2710 pub id: Option<StreamId>,
2711 _phantom: PhantomData<T>,
2712}
2713
2714impl<T> WebsocketStream<T>
2715where
2716 T: DeserializeOwned + Send + 'static,
2717{
2718 async fn on<F>(&self, event: &str, callback_fn: F)
2739 where
2740 F: Fn(T) + Send + Sync + 'static,
2741 {
2742 if event != "message" {
2743 return;
2744 }
2745
2746 let cb_wrapper: Arc<dyn Fn(&Value) + Send + Sync> =
2747 Arc::new(
2748 move |v: &Value| match serde_json::from_value::<T>(v.clone()) {
2749 Ok(data) => callback_fn(data),
2750 Err(e) => error!("Failed to deserialize stream payload: {:?}", e),
2751 },
2752 );
2753
2754 {
2755 let mut guard = self.callback.lock().await;
2756 *guard = Some(cb_wrapper.clone());
2757 }
2758
2759 match &self.websocket_base {
2760 WebsocketBase::WebsocketStreams(ws_streams) => {
2761 let key = ws_streams.stream_key(&self.stream_or_id, self.url_path.as_deref());
2762 let conn = {
2763 let map = ws_streams.connection_streams.lock().await;
2764 map.get(&key).cloned().expect("stream must be subscribed")
2765 };
2766
2767 {
2768 let mut conn_state = conn.state.lock().await;
2769 let entry = conn_state.stream_callbacks.entry(key).or_default();
2770
2771 if !entry
2772 .iter()
2773 .any(|existing| Arc::ptr_eq(existing, &cb_wrapper))
2774 {
2775 entry.push(cb_wrapper);
2776 }
2777 }
2778 }
2779 WebsocketBase::WebsocketApi(ws_api) => {
2780 let mut stream_callbacks = ws_api.stream_callbacks.lock().await;
2781 let entry = stream_callbacks
2782 .entry(self.stream_or_id.clone())
2783 .or_default();
2784
2785 if !entry
2786 .iter()
2787 .any(|existing| Arc::ptr_eq(existing, &cb_wrapper))
2788 {
2789 entry.push(cb_wrapper);
2790 }
2791 }
2792 }
2793 }
2794
2795 pub fn on_message<F>(self: &Arc<Self>, callback_fn: F)
2814 where
2815 T: Send + Sync,
2816 F: Fn(T) + Send + Sync + 'static,
2817 {
2818 let handler: Arc<Self> = Arc::clone(self);
2819
2820 std::thread::spawn(move || {
2821 let rt = tokio::runtime::Builder::new_current_thread()
2822 .enable_all()
2823 .build()
2824 .expect("failed to build Tokio runtime");
2825
2826 rt.block_on(handler.on("message", callback_fn));
2827 })
2828 .join()
2829 .expect("on_message thread panicked");
2830 }
2831
2832 pub async fn unsubscribe(&self) {
2847 let maybe_cb = {
2848 let mut guard = self.callback.lock().await;
2849 guard.take()
2850 };
2851
2852 if let Some(cb) = maybe_cb {
2853 match &self.websocket_base {
2854 WebsocketBase::WebsocketStreams(ws_streams) => {
2855 let key = ws_streams.stream_key(&self.stream_or_id, self.url_path.as_deref());
2856 let conn = {
2857 let map = ws_streams.connection_streams.lock().await;
2858 map.get(&key)
2859 .cloned()
2860 .expect("stream must have been subscribed")
2861 };
2862
2863 {
2864 let mut conn_state = conn.state.lock().await;
2865 if let Some(list) = conn_state.stream_callbacks.get_mut(&key) {
2866 list.retain(|existing| !Arc::ptr_eq(existing, &cb));
2867 }
2868 }
2869
2870 let stream = self.stream_or_id.clone();
2871 let id = self.id.clone();
2872 let url_path = self.url_path.clone();
2873 let websocket_streams_base = Arc::clone(ws_streams);
2874 spawn(async move {
2875 websocket_streams_base
2876 .unsubscribe(vec![stream], id, url_path.as_deref())
2877 .await;
2878 });
2879 }
2880 WebsocketBase::WebsocketApi(ws_api) => {
2881 let mut stream_callbacks = ws_api.stream_callbacks.lock().await;
2882 if let Some(list) = stream_callbacks.get_mut(&self.stream_or_id) {
2883 list.retain(|existing| !Arc::ptr_eq(existing, &cb));
2884 }
2885 }
2886 }
2887 }
2888 }
2889}
2890
2891pub async fn create_stream_handler<T>(
2906 websocket_base: WebsocketBase,
2907 stream_or_id: String,
2908 id: Option<StreamId>,
2909 url_path: Option<String>,
2910) -> Arc<WebsocketStream<T>>
2911where
2912 T: DeserializeOwned + Send + 'static,
2913{
2914 match &websocket_base {
2915 WebsocketBase::WebsocketStreams(ws_streams) => {
2916 ws_streams
2917 .clone()
2918 .subscribe(vec![stream_or_id.clone()], id.clone(), url_path.as_deref())
2919 .await;
2920 }
2921 WebsocketBase::WebsocketApi(_) => {}
2922 }
2923
2924 Arc::new(WebsocketStream {
2925 websocket_base,
2926 stream_or_id,
2927 url_path,
2928 id,
2929 callback: Mutex::new(None),
2930 _phantom: PhantomData,
2931 })
2932}
2933
2934#[cfg(test)]
2935mod tests {
2936 use crate::TOKIO_SHARED_RT;
2937 use crate::common::utils::{SignatureGenerator, build_user_agent};
2938 use crate::common::websocket::{
2939 PendingRequest, ReconnectEntry, SendWebsocketMessageResult, WebsocketApi, WebsocketBase,
2940 WebsocketCommon, WebsocketConnection, WebsocketEvent, WebsocketEventEmitter,
2941 WebsocketHandler, WebsocketMessageSendOptions, WebsocketMode, WebsocketSessionLogonReq,
2942 WebsocketStream, WebsocketStreams, create_stream_handler,
2943 };
2944 use crate::config::{ConfigurationWebsocketApi, ConfigurationWebsocketStreams, PrivateKey};
2945 use crate::errors::{WebsocketConnectionFailureReason, WebsocketError};
2946 use crate::models::{StreamId, TimeUnit};
2947 use async_trait::async_trait;
2948 use futures::{SinkExt, StreamExt};
2949 use http::header::USER_AGENT;
2950 use regex::Regex;
2951 use serde_json::{Value, json};
2952 use std::collections::{BTreeMap, HashSet};
2953 use std::marker::PhantomData;
2954 use std::net::SocketAddr;
2955 use std::sync::{
2956 Arc,
2957 atomic::{AtomicBool, AtomicUsize, Ordering},
2958 };
2959 use tokio::net::TcpListener;
2960 use tokio::sync::{
2961 Mutex,
2962 mpsc::{Receiver, unbounded_channel},
2963 oneshot,
2964 };
2965 use tokio::time::{Duration, advance, pause, resume, sleep, timeout};
2966 use tokio_tungstenite::{accept_async, accept_hdr_async, tungstenite, tungstenite::Message};
2967 use tungstenite::handshake::server::Request;
2968
2969 fn subscribe_events(common: &WebsocketCommon) -> Arc<Mutex<Vec<WebsocketEvent>>> {
2970 let events = Arc::new(Mutex::new(Vec::new()));
2971 let events_clone = events.clone();
2972 common.events.subscribe(move |event| {
2973 let events_clone = events_clone.clone();
2974 tokio::spawn(async move {
2975 events_clone.lock().await.push(event);
2976 });
2977 });
2978 events
2979 }
2980
2981 async fn create_connection(
2982 id: &str,
2983 has_writer: bool,
2984 reconnection_pending: bool,
2985 renewal_pending: bool,
2986 close_initiated: bool,
2987 ) -> Arc<WebsocketConnection> {
2988 let conn = WebsocketConnection::new(id);
2989 let mut st = conn.state.lock().await;
2990 st.reconnection_pending = reconnection_pending;
2991 st.renewal_pending = renewal_pending;
2992 st.close_initiated = close_initiated;
2993 if has_writer {
2994 let (tx, _) = unbounded_channel::<Message>();
2995 st.ws_write_tx = Some(tx);
2996 } else {
2997 st.ws_write_tx = None;
2998 }
2999 drop(st);
3000 conn
3001 }
3002
3003 fn create_websocket_api(
3004 time_unit: Option<TimeUnit>,
3005 mode: Option<WebsocketMode>,
3006 auto_session_relogon: Option<bool>,
3007 ) -> Arc<WebsocketApi> {
3008 let mode = mode.unwrap_or(WebsocketMode::Single);
3009 let auto_session_relogon = auto_session_relogon.unwrap_or(true);
3010 let sig_gen = SignatureGenerator::new(
3011 Some("api_secret".into()),
3012 None::<PrivateKey>,
3013 None::<String>,
3014 );
3015 let config = ConfigurationWebsocketApi {
3016 api_key: Some("api_key".into()),
3017 api_secret: Some("api_secret".into()),
3018 private_key: None,
3019 private_key_passphrase: None,
3020 ws_url: Some("wss://example.com".into()),
3021 mode,
3022 reconnect_delay: 1000,
3023 signature_gen: sig_gen,
3024 timeout: 500,
3025 time_unit,
3026 auto_session_relogon,
3027 agent: None,
3028 user_agent: build_user_agent("product"),
3029 };
3030 let conn1 = WebsocketConnection::new("c1");
3031 let conn2 = WebsocketConnection::new("c2");
3032 WebsocketApi::new(config, vec![conn1, conn2])
3033 }
3034
3035 fn create_websocket_streams(
3036 ws_url: Option<&str>,
3037 conns: Option<Vec<Arc<WebsocketConnection>>>,
3038 url_paths: Option<Vec<String>>,
3039 ) -> Arc<WebsocketStreams> {
3040 let mut connections: Vec<Arc<WebsocketConnection>> = vec![];
3041 let url_paths = url_paths.unwrap_or_default();
3042 if conns.is_none() {
3043 connections.push(WebsocketConnection::new("c1"));
3044 connections.push(WebsocketConnection::new("c2"));
3045 } else {
3046 connections = conns.expect("Expected connections to be set");
3047 }
3048 let config = ConfigurationWebsocketStreams {
3049 ws_url: Some(ws_url.unwrap_or("example.com").to_string()),
3050 mode: WebsocketMode::Single,
3051 reconnect_delay: 500,
3052 time_unit: None,
3053 agent: None,
3054 user_agent: build_user_agent("product"),
3055 };
3056 WebsocketStreams::new(config, connections, url_paths)
3057 }
3058
3059 fn subscribe_to_emitter(emitter: &WebsocketEventEmitter) -> Receiver<WebsocketEvent> {
3060 let (test_tx, test_rx) = tokio::sync::mpsc::channel(16);
3061 let _sub = emitter.subscribe(move |evt| {
3062 let _ = test_tx.try_send(evt);
3063 });
3064 test_rx
3065 }
3066
3067 async fn expect_websocket_event(rx: &mut Receiver<WebsocketEvent>) -> WebsocketEvent {
3068 timeout(Duration::from_millis(200), rx.recv())
3069 .await
3070 .expect("timed out waiting for event")
3071 .expect("subscriber channel closed")
3072 }
3073
3074 async fn eventually_async<F, Fut>(max_wait: Duration, mut f: F) -> bool
3075 where
3076 F: FnMut() -> Fut,
3077 Fut: std::future::Future<Output = bool>,
3078 {
3079 let start = tokio::time::Instant::now();
3080 while start.elapsed() < max_wait {
3081 if f().await {
3082 return true;
3083 }
3084 sleep(Duration::from_millis(20)).await;
3085 }
3086 false
3087 }
3088
3089 mod event_emitter {
3090 use super::*;
3091
3092 #[test]
3093 fn event_emitter_subscribe_and_emit() {
3094 TOKIO_SHARED_RT.block_on(async {
3095 let emitter = WebsocketEventEmitter::new();
3096 let (tx, rx) = oneshot::channel();
3097 let tx = Arc::new(std::sync::Mutex::new(Some(tx)));
3098 let tx_clone = tx.clone();
3099 let _sub = emitter.subscribe(move |event| {
3100 if let Some(sender) = tx_clone.lock().unwrap().take() {
3101 let _ = sender.send(event);
3102 }
3103 });
3104 emitter.emit(&WebsocketEvent::Open);
3105 let received = timeout(Duration::from_millis(100), rx)
3106 .await
3107 .expect("timed out");
3108 assert_eq!(received, Ok(WebsocketEvent::Open));
3109 });
3110 }
3111
3112 #[test]
3113 fn single_subscriber_gets_event() {
3114 TOKIO_SHARED_RT.block_on(async {
3115 let emitter = WebsocketEventEmitter::new();
3116 let mut rx = subscribe_to_emitter(&emitter);
3117
3118 let e1 = WebsocketEvent::Open;
3119 emitter.emit(&e1);
3120
3121 let got = expect_websocket_event(&mut rx).await;
3122 assert_eq!(got, e1);
3123 });
3124 }
3125
3126 #[test]
3127 fn multiple_subscribers_get_event() {
3128 TOKIO_SHARED_RT.block_on(async {
3129 let emitter = WebsocketEventEmitter::new();
3130 let mut rx1 = subscribe_to_emitter(&emitter);
3131 let mut rx2 = subscribe_to_emitter(&emitter);
3132
3133 let e = WebsocketEvent::Message("hello".into());
3134 emitter.emit(&e);
3135
3136 assert_eq!(expect_websocket_event(&mut rx1).await, e.clone());
3137 assert_eq!(expect_websocket_event(&mut rx2).await, e);
3138 });
3139 }
3140
3141 #[test]
3142 fn closed_subscribers_are_pruned() {
3143 TOKIO_SHARED_RT.block_on(async {
3144 let emitter = WebsocketEventEmitter::new();
3145 let rx1 = subscribe_to_emitter(&emitter);
3146 let mut rx2 = subscribe_to_emitter(&emitter);
3147 drop(rx1);
3148
3149 let e = WebsocketEvent::Pong;
3150 emitter.emit(&e);
3151
3152 assert_eq!(expect_websocket_event(&mut rx2).await, e);
3153 });
3154 }
3155
3156 #[test]
3157 fn prune_on_error_does_not_hang() {
3158 TOKIO_SHARED_RT.block_on(async {
3159 let emitter = WebsocketEventEmitter::new();
3160 let rx = subscribe_to_emitter(&emitter);
3161 drop(rx);
3162
3163 let e = WebsocketEvent::Close(1000, "bye".into());
3164 emitter.emit(&e);
3165 });
3166 }
3167 }
3168
3169 mod websocket_common {
3170 use super::*;
3171
3172 mod initialisation {
3173 use super::*;
3174
3175 #[test]
3176 fn single_mode() {
3177 TOKIO_SHARED_RT.block_on(async {
3178 let common = WebsocketCommon::new(vec![], WebsocketMode::Single, 0, None, None);
3179 assert_eq!(common.connection_pool.len(), 1);
3180 });
3181 }
3182
3183 #[test]
3184 fn pool_mode() {
3185 TOKIO_SHARED_RT.block_on(async {
3186 let common =
3187 WebsocketCommon::new(vec![], WebsocketMode::Pool(3), 0, None, None);
3188 assert_eq!(common.connection_pool.len(), 3);
3189 });
3190 }
3191 }
3192
3193 mod spawn_reconnect_loop {
3194 use super::*;
3195
3196 #[test]
3197 fn successful_reconnect_entry_triggers_init_connect() {
3198 TOKIO_SHARED_RT.block_on(async {
3199 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
3200 let addr = listener.local_addr().unwrap();
3201 tokio::spawn(async move {
3202 if let Ok((stream, _)) = listener.accept().await {
3203 let mut ws = accept_async(stream).await.unwrap();
3204 sleep(Duration::from_secs(5)).await;
3205 let _ = ws.close(None).await;
3206 }
3207 });
3208
3209 let conn = WebsocketConnection::new("c1");
3210 let common = WebsocketCommon::new(
3211 vec![conn.clone()],
3212 WebsocketMode::Single,
3213 10,
3214 None,
3215 None,
3216 );
3217 let url = format!("ws://{addr}");
3218 common
3219 .reconnect_tx
3220 .send(ReconnectEntry {
3221 connection_id: "c1".into(),
3222 url: url.clone(),
3223 is_renewal: false,
3224 })
3225 .await
3226 .unwrap();
3227
3228 let mut ok = false;
3229 for _ in 0..100 {
3230 if conn.state.lock().await.ws_write_tx.is_some() {
3231 ok = true;
3232 break;
3233 }
3234 sleep(Duration::from_millis(50)).await;
3235 }
3236 assert!(ok, "expected ws_write_tx to be Some after reconnect");
3237 });
3238 }
3239
3240 #[test]
3241 fn reconnect_entry_with_unknown_id_is_ignored() {
3242 TOKIO_SHARED_RT.block_on(async {
3243 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
3244 let addr = listener.local_addr().unwrap();
3245 tokio::spawn(async move {
3246 if let Ok((stream, _)) = listener.accept().await {
3247 let mut ws = accept_async(stream).await.unwrap();
3248 let _ = ws.close(None).await;
3249 }
3250 });
3251
3252 let conn = WebsocketConnection::new("c1");
3253 let common = WebsocketCommon::new(
3254 vec![conn.clone()],
3255 WebsocketMode::Single,
3256 5,
3257 None,
3258 None,
3259 );
3260 let url = format!("ws://{addr}");
3261 common
3262 .reconnect_tx
3263 .send(ReconnectEntry {
3264 connection_id: "other".into(),
3265 url,
3266 is_renewal: false,
3267 })
3268 .await
3269 .unwrap();
3270
3271 sleep(Duration::from_secs(1)).await;
3272
3273 let st = conn.state.lock().await;
3274 assert!(st.ws_write_tx.is_none());
3275 });
3276 }
3277
3278 #[test]
3279 fn renewal_entries_bypass_initial_delay() {
3280 TOKIO_SHARED_RT.block_on(async {
3281 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
3282 let addr = listener.local_addr().unwrap();
3283 tokio::spawn(async move {
3284 if let Ok((stream, _)) = listener.accept().await {
3285 let mut ws = accept_async(stream).await.unwrap();
3286 let _ = ws.close(None).await;
3287 }
3288 });
3289
3290 let conn = WebsocketConnection::new("renew");
3291 let common = WebsocketCommon::new(
3292 vec![conn.clone()],
3293 WebsocketMode::Single,
3294 200,
3295 None,
3296 None,
3297 );
3298 let url = format!("ws://{addr}");
3299 common
3300 .reconnect_tx
3301 .send(ReconnectEntry {
3302 connection_id: "renew".into(),
3303 url: url.clone(),
3304 is_renewal: true,
3305 })
3306 .await
3307 .unwrap();
3308
3309 sleep(Duration::from_secs(2)).await;
3310
3311 let st = conn.state.lock().await;
3312
3313 assert!(st.ws_write_tx.is_some());
3314 });
3315 }
3316
3317 #[test]
3318 fn non_renewal_entries_respect_initial_delay() {
3319 TOKIO_SHARED_RT.block_on(async {
3320 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
3321 let addr = listener.local_addr().unwrap();
3322 tokio::spawn(async move {
3323 if let Ok((stream, _)) = listener.accept().await {
3324 let mut ws = accept_async(stream).await.unwrap();
3325 sleep(Duration::from_secs(2)).await;
3326 let _ = ws.close(None).await;
3327 }
3328 });
3329
3330 let conn = WebsocketConnection::new("nonrenew");
3331 let common = WebsocketCommon::new(
3332 vec![conn.clone()],
3333 WebsocketMode::Single,
3334 200,
3335 None,
3336 None,
3337 );
3338 let url = format!("ws://{addr}");
3339 common
3340 .reconnect_tx
3341 .send(ReconnectEntry {
3342 connection_id: "nonrenew".into(),
3343 url: url.clone(),
3344 is_renewal: false,
3345 })
3346 .await
3347 .unwrap();
3348
3349 sleep(Duration::from_millis(50)).await;
3350 assert!(conn.state.lock().await.ws_write_tx.is_none());
3351
3352 let mut ok = false;
3353 for _ in 0..200 {
3354 if conn.state.lock().await.ws_write_tx.is_some() {
3355 ok = true;
3356 break;
3357 }
3358 sleep(Duration::from_millis(50)).await;
3359 }
3360 assert!(
3361 ok,
3362 "expected ws_write_tx to be Some after reconnect delay elapsed"
3363 );
3364 });
3365 }
3366 }
3367
3368 mod spawn_renewal_loop {
3369 use super::*;
3370
3371 #[tokio::test]
3372 async fn scheduling_renewal_does_not_panic_for_known_connection() {
3373 pause();
3374
3375 let conn = WebsocketConnection::new("known");
3376 let common =
3377 WebsocketCommon::new(vec![conn.clone()], WebsocketMode::Single, 0, None, None);
3378 let url = "wss://example".to_string();
3379 common
3380 .renewal_tx
3381 .send((conn.id.clone(), url))
3382 .await
3383 .unwrap();
3384 advance(Duration::from_secs(23 * 60 * 60 + 1)).await;
3385 resume();
3386 }
3387
3388 #[tokio::test]
3389 async fn scheduling_renewal_ignored_for_unknown_connection() {
3390 pause();
3391
3392 let conn = WebsocketConnection::new("c1");
3393 let common =
3394 WebsocketCommon::new(vec![conn.clone()], WebsocketMode::Single, 0, None, None);
3395 common
3396 .renewal_tx
3397 .send(("other".into(), "u".into()))
3398 .await
3399 .unwrap();
3400 advance(Duration::from_secs(23 * 60 * 60 + 1)).await;
3401
3402 resume();
3403 }
3404 }
3405
3406 mod reconnect_regressions {
3407 use super::*;
3408
3409 #[test]
3410 fn init_connect_is_not_skipped_when_reconnection_pending() {
3411 TOKIO_SHARED_RT.block_on(async {
3412 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
3413 let addr = listener.local_addr().unwrap();
3414
3415 tokio::spawn(async move {
3416 if let Ok((stream, _)) = listener.accept().await {
3417 tokio::spawn(async move {
3418 let mut ws = accept_async(stream).await.unwrap();
3419 sleep(Duration::from_millis(500)).await;
3420 let _ = ws.close(None).await;
3421 });
3422 }
3423 });
3424
3425 let conn = WebsocketConnection::new("c-reconnect");
3426 {
3427 let mut st = conn.state.lock().await;
3428 let (tx, _) = unbounded_channel::<Message>();
3429 st.ws_write_tx = Some(tx);
3430 st.reconnection_pending = true;
3431 }
3432
3433 let common = WebsocketCommon::new(
3434 vec![conn.clone()],
3435 WebsocketMode::Single,
3436 0,
3437 None,
3438 None,
3439 );
3440
3441 let url = format!("ws://{addr}");
3442 common
3443 .clone()
3444 .init_connect(&url, false, Some(conn.clone()))
3445 .await
3446 .unwrap();
3447
3448 let ok = eventually_async(Duration::from_secs(2), || {
3449 let conn = conn.clone();
3450 async move {
3451 let st = conn.state.lock().await;
3452 st.ws_write_tx.is_some() && !st.reconnection_pending
3453 }
3454 })
3455 .await;
3456
3457 assert!(
3458 ok,
3459 "expected writer installed and reconnection_pending cleared"
3460 );
3461 });
3462 }
3463
3464 #[test]
3465 fn pending_request_is_resolved_on_socket_drop() {
3466 TOKIO_SHARED_RT.block_on(async {
3467 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
3468 let addr = listener.local_addr().unwrap();
3469
3470 tokio::spawn(async move {
3471 if let Ok((stream, _)) = listener.accept().await {
3472 let mut ws = accept_async(stream).await.unwrap();
3473 let _ = ws.next().await;
3474 let _ = ws.close(None).await;
3475 }
3476 });
3477
3478 let conn = WebsocketConnection::new("c1");
3479 let common = WebsocketCommon::new(
3480 vec![conn.clone()],
3481 WebsocketMode::Single,
3482 0,
3483 None,
3484 None,
3485 );
3486
3487 let url = format!("ws://{addr}");
3488 common
3489 .clone()
3490 .init_connect(&url, false, Some(conn.clone()))
3491 .await
3492 .unwrap();
3493
3494 let rx = common
3495 .send(
3496 "{\"id\":\"req-1\",\"method\":\"PING\"}".to_string(),
3497 Some("req-1".to_string()),
3498 true,
3499 Duration::from_millis(150),
3500 Some(conn.clone()),
3501 )
3502 .await
3503 .unwrap()
3504 .expect("expected oneshot receiver");
3505
3506 let res = timeout(Duration::from_secs(2), rx)
3507 .await
3508 .expect("did not resolve pending request")
3509 .expect("oneshot cancelled");
3510
3511 assert!(matches!(res, Err(WebsocketError::Timeout)));
3512
3513 let ok = eventually_async(Duration::from_secs(1), || {
3514 let conn = conn.clone();
3515 async move { conn.state.lock().await.pending_requests.is_empty() }
3516 })
3517 .await;
3518
3519 assert!(ok, "pending_requests should be drained");
3520 });
3521 }
3522 }
3523
3524 mod is_connection_ready {
3525 use super::*;
3526
3527 #[test]
3528 fn is_connection_ready() {
3529 TOKIO_SHARED_RT.block_on(async {
3530 let conn = WebsocketConnection::new("c1");
3531 let common = WebsocketCommon::new(
3532 vec![conn.clone()],
3533 WebsocketMode::Single,
3534 0,
3535 None,
3536 None,
3537 );
3538 assert!(!common.is_connection_ready(&conn, false).await);
3539 assert!(common.is_connection_ready(&conn, true).await);
3540 });
3541 }
3542
3543 #[test]
3544 fn connection_ready_basic() {
3545 TOKIO_SHARED_RT.block_on(async {
3546 let conn = create_connection("c1", true, false, false, false).await;
3547 let common = WebsocketCommon::new(
3548 vec![conn.clone()],
3549 WebsocketMode::Single,
3550 0,
3551 None,
3552 None,
3553 );
3554 assert!(common.is_connection_ready(&conn, false).await);
3555 });
3556 }
3557
3558 #[test]
3559 fn connection_not_ready_without_writer() {
3560 TOKIO_SHARED_RT.block_on(async {
3561 let conn = create_connection("c1", false, false, false, false).await;
3562 let common = WebsocketCommon::new(
3563 vec![conn.clone()],
3564 WebsocketMode::Single,
3565 0,
3566 None,
3567 None,
3568 );
3569 assert!(!common.is_connection_ready(&conn, false).await);
3570 assert!(common.is_connection_ready(&conn, true).await);
3571 });
3572 }
3573
3574 #[test]
3575 fn connection_not_ready_when_flagged() {
3576 TOKIO_SHARED_RT.block_on(async {
3577 let conn1 = create_connection("c1", true, true, false, false).await;
3578 let conn2 = create_connection("c2", true, false, true, false).await;
3579 let conn3 = create_connection("c3", true, false, false, true).await;
3580
3581 let common = WebsocketCommon::new(
3582 vec![conn1.clone(), conn2.clone(), conn3.clone()],
3583 WebsocketMode::Pool(3),
3584 0,
3585 None,
3586 None,
3587 );
3588
3589 assert!(!common.is_connection_ready(&conn1, false).await);
3590 assert!(common.is_connection_ready(&conn2, false).await);
3591 assert!(!common.is_connection_ready(&conn3, false).await);
3592 });
3593 }
3594 }
3595
3596 mod is_connected {
3597 use super::*;
3598
3599 #[test]
3600 fn with_pool_various_connections() {
3601 TOKIO_SHARED_RT.block_on(async {
3602 let conn_a = create_connection("a", true, false, false, false).await;
3603 let conn_b = create_connection("b", false, false, false, false).await;
3604 let conn_c = create_connection("c", true, true, false, false).await;
3605 let pool = vec![conn_a.clone(), conn_b.clone(), conn_c.clone()];
3606 let common = WebsocketCommon::new(pool, WebsocketMode::Pool(3), 0, None, None);
3607
3608 assert!(common.is_connected(None).await);
3609 assert!(common.is_connected(Some(&conn_a)).await);
3610 assert!(!common.is_connected(Some(&conn_b)).await);
3611 assert!(!common.is_connected(Some(&conn_c)).await);
3612 });
3613 }
3614
3615 #[test]
3616 fn with_pool_all_bad_connections() {
3617 TOKIO_SHARED_RT.block_on(async {
3618 let bad1 = create_connection("c1", false, false, false, false).await;
3619 let bad2 = create_connection("c2", true, true, false, false).await;
3620 let bad3 = create_connection("c3", true, false, false, true).await;
3621 let common = WebsocketCommon::new(
3622 vec![bad1, bad2, bad3],
3623 WebsocketMode::Pool(3),
3624 0,
3625 None,
3626 None,
3627 );
3628
3629 assert!(!common.is_connected(None).await);
3630 });
3631 }
3632
3633 #[test]
3634 fn with_pool_ignore_close_initiated() {
3635 TOKIO_SHARED_RT.block_on(async {
3636 let good = create_connection("c1", true, false, false, false).await;
3637 let closed = create_connection("c2", true, false, false, true).await;
3638 let bad = create_connection("c3", false, false, false, false).await;
3639 let common = WebsocketCommon::new(
3640 vec![closed.clone(), good.clone(), bad.clone()],
3641 WebsocketMode::Pool(3),
3642 0,
3643 None,
3644 None,
3645 );
3646
3647 assert!(common.is_connected(None).await);
3648 assert!(!common.is_connected(Some(&closed)).await);
3649 });
3650 }
3651 }
3652
3653 mod get_available_connections {
3654 use super::*;
3655
3656 #[test]
3657 fn single_mode() {
3658 TOKIO_SHARED_RT.block_on(async {
3659 let common = WebsocketCommon::new(vec![], WebsocketMode::Single, 0, None, None);
3660 let connections = common.get_available_connections(false, None).await;
3661 assert_eq!(connections[0].id, common.connection_pool[0].id);
3662 });
3663 }
3664
3665 #[test]
3666 fn single_mode_with_url_path_does_not_force_first_connection() {
3667 TOKIO_SHARED_RT.block_on(async {
3668 let conn1 = WebsocketConnection::new("c1");
3669 let conn2 = WebsocketConnection::new("c2");
3670
3671 let (tx2, _rx2) = unbounded_channel();
3672 {
3673 let mut s2 = conn2.state.lock().await;
3674 s2.ws_write_tx = Some(tx2);
3675 s2.url_path = Some("path1".to_string());
3676 }
3677
3678 {
3679 let mut s1 = conn1.state.lock().await;
3680 s1.url_path = Some("path1".to_string());
3681 }
3682
3683 let pool = vec![conn1.clone(), conn2.clone()];
3684 let common = WebsocketCommon::new(pool, WebsocketMode::Single, 0, None, None);
3685
3686 let connections = common.get_available_connections(false, Some("path1")).await;
3687
3688 assert_eq!(connections.len(), 1);
3689 assert_eq!(connections[0].id, "c2");
3690 });
3691 }
3692
3693 #[test]
3694 fn pool_mode_not_ready() {
3695 TOKIO_SHARED_RT.block_on(async {
3696 let common =
3697 WebsocketCommon::new(vec![], WebsocketMode::Pool(2), 0, None, None);
3698 let connections = common.get_available_connections(false, None).await;
3699 assert!(connections.is_empty());
3700 });
3701 }
3702
3703 #[test]
3704 fn pool_mode_with_ready() {
3705 TOKIO_SHARED_RT.block_on(async {
3706 let conn1 = WebsocketConnection::new("c1");
3707 let conn2 = WebsocketConnection::new("c2");
3708 let (tx1, _rx1) = unbounded_channel();
3709 {
3710 let mut s1 = conn1.state.lock().await;
3711 s1.ws_write_tx = Some(tx1);
3712 }
3713 let pool = vec![conn1.clone(), conn2.clone()];
3714 let common = WebsocketCommon::new(pool, WebsocketMode::Pool(2), 0, None, None);
3715 let connections = common.get_available_connections(false, None).await;
3716 assert!(connections.len() == 1);
3717 });
3718 }
3719 }
3720
3721 mod get_connection {
3722 use super::*;
3723
3724 #[test]
3725 fn single_mode() {
3726 TOKIO_SHARED_RT.block_on(async {
3727 let common = WebsocketCommon::new(vec![], WebsocketMode::Single, 0, None, None);
3728 let conn = common
3729 .get_connection(false, None)
3730 .await
3731 .expect("should get connection");
3732 assert_eq!(conn.id, common.connection_pool[0].id);
3733 });
3734 }
3735
3736 #[test]
3737 fn single_mode_with_url_path_selects_matching_ready_connection() {
3738 TOKIO_SHARED_RT.block_on(async {
3739 let conn1 = WebsocketConnection::new("c1");
3740 let conn2 = WebsocketConnection::new("c2");
3741
3742 let (tx1, _rx1) = unbounded_channel();
3743 {
3744 let mut s1 = conn1.state.lock().await;
3745 s1.ws_write_tx = Some(tx1);
3746 s1.url_path = Some("path2".to_string());
3747 }
3748
3749 let (tx2, _rx2) = unbounded_channel();
3750 {
3751 let mut s2 = conn2.state.lock().await;
3752 s2.ws_write_tx = Some(tx2);
3753 s2.url_path = Some("path1".to_string());
3754 }
3755
3756 let pool = vec![conn1.clone(), conn2.clone()];
3757 let common = WebsocketCommon::new(pool, WebsocketMode::Single, 0, None, None);
3758
3759 let chosen = common
3760 .get_connection(false, Some("path1"))
3761 .await
3762 .expect("should get connection");
3763
3764 assert_eq!(chosen.id, "c2");
3765 });
3766 }
3767
3768 #[test]
3769 fn pool_mode_not_ready() {
3770 TOKIO_SHARED_RT.block_on(async {
3771 let common =
3772 WebsocketCommon::new(vec![], WebsocketMode::Pool(2), 0, None, None);
3773 let result = common.get_connection(false, None).await;
3774 assert!(matches!(
3775 result,
3776 Err(crate::errors::WebsocketError::NotConnected)
3777 ));
3778 });
3779 }
3780
3781 #[test]
3782 fn pool_mode_with_ready() {
3783 TOKIO_SHARED_RT.block_on(async {
3784 let conn1 = WebsocketConnection::new("c1");
3785 let conn2 = WebsocketConnection::new("c2");
3786 let (tx1, _rx1) = unbounded_channel();
3787 {
3788 let mut s1 = conn1.state.lock().await;
3789 s1.ws_write_tx = Some(tx1);
3790 }
3791 let pool = vec![conn1.clone(), conn2.clone()];
3792 let common = WebsocketCommon::new(pool, WebsocketMode::Pool(2), 0, None, None);
3793 let result = common.get_connection(false, None).await;
3794 assert!(result.is_ok());
3795 let chosen = result.unwrap();
3796 assert_eq!(chosen.id, conn1.id);
3797 });
3798 }
3799
3800 #[test]
3801 fn pool_mode_with_url_path_filters_connections() {
3802 TOKIO_SHARED_RT.block_on(async {
3803 let conn1 = WebsocketConnection::new("c1");
3804 let conn2 = WebsocketConnection::new("c2");
3805
3806 let (tx1, _rx1) = unbounded_channel();
3807 {
3808 let mut s1 = conn1.state.lock().await;
3809 s1.ws_write_tx = Some(tx1);
3810 s1.url_path = Some("path1".to_string());
3811 }
3812
3813 let (tx2, _rx2) = unbounded_channel();
3814 {
3815 let mut s2 = conn2.state.lock().await;
3816 s2.ws_write_tx = Some(tx2);
3817 s2.url_path = Some("path2".to_string());
3818 }
3819
3820 let pool = vec![conn1.clone(), conn2.clone()];
3821 let common = WebsocketCommon::new(pool, WebsocketMode::Pool(2), 0, None, None);
3822
3823 let chosen = common
3824 .get_connection(false, Some("path2"))
3825 .await
3826 .expect("should pick ready connection for path2");
3827
3828 assert_eq!(chosen.id, "c2");
3829 });
3830 }
3831
3832 #[test]
3833 fn url_path_no_match_returns_not_connected() {
3834 TOKIO_SHARED_RT.block_on(async {
3835 let conn1 = WebsocketConnection::new("c1");
3836 let (tx1, _rx1) = unbounded_channel();
3837 {
3838 let mut s1 = conn1.state.lock().await;
3839 s1.ws_write_tx = Some(tx1);
3840 s1.url_path = Some("path1".to_string());
3841 }
3842
3843 let pool = vec![conn1.clone()];
3844 let common = WebsocketCommon::new(pool, WebsocketMode::Pool(1), 0, None, None);
3845
3846 let result = common.get_connection(false, Some("path2")).await;
3847 assert!(matches!(
3848 result,
3849 Err(crate::errors::WebsocketError::NotConnected)
3850 ));
3851 });
3852 }
3853 }
3854
3855 mod close_connection_gracefully {
3856 use super::*;
3857
3858 #[tokio::test]
3859 async fn waits_for_pending_requests_then_closes() {
3860 pause();
3861
3862 let conn = WebsocketConnection::new("c1");
3863 let (tx, mut rx) = unbounded_channel::<Message>();
3864 let (req_tx, _req_rx) = oneshot::channel();
3865 {
3866 let mut st = conn.state.lock().await;
3867 st.pending_requests
3868 .insert("r".to_string(), PendingRequest { completion: req_tx });
3869 }
3870 let common =
3871 WebsocketCommon::new(vec![conn.clone()], WebsocketMode::Single, 0, None, None);
3872 let close_fut = common.close_connection_gracefully(tx.clone(), conn.clone());
3873 advance(Duration::from_secs(1)).await;
3874 {
3875 let mut st = conn.state.lock().await;
3876 st.pending_requests.clear();
3877 }
3878 conn.drain_notify.notify_waiters();
3879 advance(Duration::from_secs(1)).await;
3880 close_fut.await.unwrap();
3881 match rx.try_recv() {
3882 Ok(Message::Close(_)) => {}
3883 other => panic!("expected Close, got {other:?}"),
3884 }
3885
3886 resume();
3887 }
3888
3889 #[tokio::test]
3890 async fn force_closes_after_timeout() {
3891 pause();
3892
3893 let conn = WebsocketConnection::new("c2");
3894 let (tx, mut rx) = unbounded_channel::<Message>();
3895 let (req_tx, _req_rx) = oneshot::channel();
3896 {
3897 let mut st = conn.state.lock().await;
3898 st.pending_requests.insert(
3899 "request_id".to_string(),
3900 PendingRequest { completion: req_tx },
3901 );
3902 }
3903 let common =
3904 WebsocketCommon::new(vec![conn.clone()], WebsocketMode::Single, 0, None, None);
3905 let close_fut = common.close_connection_gracefully(tx.clone(), conn.clone());
3906 advance(Duration::from_secs(30)).await;
3907 close_fut.await.unwrap();
3908 match rx.try_recv() {
3909 Ok(Message::Close(_)) => {}
3910 other => panic!("expected Close on timeout, got {other:?}"),
3911 }
3912
3913 resume();
3914 }
3915 }
3916
3917 mod get_reconnect_url {
3918 use super::*;
3919
3920 struct DummyHandler {
3921 url: String,
3922 }
3923
3924 #[async_trait::async_trait]
3925 impl WebsocketHandler for DummyHandler {
3926 async fn on_open(&self, _url: String, _connection: Arc<WebsocketConnection>) {}
3927 async fn on_message(&self, _data: String, _connection: Arc<WebsocketConnection>) {}
3928 async fn get_reconnect_url(
3929 &self,
3930 _default_url: String,
3931 _connection: Arc<WebsocketConnection>,
3932 ) -> String {
3933 self.url.clone()
3934 }
3935 }
3936
3937 #[test]
3938 fn returns_default_when_no_handler() {
3939 TOKIO_SHARED_RT.block_on(async {
3940 let conn = WebsocketConnection::new("c1");
3941 let common = WebsocketCommon::new(
3942 vec![conn.clone()],
3943 WebsocketMode::Single,
3944 0,
3945 None,
3946 None,
3947 );
3948 let default = "wss://default".to_string();
3949 let result = common.get_reconnect_url(&default, conn.clone()).await;
3950 assert_eq!(result, default);
3951 });
3952 }
3953
3954 #[test]
3955 fn returns_handler_url_when_set() {
3956 TOKIO_SHARED_RT.block_on(async {
3957 let conn = WebsocketConnection::new("c2");
3958 let handler = Arc::new(DummyHandler {
3959 url: "wss://custom".into(),
3960 });
3961 conn.set_handler(handler).await;
3962 let common = WebsocketCommon::new(
3963 vec![conn.clone()],
3964 WebsocketMode::Single,
3965 0,
3966 None,
3967 None,
3968 );
3969 let default = "wss://default".to_string();
3970 let result = common.get_reconnect_url(&default, conn.clone()).await;
3971 assert_eq!(result, "wss://custom");
3972 });
3973 }
3974 }
3975
3976 mod on_open {
3977 use super::*;
3978
3979 struct DummyHandler {
3980 called: Arc<Mutex<bool>>,
3981 opened_url: Arc<Mutex<Option<String>>>,
3982 }
3983
3984 #[async_trait]
3985 impl WebsocketHandler for DummyHandler {
3986 async fn on_open(&self, url: String, _connection: Arc<WebsocketConnection>) {
3987 let mut flag = self.called.lock().await;
3988 *flag = true;
3989 let mut store = self.opened_url.lock().await;
3990 *store = Some(url);
3991 }
3992 async fn on_message(&self, _data: String, _connection: Arc<WebsocketConnection>) {}
3993 async fn get_reconnect_url(
3994 &self,
3995 default_url: String,
3996 _connection: Arc<WebsocketConnection>,
3997 ) -> String {
3998 default_url
3999 }
4000 }
4001
4002 #[test]
4003 fn emits_open_and_calls_handler() {
4004 TOKIO_SHARED_RT.block_on(async {
4005 let conn = WebsocketConnection::new("c1");
4006 let called = Arc::new(Mutex::new(false));
4007 let opened_url = Arc::new(Mutex::new(None));
4008 let handler = Arc::new(DummyHandler {
4009 called: called.clone(),
4010 opened_url: opened_url.clone(),
4011 });
4012
4013 conn.set_handler(handler.clone()).await;
4014 let common = WebsocketCommon::new(
4015 vec![conn.clone()],
4016 WebsocketMode::Single,
4017 0,
4018 None,
4019 None,
4020 );
4021 let events = subscribe_events(&common);
4022 common
4023 .on_open("wss://example.com".into(), conn.clone(), None)
4024 .await;
4025
4026 sleep(std::time::Duration::from_millis(10)).await;
4027
4028 let evs = events.lock().await;
4029 assert!(evs.iter().any(|e| matches!(e, WebsocketEvent::Open)));
4030 assert!(*called.lock().await);
4031 assert_eq!(
4032 opened_url.lock().await.as_deref(),
4033 Some("wss://example.com")
4034 );
4035 });
4036 }
4037
4038 #[test]
4039 fn handles_renewal_pending_and_closes_old_writer() {
4040 TOKIO_SHARED_RT.block_on(async {
4041 let conn = WebsocketConnection::new("c2");
4042 let (old_tx, mut old_rx) = unbounded_channel::<Message>();
4043 {
4044 let mut st = conn.state.lock().await;
4045 st.renewal_pending = true;
4046 }
4047 let common = WebsocketCommon::new(
4048 vec![conn.clone()],
4049 WebsocketMode::Single,
4050 0,
4051 None,
4052 None,
4053 );
4054 common
4055 .on_open("url".into(), conn.clone(), Some(old_tx.clone()))
4056 .await;
4057 assert!(!conn.state.lock().await.renewal_pending);
4058 match old_rx.try_recv() {
4059 Ok(Message::Close(_)) => {}
4060 other => panic!("expected Close, got {other:?}"),
4061 }
4062 });
4063 }
4064 }
4065
4066 mod on_message {
4067 use super::*;
4068
4069 struct DummyHandler {
4070 called_with: Arc<Mutex<Vec<String>>>,
4071 }
4072
4073 #[async_trait]
4074 impl WebsocketHandler for DummyHandler {
4075 async fn on_open(&self, _url: String, _connection: Arc<WebsocketConnection>) {}
4076 async fn on_message(&self, data: String, _connection: Arc<WebsocketConnection>) {
4077 self.called_with.lock().await.push(data);
4078 }
4079 async fn get_reconnect_url(
4080 &self,
4081 default_url: String,
4082 _connection: Arc<WebsocketConnection>,
4083 ) -> String {
4084 default_url
4085 }
4086 }
4087
4088 #[test]
4089 fn emits_message_event_without_handler() {
4090 TOKIO_SHARED_RT.block_on(async {
4091 let conn = WebsocketConnection::new("c1");
4092 let common = WebsocketCommon::new(
4093 vec![conn.clone()],
4094 WebsocketMode::Single,
4095 0,
4096 None,
4097 None,
4098 );
4099 let events = subscribe_events(&common);
4100 common.on_message("msg".into(), conn.clone()).await;
4101
4102 sleep(Duration::from_millis(10)).await;
4103
4104 let locked = events.lock().await;
4105 assert!(
4106 locked
4107 .iter()
4108 .any(|e| matches!(e, WebsocketEvent::Message(m) if m == "msg"))
4109 );
4110 });
4111 }
4112
4113 #[test]
4114 fn calls_handler_and_emits_message() {
4115 TOKIO_SHARED_RT.block_on(async {
4116 let conn = WebsocketConnection::new("c2");
4117 let called = Arc::new(Mutex::new(Vec::new()));
4118 let handler = Arc::new(DummyHandler {
4119 called_with: called.clone(),
4120 });
4121 conn.set_handler(handler.clone()).await;
4122
4123 let common = WebsocketCommon::new(
4124 vec![conn.clone()],
4125 WebsocketMode::Single,
4126 0,
4127 None,
4128 None,
4129 );
4130 let events = subscribe_events(&common);
4131 common.on_message("msg".into(), conn.clone()).await;
4132
4133 sleep(Duration::from_millis(10)).await;
4134
4135 let evs = events.lock().await;
4136 assert!(
4137 evs.iter()
4138 .any(|e| matches!(e, WebsocketEvent::Message(m) if m == "msg"))
4139 );
4140 let msgs = called.lock().await;
4141 assert_eq!(msgs.as_slice(), &["msg".to_string()]);
4142 });
4143 }
4144 }
4145
4146 mod create_websocket {
4147 use super::*;
4148
4149 #[test]
4150 fn successful_connection() {
4151 TOKIO_SHARED_RT.block_on(async {
4152 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
4153 let addr: SocketAddr = listener.local_addr().unwrap();
4154
4155 let expected_ua = build_user_agent("product");
4156 let expected_ua_clone = expected_ua.clone();
4157
4158 tokio::spawn(async move {
4159 if let Ok((stream, _)) = listener.accept().await {
4160 let callback = |req: &Request, resp| {
4161 let got = req
4162 .headers()
4163 .get(USER_AGENT)
4164 .expect("no USER_AGENT header in WS handshake")
4165 .to_str()
4166 .expect("invalid USER_AGENT header");
4167 assert_eq!(got, expected_ua_clone, "User-Agent mismatch");
4168 Ok(resp)
4169 };
4170 let _ = accept_hdr_async(stream, callback).await.unwrap();
4171 }
4172 });
4173
4174 let url = format!("ws://{addr}");
4175 let res =
4176 WebsocketCommon::create_websocket(&url, None, Some(expected_ua)).await;
4177 assert!(res.is_ok(), "handshake failed: {res:?}");
4178 });
4179 }
4180
4181 #[test]
4182 fn invalid_url_returns_handshake_error() {
4183 TOKIO_SHARED_RT.block_on(async {
4184 let res =
4185 WebsocketCommon::create_websocket("not-a-valid-url", None, None).await;
4186 assert!(matches!(res, Err(WebsocketError::Handshake(_))));
4187 });
4188 }
4189
4190 #[test]
4191 fn unreachable_host_returns_handshake_error() {
4192 TOKIO_SHARED_RT.block_on(async {
4193 let res =
4194 WebsocketCommon::create_websocket("ws://127.0.0.1:1", None, None).await;
4195 assert!(matches!(res, Err(WebsocketError::Handshake(_))));
4196 });
4197 }
4198 }
4199
4200 mod connect_pool {
4201 use super::*;
4202
4203 #[test]
4204 fn connects_all_in_pool() {
4205 TOKIO_SHARED_RT.block_on(async {
4206 let pool_size = 3;
4207 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
4208 let addr = listener.local_addr().unwrap();
4209 tokio::spawn(async move {
4210 for _ in 0..pool_size {
4211 if let Ok((stream, _)) = listener.accept().await {
4212 tokio::spawn(async move {
4213 let mut ws = accept_async(stream).await.unwrap();
4214 sleep(Duration::from_millis(500)).await;
4215 let _ = ws.close(None).await;
4216 });
4217 }
4218 }
4219 });
4220 let conns: Vec<Arc<WebsocketConnection>> = (0..pool_size)
4221 .map(|i| WebsocketConnection::new(format!("c{i}")))
4222 .collect();
4223 let common = WebsocketCommon::new(
4224 conns.clone(),
4225 WebsocketMode::Pool(pool_size),
4226 0,
4227 None,
4228 None,
4229 );
4230 let url = format!("ws://{addr}");
4231 common.clone().connect_pool(&url, None).await.unwrap();
4232 for conn in conns {
4233 let mut ok = false;
4234 for _ in 0..100 {
4235 if conn.state.lock().await.ws_write_tx.is_some() {
4236 ok = true;
4237 break;
4238 }
4239 sleep(Duration::from_millis(50)).await;
4240 }
4241 assert!(ok, "expected ws_write_tx Some after connect");
4242 }
4243 });
4244 }
4245
4246 #[test]
4247 fn fails_if_any_refused() {
4248 TOKIO_SHARED_RT.block_on(async {
4249 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
4250 let addr = listener.local_addr().unwrap();
4251 let pool_size = 3;
4252 tokio::spawn(async move {
4253 for _ in 0..2 {
4254 if let Ok((stream, _)) = listener.accept().await {
4255 let mut ws = accept_async(stream).await.unwrap();
4256 let _ = ws.close(None).await;
4257 }
4258 }
4259 });
4260 let mut conns = Vec::new();
4261 let valid_url = format!("ws://{addr}");
4262 for i in 0..2 {
4263 conns.push(WebsocketConnection::new(format!("c{i}")));
4264 }
4265 conns.push(WebsocketConnection::new("bad"));
4266 let common = WebsocketCommon::new(
4267 conns.clone(),
4268 WebsocketMode::Pool(pool_size),
4269 0,
4270 None,
4271 None,
4272 );
4273 let res = common.clone().connect_pool(&valid_url, None).await;
4274 assert!(matches!(res, Err(WebsocketError::Handshake(_))));
4275 });
4276 }
4277
4278 #[test]
4279 fn fails_on_invalid_url() {
4280 TOKIO_SHARED_RT.block_on(async {
4281 let conns = vec![WebsocketConnection::new("c1")];
4282 let common = WebsocketCommon::new(conns, WebsocketMode::Pool(1), 0, None, None);
4283 let res = common.connect_pool("not-a-url", None).await;
4284 assert!(matches!(res, Err(WebsocketError::Handshake(_))));
4285 });
4286 }
4287
4288 #[test]
4289 fn fails_if_mixed_success_and_invalid_url() {
4290 TOKIO_SHARED_RT.block_on(async {
4291 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
4292 let addr = listener.local_addr().unwrap();
4293 tokio::spawn(async move {
4294 if let Ok((stream, _)) = listener.accept().await {
4295 let mut ws = accept_async(stream).await.unwrap();
4296 let _ = ws.close(None).await;
4297 }
4298 });
4299 let good = WebsocketConnection::new("good");
4300 let bad = WebsocketConnection::new("bad");
4301 let common = WebsocketCommon::new(
4302 vec![good, bad],
4303 WebsocketMode::Pool(2),
4304 0,
4305 None,
4306 None,
4307 );
4308 let url = format!("ws://{addr}");
4309 let res = common.connect_pool(&url, None).await;
4310 assert!(matches!(res, Err(WebsocketError::Handshake(_))));
4311 });
4312 }
4313
4314 #[test]
4315 fn init_connect_invoked_for_each() {
4316 TOKIO_SHARED_RT.block_on(async {
4317 let pool_size = 2;
4318 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
4319 let addr = listener.local_addr().unwrap();
4320 tokio::spawn(async move {
4321 for _ in 0..pool_size {
4322 if let Ok((stream, _)) = listener.accept().await {
4323 tokio::spawn(async move {
4324 let mut ws = accept_async(stream).await.unwrap();
4325 sleep(Duration::from_millis(500)).await;
4326 let _ = ws.close(None).await;
4327 });
4328 }
4329 }
4330 });
4331 let conns: Vec<Arc<WebsocketConnection>> = (0..pool_size)
4332 .map(|i| WebsocketConnection::new(format!("c{i}")))
4333 .collect();
4334 let common = WebsocketCommon::new(
4335 conns.clone(),
4336 WebsocketMode::Pool(pool_size),
4337 0,
4338 None,
4339 None,
4340 );
4341 let url = format!("ws://{addr}");
4342 common.clone().connect_pool(&url, None).await.unwrap();
4343 for conn in conns {
4344 let mut ok = false;
4345 for _ in 0..100 {
4346 if conn.state.lock().await.ws_write_tx.is_some() {
4347 ok = true;
4348 break;
4349 }
4350 sleep(Duration::from_millis(25)).await;
4351 }
4352 assert!(ok, "expected ws_write_tx Some for {}", conn.id);
4353 }
4354 });
4355 }
4356
4357 #[test]
4358 fn single_mode_uses_first_connection() {
4359 TOKIO_SHARED_RT.block_on(async {
4360 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
4361 let addr = listener.local_addr().unwrap();
4362 tokio::spawn(async move {
4363 if let Ok((stream, _)) = listener.accept().await {
4364 let mut ws = accept_async(stream).await.unwrap();
4365 let _ = ws.close(None).await;
4366 }
4367 });
4368 let conn = WebsocketConnection::new("c1");
4369 let common = WebsocketCommon::new(
4370 vec![conn.clone()],
4371 WebsocketMode::Single,
4372 0,
4373 None,
4374 None,
4375 );
4376 let url = format!("ws://{addr}");
4377 common.connect_pool(&url, None).await.unwrap();
4378 let ok = eventually_async(Duration::from_secs(5), || {
4379 let conn = conn.clone();
4380 async move { conn.state.lock().await.ws_write_tx.is_some() }
4381 })
4382 .await;
4383
4384 assert!(ok, "single mode did not select first connection");
4385 });
4386 }
4387
4388 #[test]
4389 fn empty_subset_is_ok_and_connects_none() {
4390 TOKIO_SHARED_RT.block_on(async {
4391 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
4392 let addr = listener.local_addr().unwrap();
4393
4394 tokio::spawn(async move {
4395 let _ = addr;
4396 });
4397
4398 let c1 = WebsocketConnection::new("c1");
4399 let c2 = WebsocketConnection::new("c2");
4400
4401 let common = WebsocketCommon::new(
4402 vec![c1.clone(), c2.clone()],
4403 WebsocketMode::Pool(2),
4404 0,
4405 None,
4406 None,
4407 );
4408
4409 let url = format!("ws://{addr}");
4410 common
4411 .clone()
4412 .connect_pool(&url, Some(vec![]))
4413 .await
4414 .unwrap();
4415
4416 assert!(c1.state.lock().await.ws_write_tx.is_none());
4417 assert!(c2.state.lock().await.ws_write_tx.is_none());
4418 });
4419 }
4420 }
4421
4422 mod init_connect {
4423 use super::*;
4424
4425 #[test]
4426 fn pool_mode_none_connection_uses_first() {
4427 TOKIO_SHARED_RT.block_on(async {
4428 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
4429 let addr = listener.local_addr().unwrap();
4430 tokio::spawn(async move {
4431 for _ in 0..2 {
4432 if let Ok((stream, _)) = listener.accept().await {
4433 let mut ws = accept_async(stream).await.unwrap();
4434 ws.close(None).await.ok();
4435 }
4436 }
4437 });
4438
4439 let c1 = WebsocketConnection::new("c1");
4440 let c2 = WebsocketConnection::new("c2");
4441 let common = WebsocketCommon::new(
4442 vec![c1.clone(), c2.clone()],
4443 WebsocketMode::Pool(2),
4444 0,
4445 None,
4446 None,
4447 );
4448 let url = format!("ws://{addr}");
4449
4450 common
4451 .clone()
4452 .init_connect(&url, false, None)
4453 .await
4454 .unwrap();
4455
4456 let ok = eventually_async(Duration::from_secs(5), || {
4457 let conn1 = c1.clone();
4458 async move { conn1.state.lock().await.ws_write_tx.is_some() }
4459 })
4460 .await;
4461
4462 assert!(ok, "first connection was never selected");
4463 assert!(c2.state.lock().await.ws_write_tx.is_none());
4464 });
4465 }
4466
4467 #[test]
4468 fn writer_channel_can_send_text() {
4469 TOKIO_SHARED_RT.block_on(async {
4470 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
4471 let addr = listener.local_addr().unwrap();
4472 let received = Arc::new(Mutex::new(None::<String>));
4473 let received_clone = received.clone();
4474
4475 tokio::spawn(async move {
4476 if let Ok((stream, _)) = listener.accept().await {
4477 let mut ws = accept_async(stream).await.unwrap();
4478 if let Some(Ok(Message::Text(txt))) = ws.next().await {
4479 *received_clone.lock().await = Some(txt.to_string());
4480 }
4481 ws.close(None).await.ok();
4482 }
4483 });
4484
4485 let conn = WebsocketConnection::new("cw");
4486 let common = WebsocketCommon::new(
4487 vec![conn.clone()],
4488 WebsocketMode::Single,
4489 0,
4490 None,
4491 None,
4492 );
4493 let url = format!("ws://{addr}");
4494 common
4495 .clone()
4496 .init_connect(&url, false, Some(conn.clone()))
4497 .await
4498 .unwrap();
4499
4500 let tx = conn.state.lock().await.ws_write_tx.clone().unwrap();
4501 tx.send(Message::Text("ping".into())).unwrap();
4502
4503 sleep(Duration::from_millis(50)).await;
4504
4505 let lock = received.lock().await;
4506 assert_eq!(lock.as_deref(), Some("ping"));
4507 });
4508 }
4509
4510 #[test]
4511 fn does_not_skip_when_reconnection_pending_even_if_writer_exists() {
4512 TOKIO_SHARED_RT.block_on(async {
4513 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
4514 let addr = listener.local_addr().unwrap();
4515
4516 tokio::spawn(async move {
4517 if let Ok((stream, _)) = listener.accept().await {
4518 let mut ws = accept_async(stream).await.unwrap();
4519 let _ = ws.close(None).await;
4520 }
4521 });
4522
4523 let conn = WebsocketConnection::new("c-reconnect");
4524 {
4525 let mut st = conn.state.lock().await;
4526 let (tx, _) = unbounded_channel::<Message>();
4527 st.ws_write_tx = Some(tx);
4528 st.reconnection_pending = true;
4529 }
4530
4531 let common = WebsocketCommon::new(
4532 vec![conn.clone()],
4533 WebsocketMode::Single,
4534 0,
4535 None,
4536 None,
4537 );
4538
4539 let url = format!("ws://{addr}");
4540 common
4541 .clone()
4542 .init_connect(&url, false, Some(conn.clone()))
4543 .await
4544 .unwrap();
4545
4546 let st = conn.state.lock().await;
4547 assert!(
4548 st.ws_write_tx.is_some(),
4549 "writer should be set after connect"
4550 );
4551 assert!(
4552 !st.reconnection_pending,
4553 "reconnection_pending should be cleared after successful connect"
4554 );
4555 });
4556 }
4557
4558 #[test]
4559 fn responds_to_ping_with_pong() {
4560 TOKIO_SHARED_RT.block_on(async {
4561 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
4562 let addr = listener.local_addr().unwrap();
4563
4564 let saw_pong = Arc::new(Mutex::new(false));
4565 let saw_pong2 = saw_pong.clone();
4566
4567 tokio::spawn(async move {
4568 if let Ok((stream, _)) = listener.accept().await {
4569 let mut ws = accept_async(stream).await.unwrap();
4570 ws.send(Message::Ping(vec![1, 2, 3].into())).await.unwrap();
4571 if let Some(Ok(Message::Pong(payload))) = ws.next().await {
4572 if payload[..] == [1, 2, 3] {
4573 *saw_pong2.lock().await = true;
4574 }
4575 }
4576 let _ = ws.close(None).await;
4577 }
4578 });
4579
4580 let conn = WebsocketConnection::new("c-ping");
4581 let common = WebsocketCommon::new(
4582 vec![conn.clone()],
4583 WebsocketMode::Single,
4584 0,
4585 None,
4586 None,
4587 );
4588 let url = format!("ws://{addr}");
4589 common
4590 .clone()
4591 .init_connect(&url, false, Some(conn))
4592 .await
4593 .unwrap();
4594
4595 sleep(Duration::from_millis(50)).await;
4596
4597 assert!(*saw_pong.lock().await, "server should have seen a Pong");
4598 });
4599 }
4600
4601 #[test]
4602 fn handshake_error_on_invalid_url() {
4603 TOKIO_SHARED_RT.block_on(async {
4604 let conn = WebsocketConnection::new("c-invalid");
4605 let common = WebsocketCommon::new(
4606 vec![conn.clone()],
4607 WebsocketMode::Single,
4608 0,
4609 None,
4610 None,
4611 );
4612 let res = common
4613 .clone()
4614 .init_connect("not-a-url", false, Some(conn.clone()))
4615 .await;
4616 assert!(matches!(res, Err(WebsocketError::Handshake(_))));
4617 });
4618 }
4619
4620 #[test]
4621 fn skip_if_writer_exists_and_not_renewal() {
4622 TOKIO_SHARED_RT.block_on(async {
4623 let conn = WebsocketConnection::new("c-writer");
4624 let (tx, mut rx) = unbounded_channel::<Message>();
4625 {
4626 let mut st = conn.state.lock().await;
4627 st.ws_write_tx = Some(tx.clone());
4628 }
4629 let common = WebsocketCommon::new(
4630 vec![conn.clone()],
4631 WebsocketMode::Single,
4632 0,
4633 None,
4634 None,
4635 );
4636 let res = common
4637 .clone()
4638 .init_connect("ws://127.0.0.1:1", false, Some(conn.clone()))
4639 .await;
4640
4641 assert!(res.is_ok());
4642 assert!(rx.try_recv().is_err());
4643 });
4644 }
4645
4646 #[test]
4647 fn short_circuit_on_already_renewing() {
4648 TOKIO_SHARED_RT.block_on(async {
4649 let conn = WebsocketConnection::new("c-renew");
4650 {
4651 let mut st = conn.state.lock().await;
4652 st.renewal_pending = true;
4653 }
4654 let common = WebsocketCommon::new(
4655 vec![conn.clone()],
4656 WebsocketMode::Single,
4657 0,
4658 None,
4659 None,
4660 );
4661 let res = common
4662 .clone()
4663 .init_connect("ws://127.0.0.1:1", true, Some(conn.clone()))
4664 .await;
4665
4666 assert!(res.is_ok());
4667 assert!(conn.state.lock().await.ws_write_tx.is_none());
4668 });
4669 }
4670
4671 #[test]
4672 fn is_renewal_true_sets_and_clears_flag() {
4673 TOKIO_SHARED_RT.block_on(async {
4674 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
4675 let addr = listener.local_addr().unwrap();
4676 tokio::spawn(async move {
4677 if let Ok((stream, _)) = listener.accept().await {
4678 let mut ws = accept_async(stream).await.unwrap();
4679 let _ = ws.close(None).await;
4680 }
4681 });
4682
4683 let conn = WebsocketConnection::new("c-new-renew");
4684 let common = WebsocketCommon::new(
4685 vec![conn.clone()],
4686 WebsocketMode::Single,
4687 0,
4688 None,
4689 None,
4690 );
4691 let url = format!("ws://{addr}");
4692 let res = common
4693 .clone()
4694 .init_connect(&url, true, Some(conn.clone()))
4695 .await;
4696
4697 assert!(res.is_ok());
4698
4699 {
4700 let st = conn.state.lock().await;
4701 assert!(st.ws_write_tx.is_some(), "writer should be set");
4702 assert!(
4703 st.renewal_pending,
4704 "renewal_pending must be true until on_open"
4705 );
4706 }
4707
4708 common.on_open(url.clone(), conn.clone(), None).await;
4709
4710 let st = conn.state.lock().await;
4711 assert!(
4712 !st.renewal_pending,
4713 "renewal_pending should be cleared in on_open"
4714 );
4715 });
4716 }
4717
4718 #[test]
4719 fn default_connection_selected_when_none_passed() {
4720 TOKIO_SHARED_RT.block_on(async {
4721 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
4722 let addr = listener.local_addr().unwrap();
4723 tokio::spawn(async move {
4724 if let Ok((stream, _)) = listener.accept().await {
4725 let mut ws = accept_async(stream).await.unwrap();
4726 let _ = ws.close(None).await;
4727 }
4728 });
4729 let conn = WebsocketConnection::new("c-default");
4730 let common = WebsocketCommon::new(
4731 vec![conn.clone()],
4732 WebsocketMode::Single,
4733 0,
4734 None,
4735 None,
4736 );
4737 let url = format!("ws://{addr}");
4738 let res = common.clone().init_connect(&url, false, None).await;
4739
4740 assert!(res.is_ok());
4741 let ok = eventually_async(Duration::from_secs(5), || {
4742 let conn = conn.clone();
4743 async move { conn.state.lock().await.ws_write_tx.is_some() }
4744 })
4745 .await;
4746
4747 assert!(ok, "default connection was never selected");
4748 });
4749 }
4750
4751 #[test]
4752 fn schedules_reconnect_on_abnormal_close() {
4753 TOKIO_SHARED_RT.block_on(async {
4754 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
4755 let addr = listener.local_addr().unwrap();
4756 tokio::spawn(async move {
4757 if let Ok((stream, _)) = listener.accept().await {
4758 let mut ws = accept_async(stream).await.unwrap();
4759 ws.close(Some(tungstenite::protocol::CloseFrame {
4760 code: tungstenite::protocol::frame::coding::CloseCode::Abnormal,
4761 reason: "oops".into(),
4762 }))
4763 .await
4764 .ok();
4765 }
4766 });
4767 let conn = WebsocketConnection::new("c-close");
4768 let common = WebsocketCommon::new(
4769 vec![conn.clone()],
4770 WebsocketMode::Single,
4771 5_000,
4772 None,
4773 None,
4774 );
4775 let url = format!("ws://{addr}");
4776 common
4777 .clone()
4778 .init_connect(&url, false, Some(conn.clone()))
4779 .await
4780 .unwrap();
4781
4782 sleep(Duration::from_millis(50)).await;
4783
4784 let st = conn.state.lock().await;
4785 assert!(
4786 st.reconnection_pending,
4787 "expected reconnection_pending to be true after abnormal close"
4788 );
4789 assert!(
4790 st.ws_write_tx.is_none(),
4791 "ws_write_tx should be cleared when scheduling a reconnect"
4792 );
4793 });
4794 }
4795 }
4796
4797 mod disconnect {
4798 use super::*;
4799
4800 #[test]
4801 fn returns_ok_when_no_connections_are_ready() {
4802 TOKIO_SHARED_RT.block_on(async {
4803 let conn = WebsocketConnection::new("c1");
4804 let common = WebsocketCommon::new(
4805 vec![conn.clone()],
4806 WebsocketMode::Single,
4807 0,
4808 None,
4809 None,
4810 );
4811 let res = common.disconnect().await;
4812
4813 assert!(res.is_ok());
4814 assert!(!conn.state.lock().await.close_initiated);
4815 });
4816 }
4817
4818 #[test]
4819 fn closes_all_ready_connections() {
4820 TOKIO_SHARED_RT.block_on(async {
4821 let conn1 = WebsocketConnection::new("c1");
4822 let conn2 = WebsocketConnection::new("c2");
4823 let (tx1, mut rx1) = unbounded_channel::<Message>();
4824 let (tx2, mut rx2) = unbounded_channel::<Message>();
4825 {
4826 let mut s1 = conn1.state.lock().await;
4827 s1.ws_write_tx = Some(tx1);
4828 }
4829 {
4830 let mut s2 = conn2.state.lock().await;
4831 s2.ws_write_tx = Some(tx2);
4832 }
4833 let common = WebsocketCommon::new(
4834 vec![conn1.clone(), conn2.clone()],
4835 WebsocketMode::Pool(2),
4836 0,
4837 None,
4838 None,
4839 );
4840 let fut = common.disconnect();
4841
4842 sleep(Duration::from_millis(50)).await;
4843
4844 fut.await.unwrap();
4845
4846 assert!(conn1.state.lock().await.close_initiated);
4847 assert!(conn2.state.lock().await.close_initiated);
4848
4849 {
4850 let st = conn1.state.lock().await;
4851 assert!(!st.is_session_logged_on, "conn1 should be logged out");
4852 assert!(st.session_logon_req.is_none(), "conn1 req cleared");
4853 }
4854 {
4855 let st = conn2.state.lock().await;
4856 assert!(!st.is_session_logged_on, "conn2 should be logged out");
4857 assert!(st.session_logon_req.is_none(), "conn2 req cleared");
4858 }
4859
4860 match (rx1.try_recv(), rx2.try_recv()) {
4861 (Ok(Message::Close(_)), Ok(Message::Close(_))) => {}
4862 other => panic!("expected two Close frames, got {other:?}"),
4863 }
4864 });
4865 }
4866
4867 #[test]
4868 fn does_not_mark_close_initiated_if_no_writer() {
4869 TOKIO_SHARED_RT.block_on(async {
4870 let conn = WebsocketConnection::new("c-new");
4871 let common = WebsocketCommon::new(
4872 vec![conn.clone()],
4873 WebsocketMode::Single,
4874 0,
4875 None,
4876 None,
4877 );
4878 common.disconnect().await.unwrap();
4879
4880 assert!(!conn.state.lock().await.close_initiated);
4881 });
4882 }
4883
4884 #[test]
4885 fn mixed_pool_marks_all_and_closes_only_writers() {
4886 TOKIO_SHARED_RT.block_on(async {
4887 let conn_w = WebsocketConnection::new("with");
4888 let conn_wo = WebsocketConnection::new("without");
4889 let (tx, mut rx) = unbounded_channel::<Message>();
4890 {
4891 let mut st = conn_w.state.lock().await;
4892 st.ws_write_tx = Some(tx);
4893 }
4894 let common = WebsocketCommon::new(
4895 vec![conn_w.clone(), conn_wo.clone()],
4896 WebsocketMode::Pool(2),
4897 0,
4898 None,
4899 None,
4900 );
4901 let fut = common.disconnect();
4902
4903 sleep(Duration::from_millis(50)).await;
4904
4905 fut.await.unwrap();
4906
4907 assert!(conn_w.state.lock().await.close_initiated);
4908 assert!(conn_wo.state.lock().await.close_initiated);
4909 assert!(matches!(rx.try_recv(), Ok(Message::Close(_))));
4910 });
4911 }
4912
4913 #[test]
4914 fn after_disconnect_not_connected() {
4915 TOKIO_SHARED_RT.block_on(async {
4916 let conn = WebsocketConnection::new("c1");
4917 let (tx, mut _rx) = unbounded_channel::<Message>();
4918 {
4919 let mut st = conn.state.lock().await;
4920 st.ws_write_tx = Some(tx);
4921 }
4922 let common = WebsocketCommon::new(
4923 vec![conn.clone()],
4924 WebsocketMode::Single,
4925 0,
4926 None,
4927 None,
4928 );
4929 common.disconnect().await.unwrap();
4930 assert!(!common.is_connected(Some(&conn)).await);
4931 });
4932 }
4933 }
4934
4935 mod ping_server {
4936 use super::*;
4937
4938 #[test]
4939 fn sends_ping_to_all_ready_connections() {
4940 TOKIO_SHARED_RT.block_on(async {
4941 let mut conns = Vec::new();
4942 for i in 0..3 {
4943 let conn = WebsocketConnection::new(format!("c{i}"));
4944 let (tx, rx) = unbounded_channel::<Message>();
4945 {
4946 let mut st = conn.state.lock().await;
4947 st.ws_write_tx = Some(tx);
4948 }
4949 conns.push((conn, rx));
4950 }
4951 let common = WebsocketCommon::new(
4952 conns.iter().map(|(c, _)| c.clone()).collect(),
4953 WebsocketMode::Pool(3),
4954 0,
4955 None,
4956 None,
4957 );
4958 common.ping_server().await;
4959 for (_, mut rx) in conns {
4960 match rx.try_recv() {
4961 Ok(Message::Ping(payload)) if payload.is_empty() => {}
4962 other => panic!("expected empty-payload Ping, got {other:?}"),
4963 }
4964 }
4965 });
4966 }
4967
4968 #[test]
4969 fn skips_not_ready_and_partial() {
4970 TOKIO_SHARED_RT.block_on(async {
4971 let ready = WebsocketConnection::new("ready");
4972 let not_ready = WebsocketConnection::new("not-ready");
4973 let (tx_r, mut rx_r) = unbounded_channel::<Message>();
4974 {
4975 let mut st = ready.state.lock().await;
4976 st.ws_write_tx = Some(tx_r);
4977 }
4978 {
4979 let mut st = not_ready.state.lock().await;
4980 st.ws_write_tx = None;
4981 }
4982 let common = WebsocketCommon::new(
4983 vec![ready.clone(), not_ready.clone()],
4984 WebsocketMode::Pool(2),
4985 0,
4986 None,
4987 None,
4988 );
4989 common.ping_server().await;
4990 match rx_r.try_recv() {
4991 Ok(Message::Ping(payload)) if payload.is_empty() => {}
4992 other => panic!("expected Ping on ready, got {other:?}"),
4993 }
4994 });
4995 }
4996
4997 #[test]
4998 fn no_ping_when_flags_block() {
4999 TOKIO_SHARED_RT.block_on(async {
5000 let conn = WebsocketConnection::new("c1");
5001 let (tx, mut rx) = unbounded_channel::<Message>();
5002 {
5003 let mut st = conn.state.lock().await;
5004 st.ws_write_tx = Some(tx);
5005 st.reconnection_pending = true;
5006 }
5007 let common = WebsocketCommon::new(
5008 vec![conn.clone()],
5009 WebsocketMode::Single,
5010 0,
5011 None,
5012 None,
5013 );
5014 common.ping_server().await;
5015 assert!(rx.try_recv().is_err());
5016 });
5017 }
5018 }
5019
5020 mod send {
5021 use super::*;
5022
5023 #[test]
5024 fn round_robin_send_without_specific() {
5025 TOKIO_SHARED_RT.block_on(async {
5026 let conn1 = WebsocketConnection::new("c1");
5027 let conn2 = WebsocketConnection::new("c2");
5028 let (tx1, mut rx1) = unbounded_channel::<Message>();
5029 let (tx2, mut rx2) = unbounded_channel::<Message>();
5030 {
5031 let mut s1 = conn1.state.lock().await;
5032 s1.ws_write_tx = Some(tx1);
5033 }
5034 {
5035 let mut s2 = conn2.state.lock().await;
5036 s2.ws_write_tx = Some(tx2);
5037 }
5038 let common = WebsocketCommon::new(
5039 vec![conn1.clone(), conn2.clone()],
5040 WebsocketMode::Pool(2),
5041 0,
5042 None,
5043 None,
5044 );
5045
5046 let res1 = common
5047 .send("a".into(), None, false, Duration::from_secs(1), None)
5048 .await
5049 .unwrap();
5050 assert!(res1.is_none());
5051
5052 let res2 = common
5053 .send("b".into(), None, false, Duration::from_secs(1), None)
5054 .await
5055 .unwrap();
5056 assert!(res2.is_none());
5057
5058 assert_eq!(
5059 if let Message::Text(t) = rx1.try_recv().unwrap() {
5060 t
5061 } else {
5062 panic!()
5063 },
5064 "a"
5065 );
5066 assert_eq!(
5067 if let Message::Text(t) = rx2.try_recv().unwrap() {
5068 t
5069 } else {
5070 panic!()
5071 },
5072 "b"
5073 );
5074 });
5075 }
5076
5077 #[test]
5078 fn round_robin_skips_not_ready() {
5079 TOKIO_SHARED_RT.block_on(async {
5080 let conn1 = WebsocketConnection::new("c1");
5081 let conn2 = WebsocketConnection::new("c2");
5082 let (tx2, mut rx2) = unbounded_channel::<Message>();
5083 {
5084 let mut s1 = conn1.state.lock().await;
5085 s1.ws_write_tx = None;
5086 }
5087 {
5088 let mut s2 = conn2.state.lock().await;
5089 s2.ws_write_tx = Some(tx2);
5090 }
5091 let common = WebsocketCommon::new(
5092 vec![conn1.clone(), conn2.clone()],
5093 WebsocketMode::Pool(2),
5094 0,
5095 None,
5096 None,
5097 );
5098 let res = common
5099 .send("bar".into(), None, false, Duration::from_secs(1), None)
5100 .await
5101 .unwrap();
5102 assert!(res.is_none());
5103 match rx2.try_recv().unwrap() {
5104 Message::Text(t) => assert_eq!(t, "bar"),
5105 other => panic!("unexpected {other:?}"),
5106 }
5107 });
5108 }
5109
5110 #[test]
5111 fn sync_send_on_specific_connection() {
5112 TOKIO_SHARED_RT.block_on(async {
5113 let conn1 = WebsocketConnection::new("c1");
5114 let conn2 = WebsocketConnection::new("c2");
5115 let (tx2, mut rx2) = unbounded_channel::<Message>();
5116 {
5117 let mut st = conn2.state.lock().await;
5118 st.ws_write_tx = Some(tx2);
5119 }
5120 let common = WebsocketCommon::new(
5121 vec![conn1.clone(), conn2.clone()],
5122 WebsocketMode::Pool(2),
5123 0,
5124 None,
5125 None,
5126 );
5127 let res = common
5128 .send(
5129 "payload".into(),
5130 Some("id".into()),
5131 false,
5132 Duration::from_secs(1),
5133 Some(conn2.clone()),
5134 )
5135 .await
5136 .unwrap();
5137 assert!(res.is_none());
5138 match rx2.try_recv() {
5139 Ok(Message::Text(t)) => assert_eq!(t, "payload"),
5140 other => panic!("expected Text, got {other:?}"),
5141 }
5142 });
5143 }
5144
5145 #[test]
5146 fn sync_send_with_id_does_not_insert_pending() {
5147 TOKIO_SHARED_RT.block_on(async {
5148 let conn = WebsocketConnection::new("c1");
5149 let (tx, mut rx) = unbounded_channel::<Message>();
5150 {
5151 let mut st = conn.state.lock().await;
5152 st.ws_write_tx = Some(tx);
5153 }
5154 let common = WebsocketCommon::new(
5155 vec![conn.clone()],
5156 WebsocketMode::Single,
5157 0,
5158 None,
5159 None,
5160 );
5161 let res = common
5162 .send(
5163 "msg".into(),
5164 Some("id".into()),
5165 false,
5166 Duration::from_secs(1),
5167 Some(conn.clone()),
5168 )
5169 .await
5170 .unwrap();
5171 assert!(res.is_none());
5172 assert!(conn.state.lock().await.pending_requests.is_empty());
5173 match rx.try_recv().unwrap() {
5174 Message::Text(t) => assert_eq!(t, "msg"),
5175 other => panic!("unexpected {other:?}"),
5176 }
5177 });
5178 }
5179
5180 #[test]
5181 fn sync_send_error_if_not_ready() {
5182 TOKIO_SHARED_RT.block_on(async {
5183 let conn = WebsocketConnection::new("c1");
5184 let common = WebsocketCommon::new(
5185 vec![conn.clone()],
5186 WebsocketMode::Single,
5187 0,
5188 None,
5189 None,
5190 );
5191 let err = common
5192 .send(
5193 "msg".into(),
5194 Some("id".into()),
5195 false,
5196 Duration::from_secs(1),
5197 Some(conn.clone()),
5198 )
5199 .await
5200 .unwrap_err();
5201 assert!(matches!(err, WebsocketError::NotConnected));
5202 });
5203 }
5204
5205 #[test]
5206 fn sync_send_error_when_no_ready() {
5207 TOKIO_SHARED_RT.block_on(async {
5208 let conn = WebsocketConnection::new("c1");
5209 let common = WebsocketCommon::new(
5210 vec![conn.clone()],
5211 WebsocketMode::Single,
5212 0,
5213 None,
5214 None,
5215 );
5216 let err = common
5217 .send("msg".into(), None, false, Duration::from_secs(1), None)
5218 .await
5219 .unwrap_err();
5220 assert!(matches!(err, WebsocketError::NotConnected));
5221 });
5222 }
5223
5224 #[test]
5225 fn async_send_and_receive() {
5226 TOKIO_SHARED_RT.block_on(async {
5227 let conn = WebsocketConnection::new("c1");
5228 let (tx, mut rx) = unbounded_channel::<Message>();
5229 {
5230 let mut st = conn.state.lock().await;
5231 st.ws_write_tx = Some(tx);
5232 }
5233 let common = WebsocketCommon::new(
5234 vec![conn.clone()],
5235 WebsocketMode::Single,
5236 0,
5237 None,
5238 None,
5239 );
5240 let fut = common
5241 .send(
5242 "hello".into(),
5243 Some("id".into()),
5244 true,
5245 Duration::from_secs(5),
5246 Some(conn.clone()),
5247 )
5248 .await
5249 .unwrap()
5250 .unwrap();
5251 match rx.try_recv() {
5252 Ok(Message::Text(t)) => assert_eq!(t, "hello"),
5253 other => panic!("expected Text, got {other:?}"),
5254 }
5255 {
5256 let mut st = conn.state.lock().await;
5257 let pr = st.pending_requests.remove("id").unwrap();
5258 pr.completion.send(Ok(serde_json::json!("ok"))).unwrap();
5259 }
5260 let resp = fut.await.unwrap().unwrap();
5261 assert_eq!(resp, serde_json::json!("ok"));
5262 });
5263 }
5264
5265 #[test]
5266 fn async_send_default_connection() {
5267 TOKIO_SHARED_RT.block_on(async {
5268 let conn = WebsocketConnection::new("c1");
5269 let (tx, mut rx) = unbounded_channel::<Message>();
5270 {
5271 let mut st = conn.state.lock().await;
5272 st.ws_write_tx = Some(tx);
5273 }
5274 let common = WebsocketCommon::new(
5275 vec![conn.clone()],
5276 WebsocketMode::Single,
5277 0,
5278 None,
5279 None,
5280 );
5281 let fut = common
5282 .send(
5283 "msg".into(),
5284 Some("id".into()),
5285 true,
5286 Duration::from_secs(5),
5287 None,
5288 )
5289 .await
5290 .unwrap()
5291 .unwrap();
5292 match rx.try_recv() {
5293 Ok(Message::Text(t)) => assert_eq!(t, "msg"),
5294 _ => panic!("no text"),
5295 }
5296 {
5297 let mut st = conn.state.lock().await;
5298 let pr = st.pending_requests.remove("id").unwrap();
5299 pr.completion.send(Ok(serde_json::json!(123))).unwrap();
5300 }
5301 let resp = fut.await.unwrap().unwrap();
5302 assert_eq!(resp, serde_json::json!(123));
5303 });
5304 }
5305
5306 #[test]
5307 fn async_send_error_if_no_id() {
5308 TOKIO_SHARED_RT.block_on(async {
5309 let conn = WebsocketConnection::new("c§");
5310 let (tx, _rx) = unbounded_channel::<Message>();
5311 {
5312 let mut st = conn.state.lock().await;
5313 st.ws_write_tx = Some(tx);
5314 }
5315 let common = WebsocketCommon::new(
5316 vec![conn.clone()],
5317 WebsocketMode::Single,
5318 0,
5319 None,
5320 None,
5321 );
5322 let err = common
5323 .send(
5324 "msg".into(),
5325 None,
5326 true,
5327 Duration::from_secs(1),
5328 Some(conn.clone()),
5329 )
5330 .await
5331 .unwrap_err();
5332 assert!(matches!(err, WebsocketError::NotConnected));
5333 });
5334 }
5335
5336 #[test]
5337 fn timeout_rejects_async() {
5338 TOKIO_SHARED_RT.block_on(async {
5339 pause();
5340 let conn = WebsocketConnection::new("c1");
5341 let (tx, _rx) = unbounded_channel::<Message>();
5342 {
5343 let mut st = conn.state.lock().await;
5344 st.ws_write_tx = Some(tx);
5345 }
5346 let common = WebsocketCommon::new(
5347 vec![conn.clone()],
5348 WebsocketMode::Single,
5349 0,
5350 None,
5351 None,
5352 );
5353 let fut = common
5354 .send(
5355 "msg".into(),
5356 Some("id".into()),
5357 true,
5358 Duration::from_secs(1),
5359 Some(conn.clone()),
5360 )
5361 .await
5362 .unwrap()
5363 .unwrap();
5364 advance(Duration::from_secs(1)).await;
5365 let res = fut.await.unwrap();
5366 assert!(res.is_err(), "expected timeout error");
5367 assert!(!conn.state.lock().await.pending_requests.contains_key("id"));
5368 });
5369 }
5370
5371 #[test]
5372 fn async_send_errors_if_no_connection_ready() {
5373 TOKIO_SHARED_RT.block_on(async {
5374 let conn = WebsocketConnection::new("c1");
5375 let common = WebsocketCommon::new(
5376 vec![conn.clone()],
5377 WebsocketMode::Single,
5378 0,
5379 None,
5380 None,
5381 );
5382 let err = common
5383 .send(
5384 "msg".into(),
5385 Some("id".into()),
5386 true,
5387 Duration::from_secs(1),
5388 None,
5389 )
5390 .await
5391 .unwrap_err();
5392 assert!(matches!(err, WebsocketError::NotConnected));
5393 });
5394 }
5395 }
5396 }
5397
5398 mod websocket_api {
5399 use super::*;
5400
5401 mod initialisation {
5402 use super::*;
5403
5404 #[test]
5405 fn new_initializes_common() {
5406 TOKIO_SHARED_RT.block_on(async {
5407 let conn = WebsocketConnection::new("id");
5408 let pool = vec![conn.clone()];
5409
5410 let sig_gen = SignatureGenerator::new(
5411 Some("api_secret".to_string()),
5412 None::<PrivateKey>,
5413 None::<String>,
5414 );
5415
5416 let config = ConfigurationWebsocketApi {
5417 api_key: Some("api_key".to_string()),
5418 api_secret: Some("api_secret".to_string()),
5419 private_key: None,
5420 private_key_passphrase: None,
5421 ws_url: Some("wss://example".to_string()),
5422 mode: WebsocketMode::Single,
5423 reconnect_delay: 1000,
5424 signature_gen: sig_gen,
5425 timeout: 500,
5426 time_unit: None,
5427 auto_session_relogon: false,
5428 agent: None,
5429 user_agent: build_user_agent("product"),
5430 };
5431
5432 let api = WebsocketApi::new(config, pool.clone());
5433
5434 assert_eq!(api.common.connection_pool.len(), 1);
5435 assert_eq!(api.common.mode, WebsocketMode::Single);
5436
5437 let flag = *api.is_connecting.lock().await;
5438 assert!(!flag);
5439 });
5440 }
5441 }
5442
5443 mod connect {
5444 use super::*;
5445
5446 #[test]
5447 fn connect_when_not_connected_establishes() {
5448 TOKIO_SHARED_RT.block_on(async {
5449 let conn = WebsocketConnection::new("id");
5450 {
5451 let mut st = conn.state.lock().await;
5452 st.ws_write_tx = None;
5453 }
5454 let sig = SignatureGenerator::new(
5455 Some("api_secret".into()),
5456 None::<PrivateKey>,
5457 None::<String>,
5458 );
5459 let cfg = ConfigurationWebsocketApi {
5460 api_key: Some("api_key".into()),
5461 api_secret: Some("api_secret".to_string()),
5462 private_key: None,
5463 private_key_passphrase: None,
5464 ws_url: Some("ws://doesnotexist:1".to_string()),
5465 mode: WebsocketMode::Single,
5466 reconnect_delay: 0,
5467 signature_gen: sig,
5468 timeout: 10,
5469 time_unit: None,
5470 auto_session_relogon: false,
5471 agent: None,
5472 user_agent: build_user_agent("product"),
5473 };
5474 let api = WebsocketApi::new(cfg, vec![conn.clone()]);
5475 let res = api.clone().connect().await;
5476 assert!(!matches!(res, Err(WebsocketError::Timeout)));
5477 });
5478 }
5479
5480 #[test]
5481 fn already_connected_returns_ok() {
5482 TOKIO_SHARED_RT.block_on(async {
5483 let conn = WebsocketConnection::new("id2");
5484 let (tx, _) = unbounded_channel();
5485 {
5486 let mut st = conn.state.lock().await;
5487 st.ws_write_tx = Some(tx);
5488 }
5489 let sig = SignatureGenerator::new(
5490 Some("api_secret".to_string()),
5491 None::<PrivateKey>,
5492 None::<String>,
5493 );
5494 let cfg = ConfigurationWebsocketApi {
5495 api_key: Some("api_key".to_string()),
5496 api_secret: Some("api_secret".to_string()),
5497 private_key: None,
5498 private_key_passphrase: None,
5499 ws_url: Some("ws://example.com".to_string()),
5500 mode: WebsocketMode::Single,
5501 reconnect_delay: 0,
5502 signature_gen: sig,
5503 timeout: 10,
5504 time_unit: None,
5505 auto_session_relogon: false,
5506 agent: None,
5507 user_agent: build_user_agent("product"),
5508 };
5509 let api = WebsocketApi::new(cfg, vec![conn.clone()]);
5510 let res = api.connect().await;
5511 assert!(res.is_ok());
5512 });
5513 }
5514
5515 #[test]
5516 fn not_connected_returns_error() {
5517 TOKIO_SHARED_RT.block_on(async {
5518 let conn = WebsocketConnection::new("id1");
5519 let sig = SignatureGenerator::new(
5520 Some("api_secret".to_string()),
5521 None::<PrivateKey>,
5522 None::<String>,
5523 );
5524 let cfg = ConfigurationWebsocketApi {
5525 api_key: Some("api_key".to_string()),
5526 api_secret: Some("api_secret".to_string()),
5527 private_key: None,
5528 private_key_passphrase: None,
5529 ws_url: Some("ws://127.0.0.1:9".to_string()),
5530 mode: WebsocketMode::Single,
5531 reconnect_delay: 0,
5532 signature_gen: sig,
5533 timeout: 10,
5534 time_unit: None,
5535 auto_session_relogon: false,
5536 agent: None,
5537 user_agent: build_user_agent("product"),
5538 };
5539 let api = WebsocketApi::new(cfg, vec![conn.clone()]);
5540 let res = api.connect().await;
5541 assert!(res.is_err());
5542 });
5543 }
5544
5545 #[test]
5546 fn concurrent_calls_both_error_or_ok() {
5547 TOKIO_SHARED_RT.block_on(async {
5548 let conn = WebsocketConnection::new("id3");
5549 let sig = SignatureGenerator::new(
5550 Some("api_secret".to_string()),
5551 None::<PrivateKey>,
5552 None::<String>,
5553 );
5554 let cfg = ConfigurationWebsocketApi {
5555 api_key: Some("api_key".to_string()),
5556 api_secret: Some("api_secret".to_string()),
5557 private_key: None,
5558 private_key_passphrase: None,
5559 ws_url: Some("wss://invalid-domain".to_string()),
5560 mode: WebsocketMode::Single,
5561 reconnect_delay: 0,
5562 signature_gen: sig,
5563 timeout: 10,
5564 time_unit: None,
5565 auto_session_relogon: false,
5566 agent: None,
5567 user_agent: build_user_agent("product"),
5568 };
5569 let api = WebsocketApi::new(cfg, vec![conn.clone()]);
5570 let fut1 = tokio::spawn(api.clone().connect());
5571 let fut2 = tokio::spawn(api.clone().connect());
5572 let r1 = fut1.await.unwrap();
5573 let r2 = fut2.await.unwrap();
5574
5575 assert!(r1.is_err());
5576 assert!(r2.is_err() || r2.is_ok());
5577 });
5578 }
5579
5580 #[test]
5581 fn pool_failure_is_propagated() {
5582 TOKIO_SHARED_RT.block_on(async {
5583 let conn = WebsocketConnection::new("w");
5584 let sig = SignatureGenerator::new(
5585 Some("api_secret".to_string()),
5586 None::<PrivateKey>,
5587 None::<String>,
5588 );
5589 let cfg = ConfigurationWebsocketApi {
5590 api_key: Some("api_key".into()),
5591 api_secret: Some("api_secret".to_string()),
5592 private_key: None,
5593 private_key_passphrase: None,
5594 ws_url: Some("ws://doesnotexist:1".to_string()),
5595 mode: WebsocketMode::Single,
5596 reconnect_delay: 0,
5597 signature_gen: sig,
5598 timeout: 10,
5599 time_unit: None,
5600 auto_session_relogon: false,
5601 agent: None,
5602 user_agent: build_user_agent("product"),
5603 };
5604 let api = WebsocketApi::new(cfg, vec![conn.clone()]);
5605 let res = api.clone().connect().await;
5606 match res {
5607 Err(WebsocketError::Handshake(_) | WebsocketError::Timeout) => {}
5608 _ => panic!("expected handshake or timeout error"),
5609 }
5610 });
5611 }
5612 }
5613
5614 mod send_message {
5615 use super::*;
5616
5617 #[test]
5618 fn unsigned_message() {
5619 TOKIO_SHARED_RT.block_on(async {
5620 let api = create_websocket_api(None, None, None);
5621 let conn = &api.common.connection_pool[0];
5622 let (tx, mut rx) = unbounded_channel::<Message>();
5623 {
5624 let mut st = conn.state.lock().await;
5625 st.ws_write_tx = Some(tx);
5626 }
5627
5628 let fut = tokio::spawn({
5629 let api = api.clone();
5630 async move {
5631 let mut params = BTreeMap::new();
5632 params.insert("foo".into(), Value::String("bar".into()));
5633 let send_res = api
5634 .send_message::<Value>(
5635 "method",
5636 params,
5637 WebsocketMessageSendOptions {
5638 with_api_key: false,
5639 is_signed: false,
5640 ..Default::default()
5641 },
5642 )
5643 .await
5644 .unwrap();
5645
5646 match send_res {
5647 SendWebsocketMessageResult::Single(resp) => resp,
5648 SendWebsocketMessageResult::Multiple(_) => {
5649 panic!("expected single response")
5650 }
5651 }
5652 }
5653 });
5654
5655 let Message::Text(txt) = rx.recv().await.unwrap() else {
5656 panic!()
5657 };
5658 let req: Value = serde_json::from_str(&txt).unwrap();
5659 assert_eq!(req["method"], "method");
5660 assert_eq!(req["params"]["foo"], "bar");
5661 assert!(req["params"].get("apiKey").is_none());
5662 assert!(req["params"].get("timestamp").is_none());
5663 assert!(req["params"].get("signature").is_none());
5664
5665 let id = req["id"].as_str().unwrap().to_string();
5666 let mut st = conn.state.lock().await;
5667 let pending = st.pending_requests.remove(&id).unwrap();
5668 let reply = json!({
5669 "id": id,
5670 "result": { "x": 42 },
5671 "rateLimits": [{ "limit": 7 }]
5672 });
5673 pending.completion.send(Ok(reply)).unwrap();
5674
5675 let resp = fut.await.unwrap();
5676 let rate_limits = resp.rate_limits.unwrap_or_default();
5677
5678 assert!(rate_limits.is_empty());
5679 assert_eq!(resp.raw, json!({"x": 42}));
5680 });
5681 }
5682
5683 #[test]
5684 fn with_api_key_only() {
5685 TOKIO_SHARED_RT.block_on(async {
5686 let api = create_websocket_api(None, None, None);
5687 let conn = &api.common.connection_pool[0];
5688 let (tx, mut rx) = unbounded_channel::<Message>();
5689 {
5690 let mut st = conn.state.lock().await;
5691 st.ws_write_tx = Some(tx);
5692 }
5693
5694 let fut = tokio::spawn({
5695 let api = api.clone();
5696 async move {
5697 let params = BTreeMap::new();
5698 let send_res = api
5699 .send_message::<Value>(
5700 "method",
5701 params,
5702 WebsocketMessageSendOptions {
5703 with_api_key: true,
5704 is_signed: false,
5705 ..Default::default()
5706 },
5707 )
5708 .await
5709 .unwrap();
5710
5711 match send_res {
5712 SendWebsocketMessageResult::Single(resp) => resp,
5713 SendWebsocketMessageResult::Multiple(_) => {
5714 panic!("expected single response")
5715 }
5716 }
5717 }
5718 });
5719
5720 let Message::Text(txt) = rx.recv().await.unwrap() else {
5721 panic!()
5722 };
5723 let req: Value = serde_json::from_str(&txt).unwrap();
5724 assert_eq!(req["params"]["apiKey"], "api_key");
5725
5726 let id = req["id"].as_str().unwrap().to_string();
5727 let mut st = conn.state.lock().await;
5728 let pending = st.pending_requests.remove(&id).unwrap();
5729 pending
5730 .completion
5731 .send(Ok(json!({
5732 "id": id,
5733 "result": {},
5734 "rateLimits": []
5735 })))
5736 .unwrap();
5737
5738 let resp = fut.await.unwrap();
5739
5740 assert_eq!(resp.raw, json!({}));
5741 assert!(st.pending_requests.is_empty());
5742 });
5743 }
5744
5745 #[test]
5746 fn signed_message_has_timestamp_and_signature() {
5747 TOKIO_SHARED_RT.block_on(async {
5748 let api = create_websocket_api(None, None, None);
5749 let conn = &api.common.connection_pool[0];
5750 let (tx, mut rx) = unbounded_channel::<Message>();
5751 {
5752 let mut st = conn.state.lock().await;
5753 st.ws_write_tx = Some(tx);
5754 }
5755
5756 let fut = tokio::spawn({
5757 let api = api.clone();
5758 async move {
5759 let mut params = BTreeMap::new();
5760 params.insert("foo".into(), Value::String("bar".into()));
5761 let send_res = api
5762 .send_message::<Value>(
5763 "method",
5764 params,
5765 WebsocketMessageSendOptions {
5766 with_api_key: true,
5767 is_signed: true,
5768 ..Default::default()
5769 },
5770 )
5771 .await
5772 .unwrap();
5773
5774 match send_res {
5775 SendWebsocketMessageResult::Single(resp) => resp,
5776 SendWebsocketMessageResult::Multiple(_) => {
5777 panic!("expected single response")
5778 }
5779 }
5780 }
5781 });
5782
5783 let Message::Text(txt) = rx.recv().await.unwrap() else {
5784 panic!()
5785 };
5786 let req: Value = serde_json::from_str(&txt).unwrap();
5787 let p = &req["params"];
5788 assert_eq!(p["apiKey"], "api_key");
5789 assert!(p["timestamp"].is_number());
5790 assert!(p["signature"].is_string());
5791
5792 let id = req["id"].as_str().unwrap().to_string();
5793 let mut st = conn.state.lock().await;
5794 let pending = st.pending_requests.remove(&id).unwrap();
5795 pending
5796 .completion
5797 .send(Ok(json!({
5798 "id": id,
5799 "result": { "ok": true },
5800 "rateLimits": []
5801 })))
5802 .unwrap();
5803
5804 let resp = fut.await.unwrap();
5805 assert_eq!(resp.raw, json!({ "ok": true }));
5806 });
5807 }
5808
5809 #[test]
5810 fn multi_session_logon() {
5811 TOKIO_SHARED_RT.block_on(async {
5812 let api = create_websocket_api(None, Some(WebsocketMode::Pool(2)), None);
5813 let conn0 = &api.common.connection_pool[0];
5814 let conn1 = &api.common.connection_pool[1];
5815
5816 let (tx0, mut rx0) = unbounded_channel::<Message>();
5817 let (tx1, mut rx1) = unbounded_channel::<Message>();
5818 {
5819 let mut st0 = conn0.state.lock().await;
5820 st0.ws_write_tx = Some(tx0);
5821 }
5822 {
5823 let mut st1 = conn1.state.lock().await;
5824 st1.ws_write_tx = Some(tx1);
5825 }
5826
5827 let fut = tokio::spawn({
5828 let api = api.clone();
5829 async move {
5830 let params = BTreeMap::new();
5831 let send_res = api
5832 .send_message::<Value>(
5833 "method",
5834 params,
5835 WebsocketMessageSendOptions {
5836 is_session_logon: Some(true),
5837 ..Default::default()
5838 },
5839 )
5840 .await
5841 .unwrap();
5842
5843 match send_res {
5844 SendWebsocketMessageResult::Multiple(v) => v,
5845 SendWebsocketMessageResult::Single(_) => {
5846 panic!("expected multiple responses")
5847 }
5848 }
5849 }
5850 });
5851
5852 let Message::Text(txt0) = rx0.recv().await.unwrap() else {
5853 panic!()
5854 };
5855 let Message::Text(txt1) = rx1.recv().await.unwrap() else {
5856 panic!()
5857 };
5858 let req0: Value = serde_json::from_str(&txt0).unwrap();
5859 let req1: Value = serde_json::from_str(&txt1).unwrap();
5860 assert_eq!(req0["method"], "method");
5861 assert_eq!(req1["method"], "method");
5862 let id = req0["id"].as_str().unwrap().to_string();
5863 assert_eq!(req1["id"].as_str().unwrap(), &id);
5864
5865 {
5866 let mut st0 = conn0.state.lock().await;
5867 let pending0 = st0.pending_requests.remove(&id).unwrap();
5868 pending0
5869 .completion
5870 .send(Ok(json!({
5871 "id": id,
5872 "result": { "ok": true },
5873 "rateLimits": []
5874 })))
5875 .unwrap();
5876 }
5877 {
5878 let mut st1 = conn1.state.lock().await;
5879 let pending1 = st1.pending_requests.remove(&id).unwrap();
5880 pending1
5881 .completion
5882 .send(Ok(json!({
5883 "id": id,
5884 "result": { "ok": true },
5885 "rateLimits": []
5886 })))
5887 .unwrap();
5888 }
5889
5890 let results = fut.await.unwrap();
5891 assert_eq!(results.len(), 2);
5892
5893 for conn in &api.common.connection_pool {
5894 let st = conn.state.lock().await;
5895 assert!(st.is_session_logged_on, "should be logged out");
5896 assert!(st.session_logon_req.is_some(), "req cleared");
5897
5898 }
5914 });
5915 }
5916
5917 #[test]
5918 fn multi_session_logout() {
5919 TOKIO_SHARED_RT.block_on(async {
5920 let api = create_websocket_api(None, Some(WebsocketMode::Pool(2)), None);
5921
5922 for conn in &api.common.connection_pool {
5923 let (tx, _rx) = unbounded_channel::<Message>();
5924 let mut st = conn.state.lock().await;
5925 st.ws_write_tx = Some(tx);
5926 st.is_session_logged_on = true;
5927 st.session_logon_req = Some(WebsocketSessionLogonReq {
5928 method: "method".into(),
5929 payload: BTreeMap::new(),
5930 options: WebsocketMessageSendOptions::default(),
5931 });
5932 }
5933
5934 let mut rxs = Vec::new();
5935 for conn in &api.common.connection_pool {
5936 let rx = {
5937 let (tx, rx) = unbounded_channel::<Message>();
5938 conn.state.lock().await.ws_write_tx = Some(tx);
5939 rx
5940 };
5941 rxs.push(rx);
5942 }
5943
5944 let fut = tokio::spawn({
5945 let api = api.clone();
5946 async move {
5947 let send_res = api
5948 .send_message::<Value>(
5949 "method",
5950 BTreeMap::new(),
5951 WebsocketMessageSendOptions {
5952 is_signed: false,
5953 with_api_key: false,
5954 is_session_logout: Some(true),
5955 ..Default::default()
5956 },
5957 )
5958 .await
5959 .unwrap();
5960
5961 match send_res {
5962 SendWebsocketMessageResult::Multiple(v) => v,
5963 SendWebsocketMessageResult::Single(_) => panic!("expected multi"),
5964 }
5965 }
5966 });
5967
5968 let mut ids = Vec::new();
5969 for mut rx in rxs {
5970 let Message::Text(txt) = rx.recv().await.unwrap() else {
5971 panic!()
5972 };
5973 let req: Value = serde_json::from_str(&txt).unwrap();
5974 assert_eq!(req["method"], "method");
5975 ids.push(req["id"].as_str().unwrap().to_string());
5976 }
5977
5978 assert_eq!(ids[0], ids[1]);
5979
5980 for conn in &api.common.connection_pool {
5981 let id = &ids[0];
5982 let mut st = conn.state.lock().await;
5983 let pending = st.pending_requests.remove(id).unwrap();
5984 pending
5985 .completion
5986 .send(Ok(json!({
5987 "id": id,
5988 "result": {},
5989 "rateLimits": []
5990 })))
5991 .unwrap();
5992 }
5993
5994 let results = fut.await.unwrap();
5995 assert_eq!(results.len(), 2);
5996
5997 for conn in &api.common.connection_pool {
5998 let st = conn.state.lock().await;
5999 assert!(!st.is_session_logged_on, "should be logged out");
6000 assert!(st.session_logon_req.is_none(), "req cleared");
6001 }
6002 });
6003 }
6004
6005 #[test]
6006 fn skip_signature_when_logged_on_and_auto_relogon() {
6007 TOKIO_SHARED_RT.block_on(async {
6008 let api = create_websocket_api(None, Some(WebsocketMode::Single), None);
6009 let conn = &api.common.connection_pool[0];
6010 {
6011 let mut st = conn.state.lock().await;
6012 st.ws_write_tx = Some(unbounded_channel::<Message>().0);
6013 st.is_session_logged_on = true;
6014 }
6015
6016 let mut rx;
6017 {
6018 let mut st = conn.state.lock().await;
6019 let (tx, new_rx) = unbounded_channel::<Message>();
6020 st.ws_write_tx = Some(tx);
6021 rx = new_rx;
6022 }
6023
6024 let fut = tokio::spawn({
6025 let api = api.clone();
6026 async move {
6027 let send_res = api
6028 .send_message::<Value>(
6029 "method",
6030 BTreeMap::new(),
6031 WebsocketMessageSendOptions {
6032 is_signed: true,
6033 ..Default::default()
6034 },
6035 )
6036 .await
6037 .unwrap();
6038
6039 match send_res {
6040 SendWebsocketMessageResult::Single(resp) => resp,
6041 SendWebsocketMessageResult::Multiple(_) => {
6042 panic!("expected single")
6043 }
6044 }
6045 }
6046 });
6047
6048 let Message::Text(txt) = rx.recv().await.unwrap() else {
6049 panic!()
6050 };
6051 let req: Value = serde_json::from_str(&txt).unwrap();
6052 let p = &req["params"];
6053 assert!(p.get("timestamp").is_some());
6054 assert!(p.get("signature").is_none());
6055
6056 let id = req["id"].as_str().unwrap();
6057 let mut st = conn.state.lock().await;
6058 let pending = st.pending_requests.remove(id).unwrap();
6059 pending
6060 .completion
6061 .send(Ok(json!({
6062 "id": id,
6063 "result": {},
6064 "rateLimits": []
6065 })))
6066 .unwrap();
6067
6068 let resp = fut.await.unwrap();
6069 assert_eq!(resp.raw, json!({}));
6070 });
6071 }
6072
6073 #[test]
6074 fn include_signature_when_logged_on_and_no_auto_relogon() {
6075 TOKIO_SHARED_RT.block_on(async {
6076 let api = create_websocket_api(None, Some(WebsocketMode::Single), Some(false));
6077 let conn = &api.common.connection_pool[0];
6078 {
6079 let mut st = conn.state.lock().await;
6080 st.ws_write_tx = Some(unbounded_channel::<Message>().0);
6081 st.is_session_logged_on = true;
6082 }
6083
6084 let mut rx;
6085 {
6086 let mut st = conn.state.lock().await;
6087 let (tx, new_rx) = unbounded_channel::<Message>();
6088 st.ws_write_tx = Some(tx);
6089 rx = new_rx;
6090 }
6091
6092 let fut = tokio::spawn({
6093 let api = api.clone();
6094 async move {
6095 let send_res = api
6096 .send_message::<Value>(
6097 "method",
6098 BTreeMap::new(),
6099 WebsocketMessageSendOptions {
6100 is_signed: true,
6101 ..Default::default()
6102 },
6103 )
6104 .await
6105 .unwrap();
6106
6107 match send_res {
6108 SendWebsocketMessageResult::Single(resp) => resp,
6109 SendWebsocketMessageResult::Multiple(_) => {
6110 panic!("expected single")
6111 }
6112 }
6113 }
6114 });
6115
6116 let Message::Text(txt) = rx.recv().await.unwrap() else {
6117 panic!()
6118 };
6119 let req: Value = serde_json::from_str(&txt).unwrap();
6120 let p = &req["params"];
6121 assert!(p.get("timestamp").is_some());
6122 assert!(p.get("signature").is_some());
6123
6124 let id = req["id"].as_str().unwrap();
6125 let mut st = conn.state.lock().await;
6126 let pending = st.pending_requests.remove(id).unwrap();
6127 pending
6128 .completion
6129 .send(Ok(json!({
6130 "id": id,
6131 "result": {},
6132 "rateLimits": []
6133 })))
6134 .unwrap();
6135
6136 let resp = fut.await.unwrap();
6137 assert_eq!(resp.raw, json!({}));
6138 });
6139 }
6140
6141 #[test]
6142 fn error_if_not_connected() {
6143 TOKIO_SHARED_RT.block_on(async {
6144 let api = create_websocket_api(None, None, None);
6145 let conn = &api.common.connection_pool[0];
6146 {
6147 let mut st = conn.state.lock().await;
6148 st.ws_write_tx = None;
6149 }
6150 let params = BTreeMap::new();
6151 let err = api
6152 .send_message::<Value>(
6153 "method",
6154 params,
6155 WebsocketMessageSendOptions {
6156 with_api_key: false,
6157 is_signed: false,
6158 ..Default::default()
6159 },
6160 )
6161 .await
6162 .unwrap_err();
6163 matches!(err, WebsocketError::NotConnected);
6164 });
6165 }
6166 }
6167
6168 mod prepare_url {
6169 use super::*;
6170
6171 #[test]
6172 fn no_time_unit() {
6173 TOKIO_SHARED_RT.block_on(async {
6174 let api = create_websocket_api(None, None, None);
6175 let url = "wss://example.com/ws".to_string();
6176 assert_eq!(api.prepare_url(&url), url);
6177 });
6178 }
6179
6180 #[test]
6181 fn appends_time_unit() {
6182 TOKIO_SHARED_RT.block_on(async {
6183 let api = create_websocket_api(Some(TimeUnit::Millisecond), None, None);
6184 let base = "wss://example.com/ws".to_string();
6185 let got = api.prepare_url(&base);
6186 assert_eq!(got, format!("{base}?timeUnit=millisecond"));
6187 });
6188 }
6189
6190 #[test]
6191 fn handles_existing_query() {
6192 TOKIO_SHARED_RT.block_on(async {
6193 let api = create_websocket_api(Some(TimeUnit::Microsecond), None, None);
6194 let base = "wss://example.com/ws?foo=bar".to_string();
6195 let got = api.prepare_url(&base);
6196 assert_eq!(got, format!("{base}&timeUnit=microsecond"));
6197 });
6198 }
6199 }
6200
6201 mod on_open {
6202 use super::*;
6203
6204 fn create_websocket_api_and_conn() -> (Arc<WebsocketApi>, Arc<WebsocketConnection>) {
6205 let sig_gen = SignatureGenerator::new(
6206 Some("api_secret".to_string()),
6207 None::<_>,
6208 None::<String>,
6209 );
6210 let config = ConfigurationWebsocketApi {
6211 api_key: Some("api_key".to_string()),
6212 api_secret: Some("api_secret".to_string()),
6213 private_key: None,
6214 private_key_passphrase: None,
6215 ws_url: Some("wss://example".to_string()),
6216 mode: WebsocketMode::Single,
6217 reconnect_delay: 0,
6218 signature_gen: sig_gen,
6219 timeout: 1000,
6220 time_unit: None,
6221 auto_session_relogon: true,
6222 agent: None,
6223 user_agent: build_user_agent("product"),
6224 };
6225 let conn = WebsocketConnection::new("test-conn");
6226 let api = WebsocketApi::new(config, vec![conn.clone()]);
6227 (api, conn)
6228 }
6229
6230 #[test]
6231 fn session_relogon_on_open() {
6232 TOKIO_SHARED_RT.block_on(async {
6233 let (api, conn) = create_websocket_api_and_conn();
6234
6235 let req = WebsocketSessionLogonReq {
6236 method: "method".into(),
6237 payload: {
6238 let mut m = BTreeMap::new();
6239 m.insert("foo".into(), Value::String("bar".into()));
6240 m
6241 },
6242 options: WebsocketMessageSendOptions {
6243 with_api_key: true,
6244 is_signed: true,
6245 is_session_logon: Some(true),
6246 ..Default::default()
6247 },
6248 };
6249
6250 let (tx, mut rx) = unbounded_channel::<Message>();
6251 {
6252 let mut st = conn.state.lock().await;
6253 st.session_logon_req = Some(req.clone());
6254 st.is_session_logged_on = false;
6255 st.ws_write_tx = Some(tx);
6256 }
6257
6258 api.on_open("wss://example".to_string(), conn.clone()).await;
6259
6260 let Message::Text(raw) = rx.recv().await.unwrap() else {
6261 panic!("expected a Text message");
6262 };
6263 let msg: Value = serde_json::from_str(&raw).unwrap();
6264 assert_eq!(msg["method"], "method");
6265 assert_eq!(msg["params"]["foo"], "bar");
6266
6267 let id = msg["id"].as_str().unwrap().to_string();
6268 {
6269 let mut st = conn.state.lock().await;
6270 let pending = st.pending_requests.remove(&id).expect("pending request");
6271 pending
6272 .completion
6273 .send(Ok(json!({
6274 "id": id,
6275 "result": {},
6276 "rateLimits": []
6277 })))
6278 .unwrap();
6279 }
6280
6281 sleep(Duration::from_millis(10)).await;
6282
6283 let st = conn.state.lock().await;
6284 assert!(st.is_session_logged_on, "should now be logged on");
6285 });
6286 }
6287
6288 #[test]
6289 fn no_relogon_if_already_logged_on() {
6290 TOKIO_SHARED_RT.block_on(async {
6291 let (api, conn) = create_websocket_api_and_conn();
6292
6293 let req = WebsocketSessionLogonReq {
6294 method: "method".into(),
6295 payload: BTreeMap::new(),
6296 options: WebsocketMessageSendOptions {
6297 is_session_logon: Some(true),
6298 ..Default::default()
6299 },
6300 };
6301
6302 let (tx, mut rx) = unbounded_channel::<Message>();
6303 {
6304 let mut st = conn.state.lock().await;
6305 st.session_logon_req = Some(req);
6306 st.is_session_logged_on = true;
6307 st.ws_write_tx = Some(tx);
6308 }
6309
6310 api.on_open("wss://example".to_string(), conn.clone()).await;
6311
6312 assert!(rx.try_recv().is_err(), "no re‐logon when already on");
6313
6314 let st = conn.state.lock().await;
6315 assert!(st.is_session_logged_on);
6316 });
6317 }
6318
6319 #[test]
6320 fn session_relogon_fails_gracefully() {
6321 TOKIO_SHARED_RT.block_on(async {
6322 let (api, conn) = create_websocket_api_and_conn();
6323
6324 let req = WebsocketSessionLogonReq {
6325 method: "method".into(),
6326 payload: {
6327 let mut m = BTreeMap::new();
6328 m.insert("x".into(), Value::Number(1.into()));
6329 m
6330 },
6331 options: WebsocketMessageSendOptions {
6332 is_session_logon: Some(true),
6333 ..Default::default()
6334 },
6335 };
6336 {
6337 let mut st = conn.state.lock().await;
6338 st.session_logon_req = Some(req);
6339 st.is_session_logged_on = false;
6340 st.ws_write_tx = None;
6341 }
6342
6343 api.on_open("wss://example".into(), conn.clone()).await;
6344
6345 let st = conn.state.lock().await;
6346 assert!(
6347 !st.is_session_logged_on,
6348 "should remain logged‐off on failure"
6349 );
6350 });
6351 }
6352
6353 #[test]
6354 fn session_relogon_noop_when_no_req() {
6355 TOKIO_SHARED_RT.block_on(async {
6356 let (api, conn) = create_websocket_api_and_conn();
6357
6358 {
6359 let mut st = conn.state.lock().await;
6360 st.session_logon_req = None;
6361 st.is_session_logged_on = false;
6362 st.ws_write_tx = Some(unbounded_channel::<Message>().0);
6363 }
6364
6365 api.on_open("wss://example".into(), conn.clone()).await;
6366
6367 let st = conn.state.lock().await;
6368 assert!(!st.is_session_logged_on, "still logged‐off");
6369 });
6370 }
6371 }
6372
6373 mod on_message {
6374 use super::*;
6375
6376 fn create_websocket_api_and_conn() -> (Arc<WebsocketApi>, Arc<WebsocketConnection>) {
6377 let sig_gen = SignatureGenerator::new(
6378 Some("api_secret".to_string()),
6379 None::<_>,
6380 None::<String>,
6381 );
6382 let config = ConfigurationWebsocketApi {
6383 api_key: Some("api_key".to_string()),
6384 api_secret: Some("api_secret".to_string()),
6385 private_key: None,
6386 private_key_passphrase: None,
6387 ws_url: Some("wss://example".to_string()),
6388 mode: WebsocketMode::Single,
6389 reconnect_delay: 0,
6390 signature_gen: sig_gen,
6391 timeout: 1000,
6392 time_unit: None,
6393 auto_session_relogon: false,
6394 agent: None,
6395 user_agent: build_user_agent("product"),
6396 };
6397 let conn = WebsocketConnection::new("test");
6398 let api = WebsocketApi::new(config, vec![conn.clone()]);
6399 (api, conn)
6400 }
6401
6402 #[test]
6403 fn resolves_pending_and_removes_request() {
6404 TOKIO_SHARED_RT.block_on(async {
6405 let (api, conn) = create_websocket_api_and_conn();
6406 let (tx, rx) = oneshot::channel();
6407 {
6408 let mut st = conn.state.lock().await;
6409 st.pending_requests
6410 .insert("id1".to_string(), PendingRequest { completion: tx });
6411 }
6412 let msg = json!({"id":"id1","status":200,"foo":"bar"});
6413 api.on_message(msg.to_string(), conn.clone()).await;
6414 let got = rx.await.unwrap().unwrap();
6415 assert_eq!(got, msg);
6416 let st = conn.state.lock().await;
6417 assert!(!st.pending_requests.contains_key("id1"));
6418 });
6419 }
6420
6421 #[test]
6422 fn uses_result_when_present() {
6423 TOKIO_SHARED_RT.block_on(async {
6424 let (api, conn) = create_websocket_api_and_conn();
6425 let (tx, rx) = oneshot::channel();
6426 {
6427 let mut st = conn.state.lock().await;
6428 st.pending_requests
6429 .insert("id1".to_string(), PendingRequest { completion: tx });
6430 }
6431 let msg = json!({
6432 "id": "id1",
6433 "status": 200,
6434 "response": [1,2],
6435 "result": {"a":1}
6436 });
6437 api.on_message(msg.to_string(), conn.clone()).await;
6438 let got = rx.await.unwrap().unwrap();
6439 assert_eq!(got.get("result").unwrap(), &json!({"a":1}));
6440 });
6441 }
6442
6443 #[test]
6444 fn uses_response_when_no_result() {
6445 TOKIO_SHARED_RT.block_on(async {
6446 let (api, conn) = create_websocket_api_and_conn();
6447 let (tx, rx) = oneshot::channel();
6448 {
6449 let mut st = conn.state.lock().await;
6450 st.pending_requests
6451 .insert("id1".to_string(), PendingRequest { completion: tx });
6452 }
6453 let msg = json!({
6454 "id": "id1",
6455 "status": 200,
6456 "response": ["ok"]
6457 });
6458 api.on_message(msg.to_string(), conn.clone()).await;
6459 let got = rx.await.unwrap().unwrap();
6460 assert_eq!(got.get("response").unwrap(), &json!(["ok"]));
6461 });
6462 }
6463
6464 #[test]
6465 fn errors_for_status_ge_400() {
6466 TOKIO_SHARED_RT.block_on(async {
6467 let (api, conn) = create_websocket_api_and_conn();
6468 let (tx, rx) = oneshot::channel();
6469 {
6470 let mut st = conn.state.lock().await;
6471 st.pending_requests
6472 .insert("bad".to_string(), PendingRequest { completion: tx });
6473 }
6474 let err_obj = json!({"code":123,"msg":"oops"});
6475 let msg = json!({"id":"bad","status":500,"error":err_obj});
6476 api.on_message(msg.to_string(), conn.clone()).await;
6477 match rx.await.unwrap() {
6478 Err(WebsocketError::ResponseError { code, message }) => {
6479 assert_eq!(code, 123);
6480 assert_eq!(message, "oops");
6481 }
6482 other => panic!("expected ResponseError, got {other:?}"),
6483 }
6484 let st = conn.state.lock().await;
6485 assert!(!st.pending_requests.contains_key("bad"));
6486 });
6487 }
6488
6489 #[test]
6490 fn ignores_unknown_id() {
6491 TOKIO_SHARED_RT.block_on(async {
6492 let (api, conn) = create_websocket_api_and_conn();
6493 let msg = json!({"id":"nope","status":200});
6494 api.on_message(msg.to_string(), conn.clone()).await;
6495 let st = conn.state.lock().await;
6496 assert!(st.pending_requests.is_empty());
6497 });
6498 }
6499
6500 #[test]
6501 fn parse_error_ignored() {
6502 TOKIO_SHARED_RT.block_on(async {
6503 let (api, conn) = create_websocket_api_and_conn();
6504 api.on_message("not json".to_string(), conn.clone()).await;
6505 let st = conn.state.lock().await;
6506 assert!(st.pending_requests.is_empty());
6507 });
6508 }
6509
6510 #[test]
6511 fn error_status_sends_error() {
6512 TOKIO_SHARED_RT.block_on(async {
6513 let (api, conn) = create_websocket_api_and_conn();
6514 let (tx, rx) = oneshot::channel();
6515 {
6516 let mut st = conn.state.lock().await;
6517 st.pending_requests
6518 .insert("err".to_string(), PendingRequest { completion: tx });
6519 }
6520 let msg = json!({
6521 "id": "err",
6522 "status": 500,
6523 "error": { "code": 42, "msg": "Bad!" }
6524 });
6525 api.on_message(msg.to_string(), conn.clone()).await;
6526 match rx.await.unwrap() {
6527 Err(WebsocketError::ResponseError { code, message }) => {
6528 assert_eq!(code, 42);
6529 assert_eq!(message, "Bad!");
6530 }
6531 other => panic!("expected ResponseError, got {other:?}"),
6532 }
6533 });
6534 }
6535
6536 #[test]
6537 fn unknown_id_logs_warning_and_leaves_pending() {
6538 TOKIO_SHARED_RT.block_on(async {
6539 let (api, conn) = create_websocket_api_and_conn();
6540 {
6541 let mut st = conn.state.lock().await;
6542 st.pending_requests.insert(
6543 "keep".to_string(),
6544 PendingRequest {
6545 completion: oneshot::channel().0,
6546 },
6547 );
6548 }
6549 api.on_message(
6550 json!({ "id": "foo", "status": 200, "result": 1 }).to_string(),
6551 conn.clone(),
6552 )
6553 .await;
6554 let st = conn.state.lock().await;
6555 assert!(st.pending_requests.contains_key("keep"));
6556 });
6557 }
6558 }
6559 }
6560
6561 mod websocket_streams {
6562 use super::*;
6563
6564 mod initialisation {
6565 use super::*;
6566
6567 #[test]
6568 fn new_initializes_fields() {
6569 TOKIO_SHARED_RT.block_on(async {
6570 let config = ConfigurationWebsocketStreams {
6571 ws_url: Some("wss://example".to_string()),
6572 mode: WebsocketMode::Pool(2),
6573 reconnect_delay: 500,
6574 time_unit: None,
6575 agent: None,
6576 user_agent: build_user_agent("product"),
6577 };
6578 let conn1 = WebsocketConnection::new("c1");
6579 let conn2 = WebsocketConnection::new("c2");
6580 let api = WebsocketStreams::new(
6581 config.clone(),
6582 vec![conn1.clone(), conn2.clone()],
6583 vec![],
6584 );
6585
6586 assert_eq!(api.common.connection_pool.len(), 2);
6587 assert!(Arc::ptr_eq(&api.common.connection_pool[0], &conn1));
6588 assert!(Arc::ptr_eq(&api.common.connection_pool[1], &conn2));
6589 assert_eq!(api.configuration.ws_url, Some("wss://example".to_string()));
6590 let flag = api.is_connecting.lock().await;
6591 assert!(!*flag);
6592 });
6593 }
6594
6595 #[test]
6596 fn new_expands_pool_when_url_paths_present() {
6597 TOKIO_SHARED_RT.block_on(async {
6598 let config = ConfigurationWebsocketStreams {
6599 ws_url: Some("wss://example".to_string()),
6600 mode: WebsocketMode::Pool(2),
6601 reconnect_delay: 500,
6602 time_unit: None,
6603 agent: None,
6604 user_agent: build_user_agent("product"),
6605 };
6606
6607 let conn1 = WebsocketConnection::new("c1");
6608 let conn2 = WebsocketConnection::new("c2");
6609
6610 let api = WebsocketStreams::new(
6611 config,
6612 vec![conn1.clone(), conn2.clone()],
6613 vec!["path1".to_string(), "path2".to_string()],
6614 );
6615
6616 assert_eq!(api.common.connection_pool.len(), 4);
6617 assert!(Arc::ptr_eq(&api.common.connection_pool[0], &conn1));
6618 assert!(Arc::ptr_eq(&api.common.connection_pool[1], &conn2));
6619 });
6620 }
6621
6622 #[test]
6623 fn new_does_not_expand_pool_when_already_sized_for_url_paths() {
6624 TOKIO_SHARED_RT.block_on(async {
6625 let config = ConfigurationWebsocketStreams {
6626 ws_url: Some("wss://example".to_string()),
6627 mode: WebsocketMode::Pool(2),
6628 reconnect_delay: 500,
6629 time_unit: None,
6630 agent: None,
6631 user_agent: build_user_agent("product"),
6632 };
6633
6634 let conns = vec![
6635 WebsocketConnection::new("c1"),
6636 WebsocketConnection::new("c2"),
6637 WebsocketConnection::new("c3"),
6638 WebsocketConnection::new("c4"),
6639 ];
6640
6641 let api = WebsocketStreams::new(
6642 config,
6643 conns.clone(),
6644 vec!["path1".to_string(), "path2".to_string()],
6645 );
6646
6647 assert_eq!(api.common.connection_pool.len(), 4);
6648 for (i, c) in conns.iter().enumerate() {
6649 assert!(Arc::ptr_eq(&api.common.connection_pool[i], c));
6650 }
6651 });
6652 }
6653 }
6654
6655 mod connect {
6656 use super::*;
6657
6658 #[test]
6659 fn establishes_successfully() {
6660 TOKIO_SHARED_RT.block_on(async {
6661 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
6662 let port = listener.local_addr().unwrap().port();
6663
6664 tokio::spawn(async move {
6665 for _ in 0..2 {
6666 if let Ok((stream, _)) = listener.accept().await {
6667 let mut ws = accept_async(stream).await.unwrap();
6668 ws.close(None).await.ok();
6669 }
6670 }
6671 });
6672
6673 let create_websocket_streams = |ws_url: &str| {
6674 let c1 = WebsocketConnection::new("c1");
6675 let c2 = WebsocketConnection::new("c2");
6676 let config = ConfigurationWebsocketStreams {
6677 ws_url: Some(ws_url.to_string()),
6678 mode: WebsocketMode::Pool(2),
6679 reconnect_delay: 500,
6680 time_unit: None,
6681 agent: None,
6682 user_agent: build_user_agent("product"),
6683 };
6684 WebsocketStreams::new(config, vec![c1, c2], vec![])
6685 };
6686
6687 let url = format!("ws://127.0.0.1:{port}");
6688 let ws = create_websocket_streams(&url);
6689
6690 let res = ws.connect(vec!["stream1".into()]).await;
6691 assert!(res.is_ok());
6692 });
6693 }
6694
6695 #[test]
6696 fn establishes_successfully_with_url_paths() {
6697 TOKIO_SHARED_RT.block_on(async {
6698 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
6699 let addr = listener.local_addr().unwrap();
6700
6701 tokio::spawn(async move {
6702 for _ in 0..4 {
6703 if let Ok((stream, _)) = listener.accept().await {
6704 let mut ws = accept_async(stream).await.unwrap();
6705 ws.close(None).await.ok();
6706 }
6707 }
6708 });
6709
6710 let config = ConfigurationWebsocketStreams {
6711 ws_url: Some(format!("ws://{}", addr)),
6712 mode: WebsocketMode::Pool(2),
6713 reconnect_delay: 500,
6714 time_unit: None,
6715 agent: None,
6716 user_agent: build_user_agent("product"),
6717 };
6718
6719 let ws = WebsocketStreams::new(
6720 config,
6721 vec![],
6722 vec!["path1".to_string(), "path2".to_string()],
6723 );
6724
6725 let res = ws.clone().connect(vec!["stream1".into()]).await;
6726 assert!(res.is_ok());
6727 assert_eq!(ws.common.connection_pool.len(), 4);
6728 });
6729 }
6730
6731 #[test]
6732 fn connect_sets_url_path_on_connections_when_url_paths_present() {
6733 TOKIO_SHARED_RT.block_on(async {
6734 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
6735 let addr = listener.local_addr().unwrap();
6736
6737 tokio::spawn(async move {
6738 for _ in 0..4 {
6739 if let Ok((stream, _)) = listener.accept().await {
6740 let mut ws = accept_async(stream).await.unwrap();
6741 ws.close(None).await.ok();
6742 }
6743 }
6744 });
6745
6746 let config = ConfigurationWebsocketStreams {
6747 ws_url: Some(format!("ws://{}", addr)),
6748 mode: WebsocketMode::Pool(2),
6749 reconnect_delay: 500,
6750 time_unit: None,
6751 agent: None,
6752 user_agent: build_user_agent("product"),
6753 };
6754
6755 let ws = WebsocketStreams::new(
6756 config,
6757 vec![],
6758 vec!["path1".to_string(), "path2".to_string()],
6759 );
6760
6761 ws.clone().connect(vec!["stream1".into()]).await.unwrap();
6762
6763 let pool_size = ws.configuration.mode.pool_size();
6764
6765 for (i, conn) in ws.common.connection_pool.iter().enumerate() {
6766 let expected = if i < pool_size { "path1" } else { "path2" };
6767 let st = conn.state.lock().await;
6768 assert_eq!(st.url_path.as_deref(), Some(expected));
6769 }
6770 });
6771 }
6772
6773 #[test]
6774 fn refused_returns_error() {
6775 TOKIO_SHARED_RT.block_on(async {
6776 let ws = create_websocket_streams(Some("ws://127.0.0.1:9"), None, None);
6777 let res = ws.connect(vec!["stream1".into()]).await;
6778 assert!(res.is_err());
6779 });
6780 }
6781
6782 #[test]
6783 fn invalid_url_returns_error() {
6784 TOKIO_SHARED_RT.block_on(async {
6785 let ws = create_websocket_streams(Some("not-a-url"), None, None);
6786 let res = ws.connect(vec!["s".into()]).await;
6787 assert!(res.is_err());
6788 });
6789 }
6790 }
6791
6792 mod disconnect {
6793 use super::*;
6794
6795 #[test]
6796 fn disconnect_clears_state_and_streams() {
6797 TOKIO_SHARED_RT.block_on(async {
6798 let ws = create_websocket_streams(None, None, None);
6799 let conn = &ws.common.connection_pool[0];
6800 {
6801 let mut state = conn.state.lock().await;
6802 state.stream_callbacks.insert("s1".to_string(), Vec::new());
6803 state.pending_subscriptions.push_back("s2".to_string());
6804 }
6805 {
6806 let mut map = ws.connection_streams.lock().await;
6807 map.insert("s3".to_string(), Arc::clone(conn));
6808 }
6809
6810 let res = ws.disconnect().await;
6811 assert!(res.is_ok());
6812
6813 let state = conn.state.lock().await;
6814 assert!(state.stream_callbacks.is_empty());
6815 assert!(state.pending_subscriptions.is_empty());
6816
6817 let map = ws.connection_streams.lock().await;
6818 assert!(map.is_empty());
6819 });
6820 }
6821 }
6822
6823 mod subscribe {
6824 use super::*;
6825
6826 #[test]
6827 fn empty_list_does_nothing() {
6828 TOKIO_SHARED_RT.block_on(async {
6829 let ws = create_websocket_streams(None, None, None);
6830 ws.clone().subscribe(Vec::new(), None, None).await;
6831 let map = ws.connection_streams.lock().await;
6832 assert!(map.is_empty());
6833 });
6834 }
6835
6836 #[test]
6837 fn queue_when_not_ready() {
6838 TOKIO_SHARED_RT.block_on(async {
6839 let ws = create_websocket_streams(None, None, None);
6840 let conn = ws.common.connection_pool[0].clone();
6841 ws.clone().subscribe(vec!["s1".into()], None, None).await;
6842 let state = conn.state.lock().await;
6843 let pending: Vec<String> =
6844 state.pending_subscriptions.iter().cloned().collect();
6845 assert_eq!(pending, vec!["s1".to_string()]);
6846 });
6847 }
6848
6849 #[test]
6850 fn only_one_subscription_per_stream() {
6851 TOKIO_SHARED_RT.block_on(async {
6852 let ws = create_websocket_streams(None, None, None);
6853 let conn = ws.common.connection_pool[0].clone();
6854 ws.clone().subscribe(vec!["s1".into()], None, None).await;
6855 ws.clone().subscribe(vec!["s1".into()], None, None).await;
6856 let state = conn.state.lock().await;
6857 let pending: Vec<String> =
6858 state.pending_subscriptions.iter().cloned().collect();
6859 assert_eq!(pending, vec!["s1".to_string()]);
6860 });
6861 }
6862
6863 #[test]
6864 fn multiple_streams_assigned() {
6865 TOKIO_SHARED_RT.block_on(async {
6866 let ws = create_websocket_streams(None, None, None);
6867 ws.clone()
6868 .subscribe(vec!["s1".into(), "s2".into()], None, None)
6869 .await;
6870 let map = ws.connection_streams.lock().await;
6871 assert!(map.contains_key("s1"));
6872 assert!(map.contains_key("s2"));
6873 });
6874 }
6875
6876 #[test]
6877 fn existing_stream_not_reassigned() {
6878 TOKIO_SHARED_RT.block_on(async {
6879 let ws = create_websocket_streams(None, None, None);
6880 ws.clone().subscribe(vec!["s1".into()], None, None).await;
6881 let first_id = {
6882 let map = ws.connection_streams.lock().await;
6883 map.get("s1").unwrap().id.clone()
6884 };
6885 ws.clone()
6886 .subscribe(vec!["s1".into(), "s2".into()], None, None)
6887 .await;
6888 let map = ws.connection_streams.lock().await;
6889 let second_id = map.get("s1").unwrap().id.clone();
6890 assert_eq!(first_id, second_id);
6891 assert!(map.contains_key("s2"));
6892 });
6893 }
6894
6895 #[test]
6896 fn queue_when_not_ready_with_url_path() {
6897 TOKIO_SHARED_RT.block_on(async {
6898 let ws = create_websocket_streams(None, None, None);
6899
6900 let conn = ws.common.connection_pool[0].clone();
6901 {
6902 let mut st = conn.state.lock().await;
6903 st.ws_write_tx = None;
6904 st.url_path = Some("path1".to_string());
6905 st.reconnection_pending = false;
6906 st.close_initiated = false;
6907 }
6908
6909 ws.clone()
6910 .subscribe(vec!["s1".into()], None, Some("path1"))
6911 .await;
6912
6913 let state = conn.state.lock().await;
6914 let pending: Vec<String> =
6915 state.pending_subscriptions.iter().cloned().collect();
6916 assert_eq!(pending, vec!["s1".to_string()]);
6917 });
6918 }
6919
6920 #[test]
6921 fn only_one_subscription_per_stream_per_url_path() {
6922 TOKIO_SHARED_RT.block_on(async {
6923 let ws = create_websocket_streams(None, None, None);
6924
6925 let conn = ws.common.connection_pool[0].clone();
6926 {
6927 let mut st = conn.state.lock().await;
6928 st.ws_write_tx = None;
6929 st.url_path = Some("path1".to_string());
6930 st.reconnection_pending = false;
6931 st.close_initiated = false;
6932 }
6933
6934 ws.clone()
6935 .subscribe(vec!["s1".into()], None, Some("path1"))
6936 .await;
6937 ws.clone()
6938 .subscribe(vec!["s1".into()], None, Some("path1"))
6939 .await;
6940
6941 let state = conn.state.lock().await;
6942 let pending: Vec<String> =
6943 state.pending_subscriptions.iter().cloned().collect();
6944 assert_eq!(pending, vec!["s1".to_string()]);
6945 });
6946 }
6947
6948 #[test]
6949 fn same_stream_can_be_subscribed_on_different_url_paths() {
6950 TOKIO_SHARED_RT.block_on(async {
6951 let ws = create_websocket_streams(None, None, None);
6952
6953 let conn1 = ws.common.connection_pool[0].clone();
6954 let conn2 = ws.common.connection_pool[1].clone();
6955
6956 {
6957 let mut st1 = conn1.state.lock().await;
6958 st1.ws_write_tx = None;
6959 st1.url_path = Some("path1".to_string());
6960 st1.reconnection_pending = false;
6961 st1.close_initiated = false;
6962 }
6963 {
6964 let mut st2 = conn2.state.lock().await;
6965 st2.ws_write_tx = None;
6966 st2.url_path = Some("path2".to_string());
6967 st2.reconnection_pending = false;
6968 st2.close_initiated = false;
6969 }
6970
6971 ws.clone()
6972 .subscribe(vec!["s1".into()], None, Some("path1"))
6973 .await;
6974 ws.clone()
6975 .subscribe(vec!["s1".into()], None, Some("path2"))
6976 .await;
6977
6978 let map = ws.connection_streams.lock().await;
6979 assert!(map.contains_key("path1::s1"));
6980 assert!(map.contains_key("path2::s1"));
6981 });
6982 }
6983 }
6984
6985 mod unsubscribe {
6986 use super::*;
6987
6988 #[test]
6989 fn removes_stream_with_no_callbacks() {
6990 TOKIO_SHARED_RT.block_on(async {
6991 let ws = create_websocket_streams(None, None, None);
6992 let conn = ws.common.connection_pool[0].clone();
6993
6994 {
6995 let (tx, _rx) = unbounded_channel::<Message>();
6996 let mut st = conn.state.lock().await;
6997 st.ws_write_tx = Some(tx);
6998 }
6999
7000 {
7001 let mut map = ws.connection_streams.lock().await;
7002 map.insert("s1".to_string(), conn.clone());
7003 }
7004 {
7005 let mut st = conn.state.lock().await;
7006 st.stream_callbacks.insert("s1".to_string(), Vec::new());
7007 }
7008
7009 ws.unsubscribe(vec!["s1".to_string()], None, None).await;
7010
7011 assert!(!ws.connection_streams.lock().await.contains_key("s1"));
7012 assert!(!conn.state.lock().await.stream_callbacks.contains_key("s1"));
7013 });
7014 }
7015
7016 #[test]
7017 fn preserves_stream_with_callbacks() {
7018 TOKIO_SHARED_RT.block_on(async {
7019 let ws = create_websocket_streams(None, None, None);
7020 let conn = ws.common.connection_pool[1].clone();
7021
7022 {
7023 let mut map = ws.connection_streams.lock().await;
7024 map.insert("s2".to_string(), conn.clone());
7025 }
7026 {
7027 let mut state = conn.state.lock().await;
7028 state
7029 .stream_callbacks
7030 .insert("s2".to_string(), vec![Arc::new(|_: &Value| {})]);
7031 }
7032
7033 ws.unsubscribe(vec!["s2".to_string()], None, None).await;
7034
7035 assert!(ws.connection_streams.lock().await.contains_key("s2"));
7036 assert!(conn.state.lock().await.stream_callbacks.contains_key("s2"));
7037 });
7038 }
7039
7040 #[test]
7041 fn does_not_send_if_callbacks_exist() {
7042 TOKIO_SHARED_RT.block_on(async {
7043 let ws = create_websocket_streams(None, None, None);
7044 let conn = ws.common.connection_pool[0].clone();
7045 {
7046 let mut map = ws.connection_streams.lock().await;
7047 map.insert("s1".to_string(), conn.clone());
7048 }
7049 {
7050 let mut state = conn.state.lock().await;
7051 state.stream_callbacks.insert(
7052 "s1".to_string(),
7053 vec![Arc::new(|_: &Value| {}), Arc::new(|_: &Value| {})],
7054 );
7055 }
7056 ws.unsubscribe(vec!["s1".into()], None, None).await;
7057 assert!(ws.connection_streams.lock().await.contains_key("s1"));
7058 assert!(conn.state.lock().await.stream_callbacks.contains_key("s1"));
7059 });
7060 }
7061
7062 #[test]
7063 fn warns_if_not_associated() {
7064 TOKIO_SHARED_RT.block_on(async {
7065 let ws = create_websocket_streams(None, None, None);
7066 ws.unsubscribe(vec!["nope".into()], None, None).await;
7067 });
7068 }
7069
7070 #[test]
7071 fn empty_list_does_nothing() {
7072 TOKIO_SHARED_RT.block_on(async {
7073 let ws = create_websocket_streams(None, None, None);
7074 let before = ws.connection_streams.lock().await.len();
7075 ws.unsubscribe(Vec::<String>::new(), None, None).await;
7076 let after = ws.connection_streams.lock().await.len();
7077 assert_eq!(before, after);
7078 });
7079 }
7080
7081 #[test]
7082 fn invalid_custom_id_falls_back() {
7083 TOKIO_SHARED_RT.block_on(async {
7084 let ws = create_websocket_streams(None, None, None);
7085 let conn = ws.common.connection_pool[0].clone();
7086 {
7087 let mut map = ws.connection_streams.lock().await;
7088 map.insert("foo".to_string(), conn.clone());
7089 }
7090 {
7091 let mut state = conn.state.lock().await;
7092 let (tx, _rx) = unbounded_channel();
7093 state.ws_write_tx = Some(tx);
7094 state.stream_callbacks.insert("foo".to_string(), Vec::new());
7095 }
7096 ws.unsubscribe(
7097 vec!["foo".into()],
7098 Some(StreamId::Str("bad-id".into())),
7099 None,
7100 )
7101 .await;
7102 assert!(!ws.connection_streams.lock().await.contains_key("foo"));
7103 });
7104 }
7105
7106 #[test]
7107 fn removes_even_without_write_channel() {
7108 TOKIO_SHARED_RT.block_on(async {
7109 let ws = create_websocket_streams(None, None, None);
7110 let conn = ws.common.connection_pool[0].clone();
7111 {
7112 let mut map = ws.connection_streams.lock().await;
7113 map.insert("x".to_string(), conn.clone());
7114 }
7115 {
7116 let mut state = conn.state.lock().await;
7117 let (tx, _rx) = unbounded_channel();
7118 state.ws_write_tx = Some(tx);
7119 state.stream_callbacks.insert("x".to_string(), Vec::new());
7120 }
7121 ws.unsubscribe(vec!["x".into()], None, None).await;
7122 assert!(!ws.connection_streams.lock().await.contains_key("x"));
7123 });
7124 }
7125
7126 #[test]
7127 fn removes_stream_with_no_callbacks_with_url_path() {
7128 TOKIO_SHARED_RT.block_on(async {
7129 let ws = create_websocket_streams(None, None, None);
7130 let conn = ws.common.connection_pool[0].clone();
7131
7132 {
7133 let (tx, _rx) = unbounded_channel::<Message>();
7134 let mut st = conn.state.lock().await;
7135 st.ws_write_tx = Some(tx);
7136 st.url_path = Some("path1".to_string());
7137 }
7138
7139 {
7140 let mut map = ws.connection_streams.lock().await;
7141 map.insert("path1::s1".to_string(), conn.clone());
7142 }
7143 {
7144 let mut st = conn.state.lock().await;
7145 st.stream_callbacks
7146 .insert("path1::s1".to_string(), Vec::new());
7147 }
7148
7149 ws.unsubscribe(vec!["s1".to_string()], None, Some("path1"))
7150 .await;
7151
7152 assert!(!ws.connection_streams.lock().await.contains_key("path1::s1"));
7153 assert!(
7154 !conn
7155 .state
7156 .lock()
7157 .await
7158 .stream_callbacks
7159 .contains_key("path1::s1")
7160 );
7161 });
7162 }
7163
7164 #[test]
7165 fn preserves_stream_with_callbacks_with_url_path() {
7166 TOKIO_SHARED_RT.block_on(async {
7167 let ws = create_websocket_streams(None, None, None);
7168 let conn = ws.common.connection_pool[0].clone();
7169
7170 {
7171 let (tx, _rx) = unbounded_channel::<Message>();
7172 let mut st = conn.state.lock().await;
7173 st.ws_write_tx = Some(tx);
7174 st.url_path = Some("path1".to_string());
7175 }
7176
7177 {
7178 let mut map = ws.connection_streams.lock().await;
7179 map.insert("path1::s2".to_string(), conn.clone());
7180 }
7181 {
7182 let mut state = conn.state.lock().await;
7183 state
7184 .stream_callbacks
7185 .insert("path1::s2".to_string(), vec![Arc::new(|_: &Value| {})]);
7186 }
7187
7188 ws.unsubscribe(vec!["s2".to_string()], None, Some("path1"))
7189 .await;
7190
7191 assert!(ws.connection_streams.lock().await.contains_key("path1::s2"));
7192 assert!(
7193 conn.state
7194 .lock()
7195 .await
7196 .stream_callbacks
7197 .contains_key("path1::s2")
7198 );
7199 });
7200 }
7201
7202 #[test]
7203 fn url_path_mismatch_does_not_remove_other_path_subscription() {
7204 TOKIO_SHARED_RT.block_on(async {
7205 let ws = create_websocket_streams(None, None, None);
7206 let conn = ws.common.connection_pool[0].clone();
7207
7208 {
7209 let (tx, _rx) = unbounded_channel::<Message>();
7210 let mut st = conn.state.lock().await;
7211 st.ws_write_tx = Some(tx);
7212 st.url_path = Some("path1".to_string());
7213 }
7214
7215 {
7216 let mut map = ws.connection_streams.lock().await;
7217 map.insert("path1::s1".to_string(), conn.clone());
7218 }
7219 {
7220 let mut st = conn.state.lock().await;
7221 st.stream_callbacks
7222 .insert("path1::s1".to_string(), Vec::new());
7223 }
7224
7225 ws.unsubscribe(vec!["s1".to_string()], None, Some("path2"))
7226 .await;
7227
7228 assert!(ws.connection_streams.lock().await.contains_key("path1::s1"));
7229 assert!(
7230 conn.state
7231 .lock()
7232 .await
7233 .stream_callbacks
7234 .contains_key("path1::s1")
7235 );
7236 });
7237 }
7238 }
7239
7240 mod is_subscribed {
7241 use super::*;
7242
7243 #[test]
7244 fn returns_false_when_not_subscribed() {
7245 TOKIO_SHARED_RT.block_on(async {
7246 let ws = create_websocket_streams(None, None, None);
7247 assert!(!ws.is_subscribed("unknown").await);
7248 });
7249 }
7250
7251 #[test]
7252 fn returns_true_when_subscribed() {
7253 TOKIO_SHARED_RT.block_on(async {
7254 let ws = create_websocket_streams(None, None, None);
7255 let conn = ws.common.connection_pool[0].clone();
7256 {
7257 let mut map = ws.connection_streams.lock().await;
7258 map.insert("stream1".to_string(), conn);
7259 }
7260 assert!(ws.is_subscribed("stream1").await);
7261 });
7262 }
7263
7264 #[test]
7265 fn returns_true_when_subscribed_with_url_path_key() {
7266 TOKIO_SHARED_RT.block_on(async {
7267 let ws = create_websocket_streams(None, None, None);
7268 let conn = ws.common.connection_pool[0].clone();
7269 {
7270 let mut map = ws.connection_streams.lock().await;
7271 map.insert("path1::stream1".to_string(), conn);
7272 }
7273 assert!(ws.is_subscribed("stream1").await);
7274 });
7275 }
7276
7277 #[test]
7278 fn returns_true_when_same_stream_subscribed_on_multiple_paths() {
7279 TOKIO_SHARED_RT.block_on(async {
7280 let ws = create_websocket_streams(None, None, None);
7281 let conn1 = ws.common.connection_pool[0].clone();
7282 let conn2 = ws.common.connection_pool[1].clone();
7283 {
7284 let mut map = ws.connection_streams.lock().await;
7285 map.insert("path1::stream1".to_string(), conn1);
7286 map.insert("path2::stream1".to_string(), conn2);
7287 }
7288 assert!(ws.is_subscribed("stream1").await);
7289 });
7290 }
7291
7292 #[test]
7293 fn returns_false_when_only_similar_suffix_exists() {
7294 TOKIO_SHARED_RT.block_on(async {
7295 let ws = create_websocket_streams(None, None, None);
7296 let conn = ws.common.connection_pool[0].clone();
7297 {
7298 let mut map = ws.connection_streams.lock().await;
7299 map.insert("path1::stream10".to_string(), conn);
7300 }
7301 assert!(!ws.is_subscribed("stream1").await);
7302 });
7303 }
7304 }
7305
7306 mod stream_key {
7307 use super::*;
7308
7309 #[test]
7310 fn stream_key_without_url_path_returns_stream() {
7311 TOKIO_SHARED_RT.block_on(async {
7312 let ws = create_websocket_streams(None, None, None);
7313 assert_eq!(ws.stream_key("s1", None), "s1");
7314 });
7315 }
7316
7317 #[test]
7318 fn stream_key_with_empty_url_path_returns_stream() {
7319 TOKIO_SHARED_RT.block_on(async {
7320 let ws = create_websocket_streams(None, None, None);
7321 assert_eq!(ws.stream_key("s1", Some("")), "s1");
7322 });
7323 }
7324
7325 #[test]
7326 fn stream_key_with_url_path_prefixes_stream() {
7327 TOKIO_SHARED_RT.block_on(async {
7328 let ws = create_websocket_streams(None, None, None);
7329 assert_eq!(ws.stream_key("s1", Some("path1")), "path1::s1");
7330 });
7331 }
7332
7333 #[test]
7334 fn stream_key_distinguishes_paths() {
7335 TOKIO_SHARED_RT.block_on(async {
7336 let ws = create_websocket_streams(None, None, None);
7337 assert_eq!(ws.stream_key("s1", Some("path1")), "path1::s1");
7338 assert_eq!(ws.stream_key("s1", Some("path2")), "path2::s1");
7339 });
7340 }
7341 }
7342
7343 mod prepare_url {
7344 use super::*;
7345
7346 #[test]
7347 fn without_time_unit_returns_base_url() {
7348 TOKIO_SHARED_RT.block_on(async {
7349 let conns = vec![
7350 WebsocketConnection::new("c1"),
7351 WebsocketConnection::new("c2"),
7352 ];
7353 let config = ConfigurationWebsocketStreams {
7354 ws_url: Some("wss://example".to_string()),
7355 mode: WebsocketMode::Single,
7356 reconnect_delay: 100,
7357 time_unit: None,
7358 agent: None,
7359 user_agent: build_user_agent("product"),
7360 };
7361 let ws = WebsocketStreams::new(config, conns, vec![]);
7362 let url = ws.prepare_url(&["s1".into(), "s2".into()], None);
7363 assert_eq!(url, "wss://example/stream?streams=s1/s2");
7364 });
7365 }
7366
7367 #[test]
7368 fn with_time_unit_appends_parameter() {
7369 TOKIO_SHARED_RT.block_on(async {
7370 let conns = vec![WebsocketConnection::new("c1")];
7371 let config = ConfigurationWebsocketStreams {
7372 ws_url: Some("wss://example".to_string()),
7373 mode: WebsocketMode::Single,
7374 reconnect_delay: 100,
7375 time_unit: Some(TimeUnit::Millisecond),
7376 agent: None,
7377 user_agent: build_user_agent("product"),
7378 };
7379 let ws = WebsocketStreams::new(config, conns, vec![]);
7380 let url = ws.prepare_url(&["a".into()], None);
7381 assert_eq!(url, "wss://example/stream?streams=a&timeUnit=millisecond");
7382 });
7383 }
7384
7385 #[test]
7386 fn multiple_streams_and_time_unit() {
7387 TOKIO_SHARED_RT.block_on(async {
7388 let conns = vec![WebsocketConnection::new("c1")];
7389 let config = ConfigurationWebsocketStreams {
7390 ws_url: Some("wss://example".to_string()),
7391 mode: WebsocketMode::Single,
7392 reconnect_delay: 100,
7393 time_unit: Some(TimeUnit::Microsecond),
7394 agent: None,
7395 user_agent: build_user_agent("product"),
7396 };
7397 let ws = WebsocketStreams::new(config, conns, vec![]);
7398 let url = ws.prepare_url(&["x".into(), "y".into(), "z".into()], None);
7399 assert_eq!(
7400 url,
7401 "wss://example/stream?streams=x/y/z&timeUnit=microsecond"
7402 );
7403 });
7404 }
7405
7406 #[test]
7407 fn with_url_path_prefixes_base_url() {
7408 TOKIO_SHARED_RT.block_on(async {
7409 let conns = vec![WebsocketConnection::new("c1")];
7410 let config = ConfigurationWebsocketStreams {
7411 ws_url: Some("wss://example".to_string()),
7412 mode: WebsocketMode::Single,
7413 reconnect_delay: 100,
7414 time_unit: None,
7415 agent: None,
7416 user_agent: build_user_agent("product"),
7417 };
7418 let ws = WebsocketStreams::new(config, conns, vec![]);
7419 let url = ws.prepare_url(["s1".into()].as_ref(), Some("path1"));
7420 assert_eq!(url, "wss://example/path1/stream?streams=s1");
7421 });
7422 }
7423
7424 #[test]
7425 fn with_url_path_and_time_unit_appends_parameter() {
7426 TOKIO_SHARED_RT.block_on(async {
7427 let conns = vec![WebsocketConnection::new("c1")];
7428 let config = ConfigurationWebsocketStreams {
7429 ws_url: Some("wss://example".to_string()),
7430 mode: WebsocketMode::Single,
7431 reconnect_delay: 100,
7432 time_unit: Some(TimeUnit::Millisecond),
7433 agent: None,
7434 user_agent: build_user_agent("product"),
7435 };
7436 let ws = WebsocketStreams::new(config, conns, vec![]);
7437 let url = ws.prepare_url(["a".into()].as_ref(), Some("path1"));
7438 assert_eq!(
7439 url,
7440 "wss://example/path1/stream?streams=a&timeUnit=millisecond"
7441 );
7442 });
7443 }
7444
7445 #[test]
7446 fn url_path_distinguishes_urls_for_same_streams() {
7447 TOKIO_SHARED_RT.block_on(async {
7448 let conns = vec![WebsocketConnection::new("c1")];
7449 let config = ConfigurationWebsocketStreams {
7450 ws_url: Some("wss://example".to_string()),
7451 mode: WebsocketMode::Single,
7452 reconnect_delay: 100,
7453 time_unit: None,
7454 agent: None,
7455 user_agent: build_user_agent("product"),
7456 };
7457 let ws = WebsocketStreams::new(config, conns, vec![]);
7458 let u1 = ws.prepare_url(["s1".into()].as_ref(), Some("path1"));
7459 let u2 = ws.prepare_url(["s1".into()].as_ref(), Some("path2"));
7460 assert_eq!(u1, "wss://example/path1/stream?streams=s1");
7461 assert_eq!(u2, "wss://example/path2/stream?streams=s1");
7462 });
7463 }
7464 }
7465
7466 mod handle_stream_assignment {
7467 use super::*;
7468
7469 #[test]
7470 fn assigns_new_streams_to_connections() {
7471 TOKIO_SHARED_RT.block_on(async {
7472 let ws = create_websocket_streams(None, None, None);
7473 let groups = ws
7474 .clone()
7475 .handle_stream_assignment(vec!["s1".into(), "s2".into()], None)
7476 .await;
7477 let mut seen_streams = HashSet::new();
7478 for (_conn, streams) in &groups {
7479 for s in streams {
7480 seen_streams.insert(s);
7481 }
7482 }
7483 assert_eq!(
7484 seen_streams,
7485 ["s1".to_string(), "s2".to_string()].iter().collect()
7486 );
7487 assert_eq!(groups.len(), 1);
7488 });
7489 }
7490
7491 #[test]
7492 fn reuses_existing_connection_for_duplicate_stream() {
7493 TOKIO_SHARED_RT.block_on(async {
7494 let ws = create_websocket_streams(None, None, None);
7495 let _ = ws
7496 .clone()
7497 .handle_stream_assignment(vec!["s1".into()], None)
7498 .await;
7499 let groups = ws
7500 .clone()
7501 .handle_stream_assignment(vec!["s1".into(), "s3".into()], None)
7502 .await;
7503 let mut all_streams = Vec::new();
7504 for (_conn, streams) in groups {
7505 all_streams.extend(streams);
7506 }
7507 all_streams.sort();
7508 assert_eq!(all_streams, vec!["s1".to_string(), "s3".to_string()]);
7509 });
7510 }
7511
7512 #[test]
7513 fn empty_stream_list_returns_empty() {
7514 TOKIO_SHARED_RT.block_on(async {
7515 let ws = create_websocket_streams(None, None, None);
7516 let groups = ws.clone().handle_stream_assignment(vec![], None).await;
7517 assert!(groups.is_empty());
7518 });
7519 }
7520
7521 #[test]
7522 fn closed_or_reconnecting_forces_reassignment_of_stream() {
7523 TOKIO_SHARED_RT.block_on(async {
7524 let ws = create_websocket_streams(None, None, None);
7525 let mut groups = ws
7526 .clone()
7527 .handle_stream_assignment(vec!["s1".into()], None)
7528 .await;
7529 let (conn, _) = groups.pop().unwrap();
7530 {
7531 let mut st = conn.state.lock().await;
7532 st.close_initiated = true;
7533 }
7534 let groups2 = ws
7535 .clone()
7536 .handle_stream_assignment(vec!["s2".into()], None)
7537 .await;
7538 assert_eq!(groups2.len(), 1);
7539 let (_new_conn, streams) = &groups2[0];
7540 assert_eq!(streams, &vec!["s2".to_string()]);
7541 });
7542 }
7543
7544 #[test]
7545 fn no_available_connections_falls_back_to_one() {
7546 TOKIO_SHARED_RT.block_on(async {
7547 let ws = create_websocket_streams(None, Some(vec![]), None);
7548 let assigned = ws.handle_stream_assignment(vec!["foo".into()], None).await;
7549 assert_eq!(assigned.len(), 1);
7550 let (_conn, streams) = &assigned[0];
7551 assert_eq!(streams.as_slice(), &["foo".to_string()]);
7552 });
7553 }
7554
7555 #[test]
7556 fn single_connection_groups_multiple_streams() {
7557 TOKIO_SHARED_RT.block_on(async {
7558 let conn = WebsocketConnection::new("c1");
7559 let ws = create_websocket_streams(None, Some(vec![conn.clone()]), None);
7560 let assigned = ws
7561 .handle_stream_assignment(vec!["s1".into(), "s2".into()], None)
7562 .await;
7563 assert_eq!(assigned.len(), 1);
7564 let (assigned_conn, streams) = &assigned[0];
7565 assert!(Arc::ptr_eq(assigned_conn, &conn));
7566 assert_eq!(streams.len(), 2);
7567 assert!(streams.contains(&"s1".to_string()));
7568 assert!(streams.contains(&"s2".to_string()));
7569 });
7570 }
7571
7572 #[test]
7573 fn reuse_existing_healthy_connection() {
7574 TOKIO_SHARED_RT.block_on(async {
7575 let conn = WebsocketConnection::new("c");
7576 let ws = create_websocket_streams(None, Some(vec![conn.clone()]), None);
7577 let _ = ws.handle_stream_assignment(vec!["s1".into()], None).await;
7578 let second = ws.handle_stream_assignment(vec!["s1".into()], None).await;
7579 assert_eq!(second.len(), 1);
7580 let (assigned_conn, streams) = &second[0];
7581 assert!(Arc::ptr_eq(assigned_conn, &conn));
7582 assert_eq!(streams.as_slice(), &["s1".to_string()]);
7583 });
7584 }
7585
7586 #[test]
7587 fn mix_new_and_assigned_streams() {
7588 TOKIO_SHARED_RT.block_on(async {
7589 let conn = WebsocketConnection::new("c");
7590 let ws = create_websocket_streams(None, Some(vec![conn.clone()]), None);
7591 let _ = ws
7592 .handle_stream_assignment(vec!["s1".into(), "s2".into()], None)
7593 .await;
7594 let mixed = ws
7595 .handle_stream_assignment(vec!["s2".into(), "s3".into()], None)
7596 .await;
7597 assert_eq!(mixed.len(), 1);
7598 let (assigned_conn, streams) = &mixed[0];
7599 assert!(Arc::ptr_eq(assigned_conn, &conn));
7600 let mut got = streams.clone();
7601 got.sort();
7602 assert_eq!(got, vec!["s2".to_string(), "s3".to_string()]);
7603 });
7604 }
7605
7606 #[test]
7607 fn assigns_streams_with_url_path_keys() {
7608 TOKIO_SHARED_RT.block_on(async {
7609 let ws = create_websocket_streams(None, None, None);
7610
7611 let conn = ws.common.connection_pool[0].clone();
7612 {
7613 let mut st = conn.state.lock().await;
7614 st.url_path = Some("path1".to_string());
7615 st.ws_write_tx = None;
7616 st.reconnection_pending = false;
7617 st.close_initiated = false;
7618 }
7619
7620 let groups = ws
7621 .handle_stream_assignment(vec!["s1".into(), "s2".into()], Some("path1"))
7622 .await;
7623
7624 let map = ws.connection_streams.lock().await;
7625 assert!(map.contains_key("path1::s1"));
7626 assert!(map.contains_key("path1::s2"));
7627 assert_eq!(groups.len(), 1);
7628
7629 let (_assigned_conn, streams) = &groups[0];
7630 let mut got = streams.clone();
7631 got.sort();
7632 assert_eq!(got, vec!["s1".to_string(), "s2".to_string()]);
7633 });
7634 }
7635
7636 #[test]
7637 fn same_stream_on_different_paths_creates_distinct_keys() {
7638 TOKIO_SHARED_RT.block_on(async {
7639 let ws = create_websocket_streams(None, None, None);
7640
7641 let conn1 = ws.common.connection_pool[0].clone();
7642 let conn2 = ws.common.connection_pool[1].clone();
7643
7644 {
7645 let mut st = conn1.state.lock().await;
7646 st.url_path = Some("path1".to_string());
7647 st.ws_write_tx = None;
7648 st.reconnection_pending = false;
7649 st.close_initiated = false;
7650 }
7651 {
7652 let mut st = conn2.state.lock().await;
7653 st.url_path = Some("path2".to_string());
7654 st.ws_write_tx = None;
7655 st.reconnection_pending = false;
7656 st.close_initiated = false;
7657 }
7658
7659 let g1 = ws
7660 .handle_stream_assignment(vec!["s1".into()], Some("path1"))
7661 .await;
7662 let g2 = ws
7663 .handle_stream_assignment(vec!["s1".into()], Some("path2"))
7664 .await;
7665
7666 assert_eq!(g1.len(), 1);
7667 assert_eq!(g2.len(), 1);
7668
7669 let map = ws.connection_streams.lock().await;
7670 assert!(map.contains_key("path1::s1"));
7671 assert!(map.contains_key("path2::s1"));
7672 });
7673 }
7674
7675 #[test]
7676 fn reuses_existing_connection_for_same_path_and_stream() {
7677 TOKIO_SHARED_RT.block_on(async {
7678 let ws = create_websocket_streams(None, None, None);
7679
7680 let conn = ws.common.connection_pool[0].clone();
7681 {
7682 let mut st = conn.state.lock().await;
7683 st.url_path = Some("path1".to_string());
7684 st.ws_write_tx = None;
7685 st.reconnection_pending = false;
7686 st.close_initiated = false;
7687 }
7688
7689 let first = ws
7690 .handle_stream_assignment(vec!["s1".into()], Some("path1"))
7691 .await;
7692 let second = ws
7693 .handle_stream_assignment(vec!["s1".into(), "s2".into()], Some("path1"))
7694 .await;
7695
7696 assert_eq!(first.len(), 1);
7697 assert_eq!(second.len(), 1);
7698
7699 let map = ws.connection_streams.lock().await;
7700 let c1 = map.get("path1::s1").unwrap().clone();
7701 let c2 = map.get("path1::s2").unwrap().clone();
7702 assert!(Arc::ptr_eq(&c1, &c2));
7703 });
7704 }
7705
7706 #[test]
7707 fn closed_or_reconnecting_forces_reassignment_with_url_path() {
7708 TOKIO_SHARED_RT.block_on(async {
7709 let ws = create_websocket_streams(None, None, None);
7710
7711 let conn1 = ws.common.connection_pool[0].clone();
7712 let conn2 = ws.common.connection_pool[1].clone();
7713
7714 {
7715 let mut st = conn1.state.lock().await;
7716 st.url_path = Some("path1".to_string());
7717 st.ws_write_tx = None;
7718 st.reconnection_pending = false;
7719 st.close_initiated = false;
7720 }
7721 {
7722 let mut st = conn2.state.lock().await;
7723 st.url_path = Some("path1".to_string());
7724 st.ws_write_tx = None;
7725 st.reconnection_pending = false;
7726 st.close_initiated = false;
7727 }
7728
7729 let _ = ws
7730 .handle_stream_assignment(vec!["s1".into()], Some("path1"))
7731 .await;
7732
7733 {
7734 let mut st = conn1.state.lock().await;
7735 st.close_initiated = true;
7736 }
7737
7738 let _ = ws
7739 .handle_stream_assignment(vec!["s1".into()], Some("path1"))
7740 .await;
7741
7742 let map = ws.connection_streams.lock().await;
7743 let assigned = map.get("path1::s1").unwrap().clone();
7744 assert!(!Arc::ptr_eq(&assigned, &conn1));
7745 });
7746 }
7747 }
7748
7749 mod send_subscription_payload {
7750 use super::*;
7751
7752 #[test]
7753 fn subscribe_payload_with_custom_id_fallbacks_if_invalid() {
7754 TOKIO_SHARED_RT.block_on(async {
7755 let ws: Arc<WebsocketStreams> =
7756 create_websocket_streams(Some("ws://example.com"), None, None);
7757 let conn = &ws.common.connection_pool[0];
7758 let (tx, mut rx) = unbounded_channel();
7759 {
7760 let mut st = conn.state.lock().await;
7761 st.ws_write_tx = Some(tx);
7762 }
7763 let id = Some("badid".to_string());
7764 ws.send_subscription_payload(
7765 conn,
7766 &vec!["s1".to_string()],
7767 id.map(StreamId::from),
7768 );
7769 let msg = rx.recv().await.expect("no message sent");
7770 if let Message::Text(txt) = msg {
7771 let v: serde_json::Value = serde_json::from_str(&txt).unwrap();
7772 assert_eq!(v["method"], "SUBSCRIBE");
7773 let id = v["id"].as_str().unwrap();
7774 assert_ne!(id, "badid");
7775 assert!(Regex::new(r"^[0-9a-fA-F]{32}$").unwrap().is_match(id));
7776 } else {
7777 panic!("unexpected message: {msg:?}");
7778 }
7779 });
7780 }
7781
7782 #[test]
7783 fn subscribe_payload_with_and_without_custom_string_id() {
7784 TOKIO_SHARED_RT.block_on(async {
7785 let ws: Arc<WebsocketStreams> =
7786 create_websocket_streams(Some("ws://unused"), None, None);
7787 let conn = &ws.common.connection_pool[0];
7788 let (tx, mut rx) = unbounded_channel();
7789 {
7790 let mut st = conn.state.lock().await;
7791 st.ws_write_tx = Some(tx);
7792 }
7793 let id = Some("deadbeefdeadbeefdeadbeefdeadbeef".to_string());
7794 ws.send_subscription_payload(
7795 conn,
7796 &vec!["a".to_string(), "b".to_string()],
7797 id.map(StreamId::from),
7798 );
7799 let msg1 = rx.recv().await.unwrap();
7800 ws.send_subscription_payload(conn, &vec!["x".to_string()], None);
7801 let msg2 = rx.recv().await.unwrap();
7802
7803 if let Message::Text(txt1) = msg1 {
7804 let v1: serde_json::Value = serde_json::from_str(&txt1).unwrap();
7805 assert_eq!(v1["id"], "deadbeefdeadbeefdeadbeefdeadbeef");
7806 assert_eq!(
7807 v1["params"].as_array().unwrap(),
7808 &vec![serde_json::json!("a"), serde_json::json!("b")]
7809 );
7810 } else {
7811 panic!()
7812 }
7813
7814 if let Message::Text(txt2) = msg2 {
7815 let v2: serde_json::Value = serde_json::from_str(&txt2).unwrap();
7816 assert_eq!(v2["method"], "SUBSCRIBE");
7817 let params = v2["params"].as_array().unwrap();
7818 assert_eq!(params.len(), 1);
7819 assert_eq!(params[0], "x");
7820 let id2 = v2["id"].as_str().unwrap();
7821 assert!(Regex::new(r"^[0-9a-fA-F]{32}$").unwrap().is_match(id2));
7822 } else {
7823 panic!()
7824 }
7825 });
7826 }
7827
7828 #[test]
7829 fn subscribe_payload_with_and_without_custom_integer_id() {
7830 TOKIO_SHARED_RT.block_on(async {
7831 let ws: Arc<WebsocketStreams> =
7832 create_websocket_streams(Some("ws://unused"), None, None);
7833 ws.stream_id_is_strictly_number
7834 .store(true, Ordering::Relaxed);
7835 let conn = &ws.common.connection_pool[0];
7836 let (tx, mut rx) = unbounded_channel();
7837 {
7838 let mut st = conn.state.lock().await;
7839 st.ws_write_tx = Some(tx);
7840 }
7841
7842 let id = Some(123u32);
7843
7844 ws.send_subscription_payload(
7845 conn,
7846 &vec!["a".to_string(), "b".to_string()],
7847 id.map(StreamId::from),
7848 );
7849 let msg1 = rx.recv().await.unwrap();
7850
7851 ws.send_subscription_payload(conn, &vec!["x".to_string()], None);
7852 let msg2 = rx.recv().await.unwrap();
7853
7854 if let Message::Text(txt1) = msg1 {
7855 let v1: serde_json::Value = serde_json::from_str(&txt1).unwrap();
7856 assert_eq!(v1["method"], "SUBSCRIBE");
7857 assert_eq!(v1["id"].as_u64(), Some(123));
7858 assert_eq!(
7859 v1["params"].as_array().unwrap(),
7860 &vec![serde_json::json!("a"), serde_json::json!("b")]
7861 );
7862 } else {
7863 panic!("Expected Message::Text for msg1");
7864 }
7865
7866 if let Message::Text(txt2) = msg2 {
7867 let v2: serde_json::Value = serde_json::from_str(&txt2).unwrap();
7868 assert_eq!(v2["method"], "SUBSCRIBE");
7869
7870 let params = v2["params"].as_array().unwrap();
7871 assert_eq!(params.len(), 1);
7872 assert_eq!(params[0], "x");
7873
7874 let id2 = v2.get("id").expect("payload should contain id");
7875 assert!(
7876 id2.is_number(),
7877 "expected numeric id in strict-number mode, got: {id2:?}"
7878 );
7879 let n = id2.as_u64().unwrap();
7880 assert!(u32::try_from(n).is_ok(), "id should fit u32, got {n}");
7881 } else {
7882 panic!("Expected Message::Text for msg2");
7883 }
7884 });
7885 }
7886 }
7887
7888 mod on_open {
7889 use super::*;
7890
7891 #[test]
7892 fn sends_pending_subscriptions() {
7893 TOKIO_SHARED_RT.block_on(async {
7894 let ws: Arc<WebsocketStreams> =
7895 create_websocket_streams(Some("ws://example.com"), None, None);
7896 let conn = &ws.common.connection_pool[0];
7897 let (tx, mut rx) = unbounded_channel();
7898 {
7899 let mut st = conn.state.lock().await;
7900 st.ws_write_tx = Some(tx);
7901 st.pending_subscriptions.push_back("foo".to_string());
7902 st.pending_subscriptions.push_back("bar".to_string());
7903 }
7904 ws.on_open("ws://example.com".to_string(), conn.clone())
7905 .await;
7906 let msg = rx.recv().await.expect("no subscription sent");
7907 if let Message::Text(txt) = msg {
7908 let v: Value = serde_json::from_str(&txt).unwrap();
7909 assert_eq!(v["method"], "SUBSCRIBE");
7910 let params = v["params"].as_array().unwrap();
7911 assert_eq!(
7912 params,
7913 &vec![Value::String("foo".into()), Value::String("bar".into())]
7914 );
7915 } else {
7916 panic!("unexpected message: {msg:?}");
7917 }
7918 let st_after = conn.state.lock().await;
7919 assert!(st_after.pending_subscriptions.is_empty());
7920 });
7921 }
7922
7923 #[test]
7924 fn with_no_pending_subscriptions_sends_nothing() {
7925 TOKIO_SHARED_RT.block_on(async {
7926 let ws: Arc<WebsocketStreams> =
7927 create_websocket_streams(Some("ws://example.com"), None, None);
7928 let conn = &ws.common.connection_pool[0];
7929 let (tx, mut rx) = unbounded_channel();
7930 {
7931 let mut st = conn.state.lock().await;
7932 st.ws_write_tx = Some(tx);
7933 }
7934 ws.on_open("ws://example.com".to_string(), conn.clone())
7935 .await;
7936 assert!(rx.try_recv().is_err(), "unexpected message sent");
7937 });
7938 }
7939
7940 #[test]
7941 fn clears_pending_without_write_channel() {
7942 TOKIO_SHARED_RT.block_on(async {
7943 let ws: Arc<WebsocketStreams> =
7944 create_websocket_streams(Some("ws://example.com"), None, None);
7945 let conn = &ws.common.connection_pool[0];
7946 {
7947 let mut st = conn.state.lock().await;
7948 st.pending_subscriptions.push_back("solo".to_string());
7949 }
7950 ws.on_open("ws://example.com".to_string(), conn.clone())
7951 .await;
7952 let st_after = conn.state.lock().await;
7953 assert!(st_after.pending_subscriptions.is_empty());
7954 });
7955 }
7956 }
7957
7958 mod on_message {
7959 use super::*;
7960
7961 #[test]
7962 fn invokes_registered_callback() {
7963 TOKIO_SHARED_RT.block_on(async {
7964 let ws: Arc<WebsocketStreams> =
7965 create_websocket_streams(Some("ws://example.com"), None, None);
7966 let conn = &ws.common.connection_pool[0];
7967 let called = Arc::new(AtomicBool::new(false));
7968 let called_clone = called.clone();
7969
7970 {
7971 let mut st = conn.state.lock().await;
7972 st.stream_callbacks
7973 .entry("stream1".to_string())
7974 .or_default()
7975 .push(
7976 (Box::new(move |_: &Value| {
7977 called_clone.store(true, Ordering::SeqCst);
7978 })
7979 as Box<dyn Fn(&Value) + Send + Sync>)
7980 .into(),
7981 );
7982 }
7983
7984 let msg = json!({
7985 "stream": "stream1",
7986 "data": { "key": "value" }
7987 })
7988 .to_string();
7989
7990 ws.on_message(msg, conn.clone()).await;
7991
7992 assert!(called.load(Ordering::SeqCst));
7993 });
7994 }
7995
7996 #[test]
7997 fn invokes_all_registered_callbacks() {
7998 TOKIO_SHARED_RT.block_on(async {
7999 let ws: Arc<WebsocketStreams> =
8000 create_websocket_streams(Some("ws://example.com"), None, None);
8001 let conn = &ws.common.connection_pool[0];
8002 let counter = Arc::new(AtomicUsize::new(0));
8003
8004 {
8005 let mut st = conn.state.lock().await;
8006 let entry = st.stream_callbacks.entry("s".into()).or_default();
8007 let c1 = counter.clone();
8008 entry.push(
8009 (Box::new(move |_: &Value| {
8010 c1.fetch_add(1, Ordering::SeqCst);
8011 }) as Box<dyn Fn(&Value) + Send + Sync>)
8012 .into(),
8013 );
8014 let c2 = counter.clone();
8015 entry.push(
8016 (Box::new(move |_: &Value| {
8017 c2.fetch_add(1, Ordering::SeqCst);
8018 }) as Box<dyn Fn(&Value) + Send + Sync>)
8019 .into(),
8020 );
8021 }
8022
8023 let msg = json!({"stream":"s","data":42}).to_string();
8024 ws.on_message(msg, conn.clone()).await;
8025
8026 assert_eq!(counter.load(Ordering::SeqCst), 2);
8027 });
8028 }
8029
8030 #[test]
8031 fn handles_null_data_field() {
8032 TOKIO_SHARED_RT.block_on(async {
8033 let ws: Arc<WebsocketStreams> =
8034 create_websocket_streams(Some("ws://example.com"), None, None);
8035 let conn = &ws.common.connection_pool[0];
8036 let called = Arc::new(AtomicUsize::new(0));
8037 {
8038 let mut st = conn.state.lock().await;
8039 st.stream_callbacks.entry("n".into()).or_default().push(
8040 (Box::new({
8041 let c = called.clone();
8042 move |data: &Value| {
8043 if data.is_null() {
8044 c.fetch_add(1, Ordering::SeqCst);
8045 }
8046 }
8047 }) as Box<dyn Fn(&Value) + Send + Sync>)
8048 .into(),
8049 );
8050 }
8051 let msg = json!({"stream":"n","data":null}).to_string();
8052 ws.on_message(msg, conn.clone()).await;
8053 assert_eq!(called.load(Ordering::SeqCst), 1);
8054 });
8055 }
8056
8057 #[test]
8058 fn with_invalid_json_does_not_panic() {
8059 TOKIO_SHARED_RT.block_on(async {
8060 let ws: Arc<WebsocketStreams> =
8061 create_websocket_streams(Some("ws://example.com"), None, None);
8062 let conn = &ws.common.connection_pool[0];
8063 let bad = "not a json";
8064 ws.on_message(bad.to_string(), conn.clone()).await;
8065 });
8066 }
8067
8068 #[test]
8069 fn without_stream_field_does_nothing() {
8070 TOKIO_SHARED_RT.block_on(async {
8071 let ws: Arc<WebsocketStreams> =
8072 create_websocket_streams(Some("ws://example.com"), None, None);
8073 let conn = &ws.common.connection_pool[0];
8074 let msg = json!({ "data": { "foo": 1 } }).to_string();
8075 ws.on_message(msg, conn.clone()).await;
8076 });
8077 }
8078
8079 #[test]
8080 fn with_unregistered_stream_does_not_panic() {
8081 TOKIO_SHARED_RT.block_on(async {
8082 let ws: Arc<WebsocketStreams> =
8083 create_websocket_streams(Some("ws://example.com"), None, None);
8084 let conn = &ws.common.connection_pool[0];
8085 let msg = json!({
8086 "stream": "nope",
8087 "data": { "foo": 1 }
8088 })
8089 .to_string();
8090 ws.on_message(msg, conn.clone()).await;
8091 });
8092 }
8093
8094 #[test]
8095 fn invokes_registered_callback_with_url_path_key() {
8096 TOKIO_SHARED_RT.block_on(async {
8097 let ws: Arc<WebsocketStreams> =
8098 create_websocket_streams(Some("ws://example.com"), None, None);
8099 let conn = &ws.common.connection_pool[0];
8100
8101 {
8102 let mut st = conn.state.lock().await;
8103 st.url_path = Some("path1".to_string());
8104 }
8105
8106 let called = Arc::new(AtomicBool::new(false));
8107 let called_clone = called.clone();
8108
8109 {
8110 let mut st = conn.state.lock().await;
8111 st.stream_callbacks
8112 .entry("path1::stream1".to_string())
8113 .or_default()
8114 .push(
8115 (Box::new(move |_: &Value| {
8116 called_clone.store(true, Ordering::SeqCst);
8117 })
8118 as Box<dyn Fn(&Value) + Send + Sync>)
8119 .into(),
8120 );
8121 }
8122
8123 let msg = json!({
8124 "stream": "stream1",
8125 "data": { "key": "value" }
8126 })
8127 .to_string();
8128
8129 ws.on_message(msg, conn.clone()).await;
8130
8131 assert!(called.load(Ordering::SeqCst));
8132 });
8133 }
8134
8135 #[test]
8136 fn does_not_invoke_callback_when_url_path_mismatch() {
8137 TOKIO_SHARED_RT.block_on(async {
8138 let ws: Arc<WebsocketStreams> =
8139 create_websocket_streams(Some("ws://example.com"), None, None);
8140 let conn = &ws.common.connection_pool[0];
8141
8142 {
8143 let mut st = conn.state.lock().await;
8144 st.url_path = Some("path2".to_string());
8145 }
8146
8147 let called = Arc::new(AtomicBool::new(false));
8148 let called_clone = called.clone();
8149
8150 {
8151 let mut st = conn.state.lock().await;
8152 st.stream_callbacks
8153 .entry("path1::stream1".to_string())
8154 .or_default()
8155 .push(
8156 (Box::new(move |_: &Value| {
8157 called_clone.store(true, Ordering::SeqCst);
8158 })
8159 as Box<dyn Fn(&Value) + Send + Sync>)
8160 .into(),
8161 );
8162 }
8163
8164 let msg = json!({
8165 "stream": "stream1",
8166 "data": { "key": "value" }
8167 })
8168 .to_string();
8169
8170 ws.on_message(msg, conn.clone()).await;
8171
8172 assert!(!called.load(Ordering::SeqCst));
8173 });
8174 }
8175
8176 #[test]
8177 fn invokes_only_callbacks_for_current_url_path_when_both_exist() {
8178 TOKIO_SHARED_RT.block_on(async {
8179 let ws: Arc<WebsocketStreams> =
8180 create_websocket_streams(Some("ws://example.com"), None, None);
8181 let conn = &ws.common.connection_pool[0];
8182
8183 {
8184 let mut st = conn.state.lock().await;
8185 st.url_path = Some("path1".to_string());
8186 }
8187
8188 let c1 = Arc::new(AtomicUsize::new(0));
8189 let c2 = Arc::new(AtomicUsize::new(0));
8190
8191 {
8192 let mut st = conn.state.lock().await;
8193
8194 let a = c1.clone();
8195 st.stream_callbacks
8196 .entry("path1::s".to_string())
8197 .or_default()
8198 .push(
8199 (Box::new(move |_: &Value| {
8200 a.fetch_add(1, Ordering::SeqCst);
8201 })
8202 as Box<dyn Fn(&Value) + Send + Sync>)
8203 .into(),
8204 );
8205
8206 let b = c2.clone();
8207 st.stream_callbacks
8208 .entry("path2::s".to_string())
8209 .or_default()
8210 .push(
8211 (Box::new(move |_: &Value| {
8212 b.fetch_add(1, Ordering::SeqCst);
8213 })
8214 as Box<dyn Fn(&Value) + Send + Sync>)
8215 .into(),
8216 );
8217 }
8218
8219 let msg = json!({"stream":"s","data":42}).to_string();
8220 ws.on_message(msg, conn.clone()).await;
8221
8222 assert_eq!(c1.load(Ordering::SeqCst), 1);
8223 assert_eq!(c2.load(Ordering::SeqCst), 0);
8224 });
8225 }
8226 }
8227
8228 mod get_reconnect_url {
8229 use super::*;
8230
8231 #[test]
8232 fn single_stream_reconnect_url() {
8233 TOKIO_SHARED_RT.block_on(async {
8234 let ws: Arc<WebsocketStreams> =
8235 create_websocket_streams(Some("ws://example.com"), None, None);
8236 let c0 = ws.common.connection_pool[0].clone();
8237 {
8238 let mut map = ws.connection_streams.lock().await;
8239 map.insert("s1".to_string(), c0.clone());
8240 }
8241 let url = ws.get_reconnect_url("default_url".into(), c0).await;
8242 assert_eq!(url, "ws://example.com/stream?streams=s1");
8243 });
8244 }
8245
8246 #[test]
8247 fn multiple_streams_same_connection() {
8248 TOKIO_SHARED_RT.block_on(async {
8249 let ws: Arc<WebsocketStreams> =
8250 create_websocket_streams(Some("ws://example.com"), None, None);
8251 let c0 = ws.common.connection_pool[0].clone();
8252 {
8253 let mut map = ws.connection_streams.lock().await;
8254 map.insert("a".to_string(), c0.clone());
8255 map.insert("b".to_string(), c0.clone());
8256 }
8257 let url = ws.get_reconnect_url("default_url".into(), c0).await;
8258 let suffix = url
8259 .strip_prefix("ws://example.com/stream?streams=")
8260 .unwrap();
8261 let parts: Vec<_> = suffix.split('&').next().unwrap().split('/').collect();
8262 let set = parts.into_iter().collect::<std::collections::HashSet<_>>();
8263 assert_eq!(set, ["a", "b"].iter().copied().collect());
8264 });
8265 }
8266
8267 #[test]
8268 fn reconnect_url_with_time_unit() {
8269 TOKIO_SHARED_RT.block_on(async {
8270 let mut ws: Arc<WebsocketStreams> =
8271 create_websocket_streams(Some("ws://example.com"), None, None);
8272 Arc::get_mut(&mut ws).unwrap().configuration.time_unit =
8273 Some(TimeUnit::Microsecond);
8274 let c0 = ws.common.connection_pool[0].clone();
8275 {
8276 let mut map = ws.connection_streams.lock().await;
8277 map.insert("x".to_string(), c0.clone());
8278 }
8279 let url = ws.get_reconnect_url("default_url".into(), c0).await;
8280 assert_eq!(
8281 url,
8282 "ws://example.com/stream?streams=x&timeUnit=microsecond"
8283 );
8284 });
8285 }
8286
8287 #[test]
8288 fn reconnect_url_uses_url_path_from_connection_state() {
8289 TOKIO_SHARED_RT.block_on(async {
8290 let ws: Arc<WebsocketStreams> =
8291 create_websocket_streams(Some("ws://example.com"), None, None);
8292 let c0 = ws.common.connection_pool[0].clone();
8293
8294 {
8295 let mut st = c0.state.lock().await;
8296 st.url_path = Some("path1".to_string());
8297 }
8298
8299 {
8300 let mut map = ws.connection_streams.lock().await;
8301 map.insert("path1::s1".to_string(), c0.clone());
8302 }
8303
8304 let url = ws.get_reconnect_url("default_url".into(), c0).await;
8305 assert_eq!(url, "ws://example.com/path1/stream?streams=s1");
8306 });
8307 }
8308
8309 #[test]
8310 fn reconnect_url_strips_prefix_from_multiple_keys_with_url_path() {
8311 TOKIO_SHARED_RT.block_on(async {
8312 let ws: Arc<WebsocketStreams> =
8313 create_websocket_streams(Some("ws://example.com"), None, None);
8314 let c0 = ws.common.connection_pool[0].clone();
8315
8316 {
8317 let mut st = c0.state.lock().await;
8318 st.url_path = Some("path1".to_string());
8319 }
8320
8321 {
8322 let mut map = ws.connection_streams.lock().await;
8323 map.insert("path1::a".to_string(), c0.clone());
8324 map.insert("path1::b".to_string(), c0.clone());
8325 }
8326
8327 let url = ws.get_reconnect_url("default_url".into(), c0).await;
8328
8329 let suffix = url
8330 .strip_prefix("ws://example.com/path1/stream?streams=")
8331 .unwrap();
8332 let parts: Vec<_> = suffix.split('&').next().unwrap().split('/').collect();
8333 let set = parts.into_iter().collect::<std::collections::HashSet<_>>();
8334 assert_eq!(set, ["a", "b"].iter().copied().collect());
8335 });
8336 }
8337
8338 #[test]
8339 fn reconnect_url_with_url_path_and_time_unit() {
8340 TOKIO_SHARED_RT.block_on(async {
8341 let mut ws: Arc<WebsocketStreams> =
8342 create_websocket_streams(Some("ws://example.com"), None, None);
8343 Arc::get_mut(&mut ws).unwrap().configuration.time_unit =
8344 Some(TimeUnit::Microsecond);
8345
8346 let c0 = ws.common.connection_pool[0].clone();
8347
8348 {
8349 let mut st = c0.state.lock().await;
8350 st.url_path = Some("path1".to_string());
8351 }
8352
8353 {
8354 let mut map = ws.connection_streams.lock().await;
8355 map.insert("path1::x".to_string(), c0.clone());
8356 }
8357
8358 let url = ws.get_reconnect_url("default_url".into(), c0).await;
8359 assert_eq!(
8360 url,
8361 "ws://example.com/path1/stream?streams=x&timeUnit=microsecond"
8362 );
8363 });
8364 }
8365
8366 #[test]
8367 fn reconnect_url_ignores_streams_from_other_connections_even_if_same_path_prefix() {
8368 TOKIO_SHARED_RT.block_on(async {
8369 let ws: Arc<WebsocketStreams> =
8370 create_websocket_streams(Some("ws://example.com"), None, None);
8371 let c0 = ws.common.connection_pool[0].clone();
8372 let c1 = ws.common.connection_pool[1].clone();
8373
8374 {
8375 let mut st = c0.state.lock().await;
8376 st.url_path = Some("path1".to_string());
8377 }
8378 {
8379 let mut st = c1.state.lock().await;
8380 st.url_path = Some("path1".to_string());
8381 }
8382
8383 {
8384 let mut map = ws.connection_streams.lock().await;
8385 map.insert("path1::a".to_string(), c0.clone());
8386 map.insert("path1::b".to_string(), c1.clone());
8387 }
8388
8389 let url = ws.get_reconnect_url("default_url".into(), c0).await;
8390
8391 let suffix = url
8392 .strip_prefix("ws://example.com/path1/stream?streams=")
8393 .unwrap();
8394 let parts: Vec<_> = suffix.split('&').next().unwrap().split('/').collect();
8395 let set = parts.into_iter().collect::<std::collections::HashSet<_>>();
8396 assert_eq!(set, ["a"].iter().copied().collect());
8397 });
8398 }
8399 }
8400 }
8401
8402 mod websocket_stream {
8403 use super::*;
8404
8405 mod on {
8406 use super::*;
8407
8408 #[test]
8409 fn registers_callback_and_stream_callback_for_websocket_streams() {
8410 TOKIO_SHARED_RT.block_on(async {
8411 let ws_base = create_websocket_streams(Some("example.com"), None, None);
8412 let stream_name = "s1".to_string();
8413 let conn = ws_base.common.connection_pool[0].clone();
8414
8415 let key = ws_base.stream_key(&stream_name, None);
8416
8417 {
8418 let mut map = ws_base.connection_streams.lock().await;
8419 map.insert(key.clone(), conn.clone());
8420 }
8421 {
8422 let mut state = conn.state.lock().await;
8423 state.stream_callbacks.insert(key.clone(), Vec::new());
8424 }
8425
8426 let stream = Arc::new(WebsocketStream::<Value> {
8427 websocket_base: WebsocketBase::WebsocketStreams(ws_base.clone()),
8428 stream_or_id: stream_name.clone(),
8429 callback: Mutex::new(None),
8430 url_path: None,
8431 id: None,
8432 _phantom: PhantomData,
8433 });
8434
8435 stream.on("message", |_| {}).await;
8436
8437 let cb_guard = stream.callback.lock().await;
8438 assert!(cb_guard.is_some());
8439
8440 let cbs = {
8441 let state = conn.state.lock().await;
8442 state.stream_callbacks.get(&key).unwrap().clone()
8443 };
8444 assert_eq!(cbs.len(), 1);
8445 });
8446 }
8447
8448 #[test]
8449 fn message_twice_registers_two_wrappers_for_websocket_streams() {
8450 TOKIO_SHARED_RT.block_on(async {
8451 let ws_base = create_websocket_streams(Some("example.com"), None, None);
8452 let stream_name = "s2".to_string();
8453 let conn = ws_base.common.connection_pool[0].clone();
8454
8455 let key = ws_base.stream_key(&stream_name, None);
8456
8457 {
8458 let mut map = ws_base.connection_streams.lock().await;
8459 map.insert(key.clone(), conn.clone());
8460 }
8461 {
8462 let mut state = conn.state.lock().await;
8463 state.stream_callbacks.insert(key.clone(), Vec::new());
8464 }
8465
8466 let stream = Arc::new(WebsocketStream::<Value> {
8467 websocket_base: WebsocketBase::WebsocketStreams(ws_base.clone()),
8468 stream_or_id: stream_name.clone(),
8469 url_path: None,
8470 callback: Mutex::new(None),
8471 id: None,
8472 _phantom: PhantomData,
8473 });
8474
8475 stream.on("message", |_| {}).await;
8476 stream.on("message", |_| {}).await;
8477
8478 let state = conn.state.lock().await;
8479 let callbacks = state.stream_callbacks.get(&key).unwrap();
8480 assert_eq!(callbacks.len(), 2);
8481 });
8482 }
8483
8484 #[test]
8485 fn ignores_non_message_event_for_websocket_streams() {
8486 TOKIO_SHARED_RT.block_on(async {
8487 let ws_base = create_websocket_streams(Some("example.com"), None, None);
8488 let stream = Arc::new(WebsocketStream::<Value> {
8489 websocket_base: WebsocketBase::WebsocketStreams(ws_base.clone()),
8490 stream_or_id: "s".into(),
8491 url_path: None,
8492 callback: Mutex::new(None),
8493 id: None,
8494 _phantom: PhantomData,
8495 });
8496 stream.on("open", |_| {}).await;
8497 let guard = stream.callback.lock().await;
8498 assert!(guard.is_none());
8499 });
8500 }
8501
8502 #[test]
8503 fn registers_callback_and_stream_callback_for_websocket_api() {
8504 TOKIO_SHARED_RT.block_on(async {
8505 let ws_base = create_websocket_api(None, None, None);
8506
8507 {
8508 let mut stream_callbacks = ws_base.stream_callbacks.lock().await;
8509 stream_callbacks.insert("id1".to_string(), Vec::new());
8510 }
8511
8512 let stream = Arc::new(WebsocketStream::<Value> {
8513 websocket_base: WebsocketBase::WebsocketApi(ws_base.clone()),
8514 stream_or_id: "id1".to_string(),
8515 url_path: None,
8516 callback: Mutex::new(None),
8517 id: None,
8518 _phantom: PhantomData,
8519 });
8520
8521 let called = Arc::new(Mutex::new(false));
8522 let called_clone = called.clone();
8523 stream
8524 .on("message", move |v: Value| {
8525 let mut lock = called_clone.blocking_lock();
8526 *lock = v == Value::String("x".into());
8527 })
8528 .await;
8529
8530 let cb_guard = stream.callback.lock().await;
8531 assert!(cb_guard.is_some());
8532
8533 let stream_callbacks = ws_base.stream_callbacks.lock().await;
8534 let callbacks = stream_callbacks.get("id1").unwrap();
8535 assert_eq!(callbacks.len(), 1);
8536 });
8537 }
8538
8539 #[test]
8540 fn message_twice_registers_two_wrappers_for_websocket_api() {
8541 TOKIO_SHARED_RT.block_on(async {
8542 let ws_base = create_websocket_api(None, None, None);
8543
8544 let stream = Arc::new(WebsocketStream::<Value> {
8545 websocket_base: WebsocketBase::WebsocketApi(ws_base.clone()),
8546 stream_or_id: "id2".to_string(),
8547 url_path: None,
8548 callback: Mutex::new(None),
8549 id: None,
8550 _phantom: PhantomData,
8551 });
8552
8553 stream.on("message", |_| {}).await;
8554 stream.on("message", |_| {}).await;
8555
8556 let stream_callbacks = ws_base.stream_callbacks.lock().await;
8557 let callbacks = stream_callbacks.get("id2").unwrap();
8558 assert_eq!(callbacks.len(), 2);
8559 });
8560 }
8561
8562 #[test]
8563 fn ignores_non_message_event_for_websocket_api() {
8564 TOKIO_SHARED_RT.block_on(async {
8565 let ws_base = create_websocket_api(None, None, None);
8566
8567 let stream = Arc::new(WebsocketStream::<Value> {
8568 websocket_base: WebsocketBase::WebsocketApi(ws_base.clone()),
8569 stream_or_id: "id3".into(),
8570 url_path: None,
8571 callback: Mutex::new(None),
8572 id: None,
8573 _phantom: PhantomData,
8574 });
8575
8576 stream.on("open", |_| {}).await;
8577
8578 let guard = stream.callback.lock().await;
8579 assert!(guard.is_none());
8580
8581 let stream_callbacks = ws_base.stream_callbacks.lock().await;
8582 assert!(stream_callbacks.get("id3").is_none());
8583 assert!(stream_callbacks.is_empty());
8584 });
8585 }
8586
8587 #[test]
8588 fn registers_callback_for_websocket_streams_with_url_path() {
8589 TOKIO_SHARED_RT.block_on(async {
8590 let ws_base = create_websocket_streams(Some("example.com"), None, None);
8591 let stream_name = "s1".to_string();
8592 let conn = ws_base.common.connection_pool[0].clone();
8593
8594 let key = ws_base.stream_key(&stream_name, Some("path1"));
8595
8596 {
8597 let mut map = ws_base.connection_streams.lock().await;
8598 map.insert(key.clone(), conn.clone());
8599 }
8600 {
8601 let mut state = conn.state.lock().await;
8602 state.stream_callbacks.insert(key.clone(), Vec::new());
8603 }
8604
8605 let stream = Arc::new(WebsocketStream::<Value> {
8606 websocket_base: WebsocketBase::WebsocketStreams(ws_base.clone()),
8607 stream_or_id: stream_name.clone(),
8608 url_path: Some("path1".to_string()),
8609 callback: Mutex::new(None),
8610 id: None,
8611 _phantom: PhantomData,
8612 });
8613
8614 stream.on("message", |_| {}).await;
8615
8616 let cb_guard = stream.callback.lock().await;
8617 assert!(cb_guard.is_some());
8618
8619 let callbacks = {
8620 let state = conn.state.lock().await;
8621 state.stream_callbacks.get(&key).unwrap().clone()
8622 };
8623 assert_eq!(callbacks.len(), 1);
8624 });
8625 }
8626
8627 #[test]
8628 fn url_path_routes_callback_to_correct_key_when_same_stream_name_used() {
8629 TOKIO_SHARED_RT.block_on(async {
8630 let ws_base = create_websocket_streams(Some("example.com"), None, None);
8631 let stream_name = "s1".to_string();
8632 let conn = ws_base.common.connection_pool[0].clone();
8633
8634 let key1 = ws_base.stream_key(&stream_name, Some("path1"));
8635 let key2 = ws_base.stream_key(&stream_name, Some("path2"));
8636
8637 {
8638 let mut map = ws_base.connection_streams.lock().await;
8639 map.insert(key1.clone(), conn.clone());
8640 map.insert(key2.clone(), conn.clone());
8641 }
8642 {
8643 let mut state = conn.state.lock().await;
8644 state.stream_callbacks.insert(key1.clone(), Vec::new());
8645 state.stream_callbacks.insert(key2.clone(), Vec::new());
8646 }
8647
8648 let stream_path1 = Arc::new(WebsocketStream::<Value> {
8649 websocket_base: WebsocketBase::WebsocketStreams(ws_base.clone()),
8650 stream_or_id: stream_name.clone(),
8651 url_path: Some("path1".to_string()),
8652 callback: Mutex::new(None),
8653 id: None,
8654 _phantom: PhantomData,
8655 });
8656
8657 let stream_path2 = Arc::new(WebsocketStream::<Value> {
8658 websocket_base: WebsocketBase::WebsocketStreams(ws_base.clone()),
8659 stream_or_id: stream_name.clone(),
8660 url_path: Some("path2".to_string()),
8661 callback: Mutex::new(None),
8662 id: None,
8663 _phantom: PhantomData,
8664 });
8665
8666 stream_path1.on("message", |_| {}).await;
8667 stream_path2.on("message", |_| {}).await;
8668
8669 let state = conn.state.lock().await;
8670 assert_eq!(state.stream_callbacks.get(&key1).unwrap().len(), 1);
8671 assert_eq!(state.stream_callbacks.get(&key2).unwrap().len(), 1);
8672 });
8673 }
8674 }
8675
8676 mod on_message {
8677 use super::*;
8678
8679 #[test]
8680 fn on_message_registers_callback_for_websocket_streams() {
8681 TOKIO_SHARED_RT.block_on(async {
8682 let ws_base = create_websocket_streams(Some("example.com"), None, None);
8683 let stream_name = "s".to_string();
8684 let conn = ws_base.common.connection_pool[0].clone();
8685 {
8686 let mut map = ws_base.connection_streams.lock().await;
8687 map.insert(stream_name.clone(), conn.clone());
8688 }
8689 {
8690 let mut state = conn.state.lock().await;
8691 state
8692 .stream_callbacks
8693 .insert(stream_name.clone(), Vec::new());
8694 }
8695 let stream = Arc::new(WebsocketStream::<Value> {
8696 websocket_base: WebsocketBase::WebsocketStreams(ws_base.clone()),
8697 stream_or_id: stream_name.clone(),
8698 url_path: None,
8699 callback: Mutex::new(None),
8700 id: None,
8701 _phantom: PhantomData,
8702 });
8703 stream.on_message(|_v| {});
8704 let callbacks = &conn.state.lock().await.stream_callbacks[&stream_name];
8705 assert_eq!(callbacks.len(), 1);
8706 });
8707 }
8708
8709 #[test]
8710 fn on_message_twice_registers_two_callbacks_for_websocket_streams() {
8711 TOKIO_SHARED_RT.block_on(async {
8712 let ws_base = create_websocket_streams(Some("example.com"), None, None);
8713 let stream_name = "s".to_string();
8714 let conn = ws_base.common.connection_pool[0].clone();
8715 {
8716 let mut map = ws_base.connection_streams.lock().await;
8717 map.insert(stream_name.clone(), conn.clone());
8718 }
8719 {
8720 let mut state = conn.state.lock().await;
8721 state
8722 .stream_callbacks
8723 .insert(stream_name.clone(), Vec::new());
8724 }
8725 let stream = Arc::new(WebsocketStream::<Value> {
8726 websocket_base: WebsocketBase::WebsocketStreams(ws_base.clone()),
8727 stream_or_id: stream_name.clone(),
8728 url_path: None,
8729 callback: Mutex::new(None),
8730 id: None,
8731 _phantom: PhantomData,
8732 });
8733 stream.on_message(|_v| {});
8734 stream.on_message(|_v| {});
8735 let callbacks = &conn.state.lock().await.stream_callbacks[&stream_name];
8736 assert_eq!(callbacks.len(), 2);
8737 });
8738 }
8739
8740 #[test]
8741 fn on_message_registers_callback_for_websocket_api() {
8742 TOKIO_SHARED_RT.block_on(async {
8743 let ws_base = create_websocket_api(None, None, None);
8744 let identifier = "id1".to_string();
8745
8746 let stream = Arc::new(WebsocketStream::<Value> {
8747 websocket_base: WebsocketBase::WebsocketApi(ws_base.clone()),
8748 stream_or_id: identifier.clone(),
8749 url_path: None,
8750 callback: Mutex::new(None),
8751 id: None,
8752 _phantom: PhantomData,
8753 });
8754
8755 stream.on_message(|_v: Value| {});
8756
8757 let stream_callbacks = ws_base.stream_callbacks.lock().await;
8758 let callbacks = stream_callbacks.get(&identifier).unwrap();
8759 assert_eq!(callbacks.len(), 1);
8760 });
8761 }
8762
8763 #[test]
8764 fn on_message_twice_registers_two_callbacks_for_websocket_api() {
8765 TOKIO_SHARED_RT.block_on(async {
8766 let ws_base = create_websocket_api(None, None, None);
8767 let identifier = "id2".to_string();
8768
8769 let stream = Arc::new(WebsocketStream::<Value> {
8770 websocket_base: WebsocketBase::WebsocketApi(ws_base.clone()),
8771 stream_or_id: identifier.clone(),
8772 url_path: None,
8773 callback: Mutex::new(None),
8774 id: None,
8775 _phantom: PhantomData,
8776 });
8777
8778 stream.on_message(|_v: Value| {});
8779 stream.on_message(|_v: Value| {});
8780
8781 let stream_callbacks = ws_base.stream_callbacks.lock().await;
8782 let callbacks = stream_callbacks.get(&identifier).unwrap();
8783 assert_eq!(callbacks.len(), 2);
8784 });
8785 }
8786 }
8787
8788 mod unsubscribe {
8789 use super::*;
8790
8791 #[test]
8792 fn without_callback_does_nothing() {
8793 TOKIO_SHARED_RT.block_on(async {
8794 let ws_base = create_websocket_streams(Some("example.com"), None, None);
8795 let stream_name = "s1".to_string();
8796 let conn = ws_base.common.connection_pool[0].clone();
8797 {
8798 let mut map = ws_base.connection_streams.lock().await;
8799 map.insert(stream_name.clone(), conn.clone());
8800 }
8801 let mut state = conn.state.lock().await;
8802 state.stream_callbacks.insert(stream_name.clone(), vec![]);
8803 drop(state);
8804 let stream = Arc::new(WebsocketStream::<Value> {
8805 websocket_base: WebsocketBase::WebsocketStreams(ws_base.clone()),
8806 stream_or_id: stream_name.clone(),
8807 url_path: None,
8808 callback: Mutex::new(None),
8809 id: None,
8810 _phantom: PhantomData,
8811 });
8812 stream.unsubscribe().await;
8813 let state = conn.state.lock().await;
8814 assert!(state.stream_callbacks.contains_key(&stream_name));
8815 });
8816 }
8817
8818 #[test]
8819 fn removes_registered_callback_and_clears_state() {
8820 TOKIO_SHARED_RT.block_on(async {
8821 let ws_base = create_websocket_streams(Some("example.com"), None, None);
8822 let stream_name = "s2".to_string();
8823 let conn = ws_base.common.connection_pool[0].clone();
8824 {
8825 let mut map = ws_base.connection_streams.lock().await;
8826 map.insert(stream_name.clone(), conn.clone());
8827 }
8828 {
8829 let mut state = conn.state.lock().await;
8830 state
8831 .stream_callbacks
8832 .insert(stream_name.clone(), Vec::new());
8833 }
8834 let stream = Arc::new(WebsocketStream::<Value> {
8835 websocket_base: WebsocketBase::WebsocketStreams(ws_base.clone()),
8836 stream_or_id: stream_name.clone(),
8837 url_path: None,
8838 callback: Mutex::new(None),
8839 id: None,
8840 _phantom: PhantomData,
8841 });
8842 stream.on("message", |_| {}).await;
8843 {
8844 let guard = stream.callback.lock().await;
8845 assert!(guard.is_some());
8846 }
8847 stream.unsubscribe().await;
8848 sleep(Duration::from_millis(10)).await;
8849 let guard = stream.callback.lock().await;
8850 assert!(guard.is_none());
8851 let state = conn.state.lock().await;
8852 assert!(
8853 state
8854 .stream_callbacks
8855 .get(&stream_name)
8856 .is_none_or(std::vec::Vec::is_empty)
8857 );
8858 });
8859 }
8860
8861 #[test]
8862 fn without_callback_does_nothing_for_websocket_api() {
8863 TOKIO_SHARED_RT.block_on(async {
8864 let ws_base = create_websocket_api(None, None, None);
8865 let identifier = "id1".to_string();
8866
8867 {
8868 let mut stream_callbacks = ws_base.stream_callbacks.lock().await;
8869 stream_callbacks.insert(identifier.clone(), Vec::new());
8870 }
8871
8872 let stream = Arc::new(WebsocketStream::<Value> {
8873 websocket_base: WebsocketBase::WebsocketApi(ws_base.clone()),
8874 stream_or_id: identifier.clone(),
8875 url_path: None,
8876 callback: Mutex::new(None),
8877 id: None,
8878 _phantom: PhantomData,
8879 });
8880
8881 stream.unsubscribe().await;
8882
8883 let stream_callbacks = ws_base.stream_callbacks.lock().await;
8884 assert!(stream_callbacks.contains_key(&identifier));
8885 let callbacks = stream_callbacks.get(&identifier).unwrap();
8886 assert!(callbacks.is_empty());
8887 });
8888 }
8889
8890 #[test]
8891 fn removes_registered_callback_and_clears_state_for_websocket_api() {
8892 TOKIO_SHARED_RT.block_on(async {
8893 let ws_base = create_websocket_api(None, None, None);
8894 let identifier = "id2".to_string();
8895
8896 {
8897 let mut stream_callbacks = ws_base.stream_callbacks.lock().await;
8898 stream_callbacks.insert(identifier.clone(), Vec::new());
8899 }
8900
8901 let stream = Arc::new(WebsocketStream::<Value> {
8902 websocket_base: WebsocketBase::WebsocketApi(ws_base.clone()),
8903 stream_or_id: identifier.clone(),
8904 url_path: None,
8905 callback: Mutex::new(None),
8906 id: None,
8907 _phantom: PhantomData,
8908 });
8909
8910 stream.on("message", |_| {}).await;
8911
8912 {
8913 let stream_callbacks = ws_base.stream_callbacks.lock().await;
8914 let callbacks = stream_callbacks
8915 .get(&identifier)
8916 .expect("Entry for 'id2' should exist");
8917 assert_eq!(callbacks.len(), 1);
8918 }
8919
8920 stream.unsubscribe().await;
8921
8922 {
8923 let guard = stream.callback.lock().await;
8924 assert!(guard.is_none());
8925 }
8926
8927 {
8928 let stream_callbacks = ws_base.stream_callbacks.lock().await;
8929 let callbacks = stream_callbacks
8930 .get(&identifier)
8931 .expect("Entry for 'id2' should still exist");
8932 assert!(callbacks.is_empty());
8933 }
8934 });
8935 }
8936 }
8937 }
8938
8939 mod create_stream_handler {
8940 use super::*;
8941
8942 #[test]
8943 fn create_stream_handler_without_id_registers_stream() {
8944 TOKIO_SHARED_RT.block_on(async {
8945 let ws = create_websocket_streams(Some("ws://example.com"), None, None);
8946 let stream_name = "foo".to_string();
8947 let handler = create_stream_handler::<serde_json::Value>(
8948 WebsocketBase::WebsocketStreams(ws.clone()),
8949 stream_name.clone(),
8950 None,
8951 None,
8952 )
8953 .await;
8954 assert_eq!(handler.stream_or_id, stream_name);
8955 assert!(handler.id.is_none());
8956 let map = ws.connection_streams.lock().await;
8957 assert!(map.contains_key(&stream_name));
8958 });
8959 }
8960
8961 #[test]
8962 fn create_stream_handler_with_custom_string_id_registers_stream_and_id() {
8963 TOKIO_SHARED_RT.block_on(async {
8964 let ws = create_websocket_streams(Some("ws://example.com"), None, None);
8965 let stream_name = "bar".to_string();
8966 let custom_id = StreamId::from("my-custom-id".to_string());
8967 let handler = create_stream_handler::<serde_json::Value>(
8968 WebsocketBase::WebsocketStreams(ws.clone()),
8969 stream_name.clone(),
8970 Some(custom_id.clone()),
8971 None,
8972 )
8973 .await;
8974 assert_eq!(handler.stream_or_id, stream_name);
8975 assert_eq!(handler.id, Some(custom_id));
8976 let map = ws.connection_streams.lock().await;
8977 assert!(map.contains_key(&stream_name));
8978 });
8979 }
8980
8981 #[test]
8982 fn create_stream_handler_with_custom_integer_id_registers_stream_and_id() {
8983 TOKIO_SHARED_RT.block_on(async {
8984 let ws = create_websocket_streams(Some("ws://example.com"), None, None);
8985 let stream_name = "bar".to_string();
8986 let custom_id = StreamId::from(123u32);
8987 let handler = create_stream_handler::<serde_json::Value>(
8988 WebsocketBase::WebsocketStreams(ws.clone()),
8989 stream_name.clone(),
8990 Some(custom_id.clone()),
8991 None,
8992 )
8993 .await;
8994 assert_eq!(handler.stream_or_id, stream_name);
8995 assert_eq!(handler.id, Some(custom_id));
8996 let map = ws.connection_streams.lock().await;
8997 assert!(map.contains_key(&stream_name));
8998 });
8999 }
9000
9001 #[test]
9002 fn create_stream_handler_without_id_registers_api_stream() {
9003 TOKIO_SHARED_RT.block_on(async {
9004 let ws_base = create_websocket_api(None, None, None);
9005 let identifier = "foo-api".to_string();
9006
9007 let handler = create_stream_handler::<Value>(
9008 WebsocketBase::WebsocketApi(ws_base.clone()),
9009 identifier.clone(),
9010 None,
9011 None,
9012 )
9013 .await;
9014
9015 assert_eq!(handler.stream_or_id, identifier);
9016 assert!(handler.id.is_none());
9017 });
9018 }
9019
9020 #[test]
9021 fn create_stream_handler_with_custom_string_id_registers_api_stream_and_id() {
9022 TOKIO_SHARED_RT.block_on(async {
9023 let ws_base = create_websocket_api(None, None, None);
9024 let identifier = "bar-api".to_string();
9025 let custom_id = StreamId::from("custom-123".to_string());
9026
9027 let handler = create_stream_handler::<Value>(
9028 WebsocketBase::WebsocketApi(ws_base.clone()),
9029 identifier.clone(),
9030 Some(custom_id.clone()),
9031 None,
9032 )
9033 .await;
9034
9035 assert_eq!(handler.stream_or_id, identifier);
9036 assert_eq!(handler.id, Some(custom_id));
9037 });
9038 }
9039
9040 #[test]
9041 fn create_stream_handler_with_custom_integer_id_registers_api_stream_and_id() {
9042 TOKIO_SHARED_RT.block_on(async {
9043 let ws_base = create_websocket_api(None, None, None);
9044 let identifier = "bar-api".to_string();
9045 let custom_id = StreamId::from(123u32);
9046
9047 let handler = create_stream_handler::<Value>(
9048 WebsocketBase::WebsocketApi(ws_base.clone()),
9049 identifier.clone(),
9050 Some(custom_id.clone()),
9051 None,
9052 )
9053 .await;
9054
9055 assert_eq!(handler.stream_or_id, identifier);
9056 assert_eq!(handler.id, Some(custom_id));
9057 });
9058 }
9059
9060 #[test]
9061 fn websocket_streams_without_url_path_registers_stream_key() {
9062 TOKIO_SHARED_RT.block_on(async {
9063 let ws = create_websocket_streams(Some("ws://example.com"), None, None);
9064 let stream_name = "foo".to_string();
9065
9066 let handler = create_stream_handler::<Value>(
9067 WebsocketBase::WebsocketStreams(ws.clone()),
9068 stream_name.clone(),
9069 None,
9070 None,
9071 )
9072 .await;
9073
9074 assert_eq!(handler.stream_or_id, stream_name);
9075 assert!(handler.id.is_none());
9076
9077 let map = ws.connection_streams.lock().await;
9078 assert!(map.contains_key("foo"));
9079 });
9080 }
9081
9082 #[test]
9083 fn websocket_streams_with_url_path_registers_prefixed_stream_key() {
9084 TOKIO_SHARED_RT.block_on(async {
9085 let ws = create_websocket_streams(Some("ws://example.com"), None, None);
9086
9087 {
9088 let conn = ws.common.connection_pool[0].clone();
9089 let mut st = conn.state.lock().await;
9090 st.url_path = Some("path1".to_string());
9091 }
9092
9093 let stream_name = "foo".to_string();
9094
9095 let handler = create_stream_handler::<Value>(
9096 WebsocketBase::WebsocketStreams(ws.clone()),
9097 stream_name.clone(),
9098 None,
9099 Some("path1".to_string()),
9100 )
9101 .await;
9102
9103 assert_eq!(handler.stream_or_id, stream_name);
9104 assert!(handler.id.is_none());
9105
9106 let map = ws.connection_streams.lock().await;
9107 assert!(map.contains_key("path1::foo"));
9108 });
9109 }
9110
9111 #[test]
9112 fn websocket_streams_with_custom_id_preserves_id_and_registers_prefixed_key() {
9113 TOKIO_SHARED_RT.block_on(async {
9114 let ws = create_websocket_streams(Some("ws://example.com"), None, None);
9115
9116 {
9117 let conn = ws.common.connection_pool[0].clone();
9118 let mut st = conn.state.lock().await;
9119 st.url_path = Some("path1".to_string());
9120 }
9121
9122 let stream_name = "bar".to_string();
9123 let custom_id = StreamId::from("my-custom-id".to_string());
9124
9125 let handler = create_stream_handler::<Value>(
9126 WebsocketBase::WebsocketStreams(ws.clone()),
9127 stream_name.clone(),
9128 Some(custom_id.clone()),
9129 Some("path1".to_string()),
9130 )
9131 .await;
9132
9133 assert_eq!(handler.stream_or_id, stream_name);
9134 assert_eq!(handler.id, Some(custom_id));
9135
9136 let map = ws.connection_streams.lock().await;
9137 assert!(map.contains_key("path1::bar"));
9138 });
9139 }
9140
9141 #[test]
9142 fn websocket_api_does_not_register_stream_in_connection_map() {
9143 TOKIO_SHARED_RT.block_on(async {
9144 let ws_base = create_websocket_api(None, None, None);
9145 let identifier = "foo-api".to_string();
9146
9147 let handler = create_stream_handler::<Value>(
9148 WebsocketBase::WebsocketApi(ws_base.clone()),
9149 identifier.clone(),
9150 None,
9151 Some("path1".to_string()),
9152 )
9153 .await;
9154
9155 assert_eq!(handler.stream_or_id, identifier);
9156 assert!(handler.id.is_none());
9157 });
9158 }
9159 }
9160
9161 mod websocket_connection_failure_reason {
9162 use super::*;
9163 use std::io::{Error as IoError, ErrorKind};
9164 use tokio_tungstenite::tungstenite::Error as TungsteniteError;
9165
9166 #[test]
9167 fn from_tungstenite_error_classifies_connection_closed() {
9168 let error = TungsteniteError::ConnectionClosed;
9169 let reason = WebsocketConnectionFailureReason::from_tungstenite_error(&error);
9170 assert!(matches!(
9171 reason,
9172 WebsocketConnectionFailureReason::ConnectionReset
9173 ));
9174 assert!(reason.should_reconnect());
9175 }
9176
9177 #[test]
9178 fn from_tungstenite_error_classifies_already_closed() {
9179 let error = TungsteniteError::AlreadyClosed;
9180 let reason = WebsocketConnectionFailureReason::from_tungstenite_error(&error);
9181 assert!(matches!(
9182 reason,
9183 WebsocketConnectionFailureReason::ConnectionReset
9184 ));
9185 assert!(reason.should_reconnect());
9186 }
9187
9188 #[test]
9189 fn from_tungstenite_error_classifies_io_errors() {
9190 let io_error = IoError::new(ErrorKind::ConnectionReset, "connection reset");
9192 let error = TungsteniteError::Io(io_error);
9193 let reason = WebsocketConnectionFailureReason::from_tungstenite_error(&error);
9194 assert!(matches!(
9195 reason,
9196 WebsocketConnectionFailureReason::ConnectionReset
9197 ));
9198 assert!(reason.should_reconnect());
9199
9200 let io_error = IoError::new(ErrorKind::ConnectionAborted, "connection aborted");
9202 let error = TungsteniteError::Io(io_error);
9203 let reason = WebsocketConnectionFailureReason::from_tungstenite_error(&error);
9204 assert!(matches!(
9205 reason,
9206 WebsocketConnectionFailureReason::ConnectionReset
9207 ));
9208 assert!(reason.should_reconnect());
9209
9210 let io_error = IoError::new(ErrorKind::TimedOut, "timed out");
9212 let error = TungsteniteError::Io(io_error);
9213 let reason = WebsocketConnectionFailureReason::from_tungstenite_error(&error);
9214 assert!(matches!(
9215 reason,
9216 WebsocketConnectionFailureReason::NetworkInterruption
9217 ));
9218 assert!(reason.should_reconnect());
9219
9220 let io_error = IoError::new(ErrorKind::UnexpectedEof, "unexpected eof");
9222 let error = TungsteniteError::Io(io_error);
9223 let reason = WebsocketConnectionFailureReason::from_tungstenite_error(&error);
9224 assert!(matches!(
9225 reason,
9226 WebsocketConnectionFailureReason::StreamEnded
9227 ));
9228 assert!(reason.should_reconnect());
9229
9230 let io_error = IoError::new(ErrorKind::PermissionDenied, "permission denied");
9232 let error = TungsteniteError::Io(io_error);
9233 let reason = WebsocketConnectionFailureReason::from_tungstenite_error(&error);
9234 assert!(matches!(
9235 reason,
9236 WebsocketConnectionFailureReason::AuthenticationFailure
9237 ));
9238 assert!(!reason.should_reconnect());
9239
9240 let io_error = IoError::other("other error");
9242 let error = TungsteniteError::Io(io_error);
9243 let reason = WebsocketConnectionFailureReason::from_tungstenite_error(&error);
9244 assert!(matches!(
9245 reason,
9246 WebsocketConnectionFailureReason::NetworkInterruption
9247 ));
9248 assert!(reason.should_reconnect());
9249 }
9250
9251 #[test]
9252 fn from_tungstenite_error_classifies_protocol_errors() {
9253 use tokio_tungstenite::tungstenite::error::ProtocolError;
9255 let protocol_error = ProtocolError::ResetWithoutClosingHandshake;
9256 let error = TungsteniteError::Protocol(protocol_error);
9257 let reason = WebsocketConnectionFailureReason::from_tungstenite_error(&error);
9258 assert!(matches!(
9259 reason,
9260 WebsocketConnectionFailureReason::ProtocolViolation
9261 ));
9262 assert!(!reason.should_reconnect());
9263
9264 let error = TungsteniteError::Utf8;
9266 let reason = WebsocketConnectionFailureReason::from_tungstenite_error(&error);
9267 assert!(matches!(
9268 reason,
9269 WebsocketConnectionFailureReason::ProtocolViolation
9270 ));
9271 assert!(!reason.should_reconnect());
9272 }
9273
9274 #[test]
9275 fn from_close_code_classifies_standard_codes() {
9276 let reason = WebsocketConnectionFailureReason::from_close_code(1000, false);
9278 assert!(matches!(
9279 reason,
9280 WebsocketConnectionFailureReason::NormalClose
9281 ));
9282 assert!(!reason.should_reconnect());
9283
9284 let reason = WebsocketConnectionFailureReason::from_close_code(1001, false);
9286 assert!(matches!(
9287 reason,
9288 WebsocketConnectionFailureReason::ServerTemporaryError
9289 ));
9290 assert!(reason.should_reconnect());
9291
9292 let reason = WebsocketConnectionFailureReason::from_close_code(1002, false);
9294 assert!(matches!(
9295 reason,
9296 WebsocketConnectionFailureReason::ProtocolViolation
9297 ));
9298 assert!(!reason.should_reconnect());
9299
9300 let reason = WebsocketConnectionFailureReason::from_close_code(1006, false);
9302 assert!(matches!(
9303 reason,
9304 WebsocketConnectionFailureReason::UnexpectedClose
9305 ));
9306 assert!(reason.should_reconnect());
9307
9308 let reason = WebsocketConnectionFailureReason::from_close_code(1008, false);
9310 assert!(matches!(
9311 reason,
9312 WebsocketConnectionFailureReason::PermanentServerError
9313 ));
9314 assert!(!reason.should_reconnect());
9315
9316 let reason = WebsocketConnectionFailureReason::from_close_code(1011, false);
9318 assert!(matches!(
9319 reason,
9320 WebsocketConnectionFailureReason::ServerTemporaryError
9321 ));
9322 assert!(reason.should_reconnect());
9323
9324 let reason = WebsocketConnectionFailureReason::from_close_code(1015, false);
9326 assert!(matches!(
9327 reason,
9328 WebsocketConnectionFailureReason::ConfigurationError
9329 ));
9330 assert!(!reason.should_reconnect());
9331
9332 let reason = WebsocketConnectionFailureReason::from_close_code(4000, false);
9334 assert!(matches!(
9335 reason,
9336 WebsocketConnectionFailureReason::PermanentServerError
9337 ));
9338 assert!(!reason.should_reconnect());
9339
9340 let reason = WebsocketConnectionFailureReason::from_close_code(4999, false);
9341 assert!(matches!(
9342 reason,
9343 WebsocketConnectionFailureReason::PermanentServerError
9344 ));
9345 assert!(!reason.should_reconnect());
9346
9347 let reason = WebsocketConnectionFailureReason::from_close_code(9999, false);
9349 assert!(matches!(
9350 reason,
9351 WebsocketConnectionFailureReason::UnexpectedClose
9352 ));
9353 assert!(reason.should_reconnect());
9354 }
9355
9356 #[test]
9357 fn from_close_code_handles_user_initiated() {
9358 let reason = WebsocketConnectionFailureReason::from_close_code(1000, true);
9360 assert!(matches!(
9361 reason,
9362 WebsocketConnectionFailureReason::UserInitiatedClose
9363 ));
9364 assert!(!reason.should_reconnect());
9365
9366 let reason = WebsocketConnectionFailureReason::from_close_code(1006, true);
9367 assert!(matches!(
9368 reason,
9369 WebsocketConnectionFailureReason::UserInitiatedClose
9370 ));
9371 assert!(!reason.should_reconnect());
9372
9373 let reason = WebsocketConnectionFailureReason::from_close_code(4000, true);
9374 assert!(matches!(
9375 reason,
9376 WebsocketConnectionFailureReason::UserInitiatedClose
9377 ));
9378 assert!(!reason.should_reconnect());
9379 }
9380
9381 #[test]
9382 fn should_reconnect_logic() {
9383 assert!(WebsocketConnectionFailureReason::NetworkInterruption.should_reconnect());
9385 assert!(WebsocketConnectionFailureReason::ConnectionReset.should_reconnect());
9386 assert!(WebsocketConnectionFailureReason::ServerTemporaryError.should_reconnect());
9387 assert!(WebsocketConnectionFailureReason::UnexpectedClose.should_reconnect());
9388 assert!(WebsocketConnectionFailureReason::StreamEnded.should_reconnect());
9389
9390 assert!(!WebsocketConnectionFailureReason::AuthenticationFailure.should_reconnect());
9392 assert!(!WebsocketConnectionFailureReason::ProtocolViolation.should_reconnect());
9393 assert!(!WebsocketConnectionFailureReason::ConfigurationError.should_reconnect());
9394 assert!(!WebsocketConnectionFailureReason::UserInitiatedClose.should_reconnect());
9395 assert!(!WebsocketConnectionFailureReason::PermanentServerError.should_reconnect());
9396 assert!(!WebsocketConnectionFailureReason::NormalClose.should_reconnect());
9397 }
9398
9399 #[test]
9400 fn debug_and_clone_work() {
9401 let reason = WebsocketConnectionFailureReason::NetworkInterruption;
9402 let cloned = reason;
9403 let debug_str = format!("{:?}", reason);
9404
9405 assert!(matches!(
9406 cloned,
9407 WebsocketConnectionFailureReason::NetworkInterruption
9408 ));
9409 assert!(debug_str.contains("NetworkInterruption"));
9410 }
9411 }
9412}