1use async_trait::async_trait;
2use flate2::read::ZlibDecoder;
3use futures::{SinkExt, StreamExt, future::try_join_all, stream::FuturesUnordered};
4use http::header::USER_AGENT;
5use serde::de::DeserializeOwned;
6use serde_json::{Value, json};
7use std::{
8 collections::{BTreeMap, HashMap, VecDeque},
9 io::Read,
10 marker::PhantomData,
11 mem::take,
12 sync::{
13 Arc,
14 atomic::{AtomicBool, AtomicUsize, Ordering},
15 },
16 time::Duration,
17};
18use tokio::{
19 net::TcpStream,
20 select, spawn,
21 sync::{
22 Mutex, Notify,
23 mpsc::{Receiver, Sender, UnboundedSender, channel, unbounded_channel},
24 oneshot,
25 },
26 task::JoinHandle,
27 time::{sleep, timeout},
28};
29use tokio_tungstenite::{
30 Connector, MaybeTlsStream, WebSocketStream, connect_async_tls_with_config,
31 tungstenite::{
32 Message,
33 client::IntoClientRequest,
34 protocol::{CloseFrame, WebSocketConfig, frame::coding::CloseCode},
35 },
36};
37use tokio_util::time::DelayQueue;
38use tracing::{debug, error, info, warn};
39
40use super::{
41 config::{AgentConnector, ConfigurationWebsocketApi, ConfigurationWebsocketStreams},
42 errors::{WebsocketConnectionFailureReason, WebsocketError},
43 models::{StreamId, WebsocketApiResponse, WebsocketEvent, WebsocketMode},
44 utils::{build_websocket_api_message, normalize_stream_id, random_string, validate_time_unit},
45};
46
47pub type WebSocketClient = WebSocketStream<MaybeTlsStream<TcpStream>>;
48
49const MAX_CONN_DURATION: Duration = Duration::from_secs(23 * 60 * 60);
50
51pub struct Subscription {
52 handle: JoinHandle<()>,
53}
54
55impl Subscription {
56 pub fn unsubscribe(self) {
71 self.handle.abort();
72 }
73}
74
75#[derive(Clone)]
76pub enum WebsocketBase {
77 WebsocketApi(Arc<WebsocketApi>),
78 WebsocketStreams(Arc<WebsocketStreams>),
79}
80
81pub struct WebsocketEventEmitter {
82 subscribers: Arc<std::sync::Mutex<Vec<UnboundedSender<WebsocketEvent>>>>,
83}
84
85impl Default for WebsocketEventEmitter {
86 fn default() -> Self {
87 Self::new()
88 }
89}
90
91impl WebsocketEventEmitter {
92 #[must_use]
93 pub fn new() -> Self {
94 Self {
95 subscribers: Arc::new(std::sync::Mutex::new(Vec::new())),
96 }
97 }
98
99 pub fn subscribe<F>(&self, mut callback: F) -> Subscription
126 where
127 F: FnMut(WebsocketEvent) + Send + 'static,
128 {
129 let (tx, mut rx) = unbounded_channel();
130 let mut guard = match self.subscribers.lock() {
131 Ok(guard) => guard,
132 Err(poisoned) => poisoned.into_inner(),
133 };
134 guard.push(tx);
135 drop(guard);
136
137 let handle = spawn(async move {
138 while let Some(event) = rx.recv().await {
139 callback(event);
140 }
141 });
142 Subscription { handle }
143 }
144
145 pub fn emit(&self, event: &WebsocketEvent) {
155 let mut guard = match self.subscribers.lock() {
156 Ok(guard) => guard,
157 Err(poisoned) => poisoned.into_inner(),
158 };
159
160 guard.retain(|tx| {
161 if tx.send(event.clone()).is_ok() {
162 true
163 } else {
164 warn!("subscriber dropped without unsubscribing");
165 false
166 }
167 });
168 }
169}
170
171#[async_trait]
186pub trait WebsocketHandler: Send + Sync + 'static {
187 async fn on_open(&self, url: String, connection: Arc<WebsocketConnection>);
188 async fn on_message(&self, data: String, connection: Arc<WebsocketConnection>);
189 async fn get_reconnect_url(
190 &self,
191 default_url: String,
192 connection: Arc<WebsocketConnection>,
193 ) -> String;
194}
195
196pub struct PendingRequest {
197 pub completion: oneshot::Sender<Result<Value, WebsocketError>>,
198}
199
200#[derive(Clone)]
201pub struct WebsocketSessionLogonReq {
202 pub method: String,
203 pub payload: BTreeMap<String, Value>,
204 pub options: WebsocketMessageSendOptions,
205}
206
207pub struct WebsocketConnectionState {
208 pub reconnection_pending: bool,
209 pub renewal_pending: bool,
210 pub close_initiated: bool,
211 pub pending_requests: HashMap<String, PendingRequest>,
212 pub pending_subscriptions: VecDeque<String>,
213 pub stream_callbacks: HashMap<String, Vec<Arc<dyn Fn(&Value) + Send + Sync + 'static>>>,
214 pub is_session_logged_on: bool,
215 pub session_logon_req: Option<WebsocketSessionLogonReq>,
216 pub url_path: Option<String>,
217 pub handler: Option<Arc<dyn WebsocketHandler>>,
218 pub ws_write_tx: Option<UnboundedSender<Message>>,
219}
220
221impl Default for WebsocketConnectionState {
222 fn default() -> Self {
223 Self::new()
224 }
225}
226
227impl WebsocketConnectionState {
228 #[must_use]
229 pub fn new() -> Self {
230 Self {
231 reconnection_pending: false,
232 renewal_pending: false,
233 close_initiated: false,
234 pending_requests: HashMap::new(),
235 pending_subscriptions: VecDeque::new(),
236 stream_callbacks: HashMap::new(),
237 is_session_logged_on: false,
238 session_logon_req: None,
239 url_path: None,
240 handler: None,
241 ws_write_tx: None,
242 }
243 }
244}
245
246pub struct WebsocketConnection {
247 pub id: String,
248 pub drain_notify: Notify,
249 pub state: Mutex<WebsocketConnectionState>,
250}
251
252impl WebsocketConnection {
253 pub fn new(id: impl Into<String>) -> Arc<Self> {
254 Arc::new(Self {
255 id: id.into(),
256 drain_notify: Notify::new(),
257 state: Mutex::new(WebsocketConnectionState::new()),
258 })
259 }
260
261 pub async fn set_handler(&self, handler: Arc<dyn WebsocketHandler>) {
262 let mut conn_state = self.state.lock().await;
263 conn_state.handler = Some(handler);
264 }
265}
266
267struct ReconnectEntry {
268 connection_id: String,
269 url: String,
270 is_renewal: bool,
271}
272
273pub struct WebsocketCommon {
274 pub events: WebsocketEventEmitter,
275 mode: WebsocketMode,
276 round_robin_index: AtomicUsize,
277 connection_pool: Vec<Arc<WebsocketConnection>>,
278 reconnect_tx: Sender<ReconnectEntry>,
279 renewal_tx: Sender<(String, String)>,
280 reconnect_delay: usize,
281 agent: Option<AgentConnector>,
282 user_agent: Option<String>,
283}
284
285impl WebsocketCommon {
286 #[must_use]
287 pub fn new(
288 mut initial_pool: Vec<Arc<WebsocketConnection>>,
289 mode: WebsocketMode,
290 reconnect_delay: usize,
291 agent: Option<AgentConnector>,
292 user_agent: Option<String>,
293 ) -> Arc<Self> {
294 if initial_pool.is_empty() {
295 for _ in 0..mode.pool_size() {
296 let id = random_string();
297 initial_pool.push(WebsocketConnection::new(id));
298 }
299 }
300
301 let (reconnect_tx, reconnect_rx) = channel::<ReconnectEntry>(mode.pool_size());
302 let (renewal_tx, renewal_rx) = channel::<(String, String)>(mode.pool_size());
303
304 let common = Arc::new(Self {
305 events: WebsocketEventEmitter::new(),
306 mode,
307 round_robin_index: AtomicUsize::new(0),
308 connection_pool: initial_pool,
309 reconnect_tx,
310 renewal_tx,
311 reconnect_delay,
312 agent,
313 user_agent,
314 });
315
316 Self::spawn_reconnect_loop(Arc::clone(&common), reconnect_rx);
317 Self::spawn_renewal_loop(&Arc::clone(&common), renewal_rx);
318
319 common
320 }
321
322 fn spawn_reconnect_loop(common: Arc<Self>, mut reconnect_rx: Receiver<ReconnectEntry>) {
340 spawn(async move {
341 while let Some(entry) = reconnect_rx.recv().await {
342 info!("Scheduling reconnect for id {}", entry.connection_id);
343
344 if !entry.is_renewal {
345 sleep(Duration::from_millis(common.reconnect_delay as u64)).await;
346 }
347
348 if let Some(conn_arc) = common
349 .connection_pool
350 .iter()
351 .find(|c| c.id == entry.connection_id)
352 .cloned()
353 {
354 let common_clone = Arc::clone(&common);
355 if let Err(err) = common_clone
356 .init_connect(&entry.url, entry.is_renewal, Some(conn_arc.clone()))
357 .await
358 {
359 error!(
360 "Reconnect failed for {} → {}: {:?}",
361 entry.connection_id, entry.url, err
362 );
363 }
364
365 sleep(Duration::from_secs(1)).await;
366 } else {
367 warn!("No connection {} found for reconnect", entry.connection_id);
368 }
369 }
370 });
371 }
372
373 fn spawn_renewal_loop(common: &Arc<Self>, renewal_rx: Receiver<(String, String)>) {
387 let common = Arc::clone(common);
388 spawn(async move {
389 let mut dq = DelayQueue::new();
390 let mut renewal_rx = renewal_rx;
391
392 loop {
393 select! {
394 Some((conn_id, url)) = renewal_rx.recv() => {
395 debug!("Scheduling renewal for {}", conn_id);
396 dq.insert((conn_id, url), MAX_CONN_DURATION);
397 }
398
399 Some(expired) = dq.next() => {
400 let (conn_id, default_url) = expired.into_inner();
401
402 if let Some(conn_arc) = common
403 .connection_pool
404 .iter()
405 .find(|c| c.id == conn_id)
406 .cloned()
407 {
408 debug!("Renewing connection {}", conn_id);
409 let url = common
410 .get_reconnect_url(&default_url, Arc::clone(&conn_arc))
411 .await;
412 if let Err(e) = common.reconnect_tx.send(ReconnectEntry {
413 connection_id: conn_id.clone(),
414 url,
415 is_renewal: true,
416 }).await {
417 error!(
418 "Failed to enqueue renewal for {}: {:?}",
419 conn_id, e
420 );
421 }
422 } else {
423 warn!("No connection {} found for renewal", conn_id);
424 }
425 }
426 }
427 }
428 });
429 }
430
431 pub async fn is_connection_ready(
449 &self,
450 connection: &WebsocketConnection,
451 allow_non_established: bool,
452 ) -> bool {
453 let conn_state = connection.state.lock().await;
454 (allow_non_established || conn_state.ws_write_tx.is_some())
455 && !conn_state.reconnection_pending
456 && !conn_state.close_initiated
457 }
458
459 async fn is_connected(&self, connection: Option<&Arc<WebsocketConnection>>) -> bool {
475 if let Some(conn_arc) = connection {
476 return self.is_connection_ready(conn_arc, false).await;
477 }
478
479 for conn_arc in &self.connection_pool {
480 if self.is_connection_ready(conn_arc, false).await {
481 return true;
482 }
483 }
484
485 false
486 }
487
488 async fn get_available_connections(
505 &self,
506 allow_non_established: bool,
507 url_path: Option<&str>,
508 ) -> Vec<Arc<WebsocketConnection>> {
509 if matches!(self.mode, WebsocketMode::Single) && url_path.is_none() {
510 return vec![Arc::clone(&self.connection_pool[0])];
511 }
512
513 let mut ready = Vec::new();
514 for conn in &self.connection_pool {
515 if self.is_connection_ready(conn, allow_non_established).await {
516 ready.push(Arc::clone(conn));
517 }
518 }
519
520 ready
521 }
522
523 async fn get_connection(
544 &self,
545 allow_non_established: bool,
546 url_path: Option<&str>,
547 ) -> Result<Arc<WebsocketConnection>, WebsocketError> {
548 let candidates = self
549 .get_available_connections(allow_non_established, url_path)
550 .await;
551
552 let mut ready = Vec::new();
553 for conn in candidates {
554 if let Some(path) = url_path {
555 let st = conn.state.lock().await;
556 if st.url_path.as_deref() != Some(path) {
557 continue;
558 }
559 }
560 ready.push(conn);
561 }
562
563 if ready.is_empty() {
564 return Err(WebsocketError::NotConnected);
565 }
566
567 let idx = self.round_robin_index.fetch_add(1, Ordering::Relaxed) % ready.len();
568
569 Ok(Arc::clone(&ready[idx]))
570 }
571
572 async fn close_connection_gracefully(
589 &self,
590 ws_write_tx_to_close: UnboundedSender<Message>,
591 connection: Arc<WebsocketConnection>,
592 ) -> Result<(), WebsocketError> {
593 debug!("Waiting for pending requests to complete before disconnecting.");
594
595 let drain = async {
596 loop {
597 {
598 let conn_state = connection.state.lock().await;
599 if conn_state.pending_requests.is_empty() {
600 debug!("All pending requests completed, proceeding to close.");
601 break;
602 }
603 }
604 connection.drain_notify.notified().await;
605 }
606 };
607
608 if timeout(Duration::from_secs(30), drain).await.is_err() {
609 warn!("Timeout waiting for pending requests; forcing close.");
610 }
611
612 info!("Closing WebSocket connection for {}", connection.id);
613 let _ = ws_write_tx_to_close.send(Message::Close(Some(CloseFrame {
614 code: CloseCode::Normal,
615 reason: "".into(),
616 })));
617
618 Ok(())
619 }
620
621 async fn get_reconnect_url(
638 &self,
639 default_url: &str,
640 connection: Arc<WebsocketConnection>,
641 ) -> String {
642 if let Some(handler) = {
643 let conn_state = connection.state.lock().await;
644 conn_state.handler.clone()
645 } {
646 return handler
647 .get_reconnect_url(default_url.to_string(), Arc::clone(&connection))
648 .await;
649 }
650
651 default_url.to_string()
652 }
653
654 async fn on_open(
676 &self,
677 url: String,
678 connection: Arc<WebsocketConnection>,
679 old_ws_writer: Option<UnboundedSender<Message>>,
680 ) {
681 if let Some(handler) = {
682 let conn_state = connection.state.lock().await;
683 conn_state.handler.clone()
684 } {
685 handler.on_open(url.clone(), Arc::clone(&connection)).await;
686 }
687
688 let conn_id = &connection.id;
689 info!("Connected to WebSocket Server with id {}: {}", conn_id, url);
690
691 {
692 let mut conn_state = connection.state.lock().await;
693
694 if conn_state.renewal_pending {
695 conn_state.renewal_pending = false;
696 drop(conn_state);
697 if let Some(tx) = old_ws_writer {
698 info!("Connection renewal in progress; closing previous connection.");
699 let _ = self
700 .close_connection_gracefully(tx, Arc::clone(&connection))
701 .await;
702 }
703 return;
704 }
705
706 if conn_state.close_initiated {
707 drop(conn_state);
708 if let Some(tx) = connection.state.lock().await.ws_write_tx.clone() {
709 info!("Close initiated; closing connection.");
710 let _ = self
711 .close_connection_gracefully(tx, Arc::clone(&connection))
712 .await;
713 }
714 return;
715 }
716
717 self.events.emit(&WebsocketEvent::Open);
718 }
719 }
720
721 async fn on_message(&self, msg: String, connection: Arc<WebsocketConnection>) {
733 let handler = connection.state.lock().await.handler.clone();
734 if let Some(handler) = handler {
735 handler
736 .on_message(msg.clone(), Arc::clone(&connection))
737 .await;
738 }
739 self.events.emit(&WebsocketEvent::Message(msg));
740 }
741
742 async fn create_websocket(
765 url: &str,
766 agent: Option<AgentConnector>,
767 user_agent: Option<String>,
768 ) -> Result<WebSocketStream<MaybeTlsStream<tokio::net::TcpStream>>, WebsocketError> {
769 let mut req = url
770 .into_client_request()
771 .map_err(|e| WebsocketError::Handshake(e.to_string()))?;
772
773 if let Some(ua) = user_agent {
774 req.headers_mut().insert(USER_AGENT, ua.parse().unwrap());
775 }
776
777 let ws_config: Option<WebSocketConfig> = None;
778 let disable_nagle = false;
779 let connector: Option<Connector> = agent.map(|dbg| dbg.0);
780
781 let timeout_duration = Duration::from_secs(10);
782 let handshake = connect_async_tls_with_config(req, ws_config, disable_nagle, connector);
783 match timeout(timeout_duration, handshake).await {
784 Ok(Ok((ws_stream, response))) => {
785 debug!("WebSocket connected: {:?}", response);
786 Ok(ws_stream)
787 }
788 Ok(Err(e)) => {
789 let msg = e.to_string();
790 error!("WebSocket handshake failed: {}", msg);
791 Err(WebsocketError::Handshake(msg))
792 }
793 Err(_) => {
794 error!(
795 "WebSocket connection timed out after {}s",
796 timeout_duration.as_secs()
797 );
798 Err(WebsocketError::Timeout)
799 }
800 }
801 }
802
803 async fn connect_pool(
823 self: Arc<Self>,
824 url: &str,
825 connections: Option<Vec<Arc<WebsocketConnection>>>,
826 ) -> Result<(), WebsocketError> {
827 let pool: Vec<Arc<WebsocketConnection>> = match connections {
828 Some(v) => v,
829 None => self.connection_pool.clone(),
830 };
831
832 let mut tasks = FuturesUnordered::new();
833
834 for conn in pool {
835 let common = Arc::clone(&self);
836 let url = url.to_owned();
837
838 tasks.push(async move {
839 match common.init_connect(&url, false, Some(conn)).await {
840 Ok(()) => {
841 info!("Successfully connected to {}", url);
842 Ok(())
843 }
844 Err(err) => {
845 error!("Failed to connect to {}: {:?}", url, err);
846 Err(err)
847 }
848 }
849 });
850 }
851
852 while let Some(result) = tasks.next().await {
853 result?;
854 }
855
856 Ok(())
857 }
858
859 async fn init_connect(
880 self: Arc<Self>,
881 url: &str,
882 is_renewal: bool,
883 connection: Option<Arc<WebsocketConnection>>,
884 ) -> Result<(), WebsocketError> {
885 let conn = connection.unwrap_or(self.get_connection(true, None).await?);
886
887 {
888 let mut conn_state = conn.state.lock().await;
889 if conn_state.renewal_pending && is_renewal {
890 info!("Renewal in progress {}→{}", conn.id, url);
891 return Ok(());
892 }
893 if conn_state.ws_write_tx.is_some() && !is_renewal && !conn_state.reconnection_pending {
894 info!("Exists {}; skipping {}", conn.id, url);
895 return Ok(());
896 }
897 if is_renewal {
898 conn_state.renewal_pending = true;
899 }
900
901 conn_state.is_session_logged_on = false;
902 }
903
904 let ws = Self::create_websocket(url, self.agent.clone(), self.user_agent.clone())
905 .await
906 .map_err(|e| {
907 error!("Handshake failed {}: {:?}", url, e);
908 e
909 })?;
910
911 info!("Established {} → {}", conn.id, url);
912
913 if let Err(e) = self.renewal_tx.try_send((conn.id.clone(), url.to_string())) {
914 error!("Failed to schedule renewal for {}: {:?}", conn.id, e);
915 }
916
917 let (write_half, mut read_half) = ws.split();
918 let (tx, mut rx) = unbounded_channel::<Message>();
919
920 let old_writer = {
921 let mut conn_state = conn.state.lock().await;
922 conn_state.reconnection_pending = false;
923 conn_state.ws_write_tx.replace(tx.clone())
924 };
925
926 {
927 let wconn = conn.clone();
928 let common_clone = self.clone();
929 let writer_url = url.to_string();
930
931 spawn(async move {
932 let mut sink = write_half;
933 while let Some(msg) = rx.recv().await {
934 if let Err(e) = sink.send(msg).await {
935 let failure_reason =
936 WebsocketConnectionFailureReason::from_tungstenite_error(&e);
937
938 error!(
939 "Write error on {}: {:?}, classified as {:?}",
940 wconn.id, e, failure_reason
941 );
942
943 let mut conn_state = wconn.state.lock().await;
945 if !conn_state.close_initiated
946 && !is_renewal
947 && failure_reason.should_reconnect()
948 {
949 info!(
950 "Writer connection {} has recoverable error, attempting reconnection: {:?}",
951 wconn.id, failure_reason
952 );
953 conn_state.reconnection_pending = true;
954 conn_state.is_session_logged_on = false;
955 drop(conn_state);
956 let reconnect_url = common_clone
957 .get_reconnect_url(&writer_url, Arc::clone(&wconn))
958 .await;
959
960 let _ = common_clone
961 .reconnect_tx
962 .send(ReconnectEntry {
963 connection_id: wconn.id.clone(),
964 url: reconnect_url,
965 is_renewal: false,
966 })
967 .await;
968 } else {
969 warn!(
970 "Writer connection {} has permanent error, will not reconnect: {:?}",
971 wconn.id, failure_reason
972 );
973 }
974
975 break;
976 }
977 }
978 debug!("Writer {} exit", wconn.id);
979 });
980 }
981
982 {
983 let common = self.clone();
984 let conn = conn.clone();
985 let url = url.to_string();
986 spawn(async move {
987 common.on_open(url, conn, old_writer).await;
988 });
989 }
990
991 {
992 let common = self.clone();
993 let reader_conn = conn.clone();
994 let read_url = url.to_string();
995
996 spawn(async move {
997 let mut stream_end_reason = None;
998 while let Some(item) = read_half.next().await {
999 match item {
1000 Ok(Message::Text(msg)) => {
1001 common
1002 .on_message(msg.to_string(), Arc::clone(&reader_conn))
1003 .await;
1004 }
1005 Ok(Message::Binary(bin)) => {
1006 let mut decoder = ZlibDecoder::new(&bin[..]);
1007 let mut decompressed = String::new();
1008 if let Err(err) = decoder.read_to_string(&mut decompressed) {
1009 error!("Binary message decompress failed: {:?}", err);
1010 continue;
1011 }
1012 common
1013 .on_message(decompressed, Arc::clone(&reader_conn))
1014 .await;
1015 }
1016 Ok(Message::Ping(payload)) => {
1017 info!("PING received from server on {}", reader_conn.id);
1018 common.events.emit(&WebsocketEvent::Ping);
1019 if let Some(tx) = reader_conn.state.lock().await.ws_write_tx.clone() {
1020 let _ = tx.send(Message::Pong(payload));
1021 info!(
1022 "Responded PONG to server's PING message on {}",
1023 reader_conn.id
1024 );
1025 }
1026 }
1027 Ok(Message::Pong(_)) => {
1028 info!("Received PONG from server on {}", reader_conn.id);
1029 common.events.emit(&WebsocketEvent::Pong);
1030 }
1031 Ok(Message::Close(frame)) => {
1032 let (code, reason) = frame
1033 .map_or((1000, String::new()), |CloseFrame { code, reason }| {
1034 (code.into(), reason.to_string())
1035 });
1036 common
1037 .events
1038 .emit(&WebsocketEvent::Close(code, reason.clone()));
1039
1040 let user_initiated = {
1042 let conn_state = reader_conn.state.lock().await;
1043 conn_state.close_initiated
1044 };
1045
1046 let failure_reason = WebsocketConnectionFailureReason::from_close_code(
1047 code,
1048 user_initiated,
1049 );
1050 stream_end_reason = Some(failure_reason);
1051
1052 info!(
1053 "Connection {} received close frame: code={}, reason='{}', classified as {:?}",
1054 reader_conn.id, code, reason, failure_reason
1055 );
1056
1057 let mut conn_state = reader_conn.state.lock().await;
1058 if !conn_state.close_initiated
1059 && !is_renewal
1060 && failure_reason.should_reconnect()
1061 {
1062 info!(
1063 "Connection {} received close frame with reconnectable failure: {:?}",
1064 reader_conn.id, failure_reason
1065 );
1066 conn_state.reconnection_pending = true;
1067 conn_state.is_session_logged_on = false;
1068 drop(conn_state);
1069 let reconnect_url = common
1070 .get_reconnect_url(&read_url, Arc::clone(&reader_conn))
1071 .await;
1072
1073 let _ = common
1074 .reconnect_tx
1075 .send(ReconnectEntry {
1076 connection_id: reader_conn.id.clone(),
1077 url: reconnect_url,
1078 is_renewal: false,
1079 })
1080 .await;
1081 } else {
1082 warn!(
1083 "Connection {} received close frame with non-reconnectable failure: {:?}",
1084 reader_conn.id, failure_reason
1085 );
1086
1087 common.events.emit(&WebsocketEvent::Error(format!(
1089 "[CRITICAL] Connection {} permanently failed: {:?}",
1090 reader_conn.id, failure_reason
1091 )));
1092 }
1093
1094 break;
1095 }
1096 Err(e) => {
1097 let failure_reason =
1099 WebsocketConnectionFailureReason::from_tungstenite_error(&e);
1100
1101 stream_end_reason = Some(failure_reason);
1102 error!(
1103 "WebSocket error on {}: {:?}, classified as {:?}",
1104 reader_conn.id, e, failure_reason
1105 );
1106
1107 common.events.emit(&WebsocketEvent::Error(e.to_string()));
1108
1109 let mut conn_state = reader_conn.state.lock().await;
1111 if !conn_state.close_initiated
1112 && !is_renewal
1113 && failure_reason.should_reconnect()
1114 {
1115 info!(
1116 "Connection {} has recoverable error, attempting reconnection: {:?}",
1117 reader_conn.id, failure_reason
1118 );
1119 conn_state.reconnection_pending = true;
1120 conn_state.is_session_logged_on = false;
1121 drop(conn_state);
1122 let reconnect_url = common
1123 .get_reconnect_url(&read_url, Arc::clone(&reader_conn))
1124 .await;
1125
1126 let _ = common
1127 .reconnect_tx
1128 .send(ReconnectEntry {
1129 connection_id: reader_conn.id.clone(),
1130 url: reconnect_url,
1131 is_renewal: false,
1132 })
1133 .await;
1134 } else {
1135 warn!(
1136 "Connection {} has permanent error, will not reconnect: {:?}",
1137 reader_conn.id, failure_reason
1138 );
1139
1140 common.events.emit(&WebsocketEvent::Error(format!(
1142 "[CRITICAL] Connection {} permanently failed: {:?}",
1143 reader_conn.id, failure_reason
1144 )));
1145 }
1146
1147 break;
1148 }
1149 _ => {}
1150 }
1151 }
1152
1153 info!("WebSocket stream ended for connection {}", reader_conn.id);
1155
1156 let failure_reason =
1158 stream_end_reason.unwrap_or(WebsocketConnectionFailureReason::StreamEnded);
1159
1160 info!(
1161 "WebSocket stream ended for connection {}, classified as {:?}",
1162 reader_conn.id, failure_reason
1163 );
1164
1165 let mut conn_state = reader_conn.state.lock().await;
1166 if !conn_state.close_initiated && !is_renewal && failure_reason.should_reconnect() {
1167 info!(
1168 "Connection {} stream ended unexpectedly, attempting reconnection",
1169 reader_conn.id
1170 );
1171 conn_state.reconnection_pending = true;
1172 conn_state.is_session_logged_on = false;
1173 conn_state.ws_write_tx = None;
1174 drop(conn_state);
1175 let reconnect_url = common
1176 .get_reconnect_url(&read_url, Arc::clone(&reader_conn))
1177 .await;
1178
1179 let _ = common
1180 .reconnect_tx
1181 .send(ReconnectEntry {
1182 connection_id: reader_conn.id.clone(),
1183 url: reconnect_url,
1184 is_renewal: false,
1185 })
1186 .await;
1187 } else {
1188 debug!(
1189 "Connection {} stream ended normally (close_initiated={}, is_renewal={})",
1190 reader_conn.id, conn_state.close_initiated, is_renewal
1191 );
1192 }
1193
1194 debug!("Reader actor for {} exiting", reader_conn.id);
1195 });
1196 }
1197
1198 Ok(())
1199 }
1200 async fn disconnect(&self) -> Result<(), WebsocketError> {
1215 if !self.is_connected(None).await {
1216 warn!("No active connection to close.");
1217 return Ok(());
1218 }
1219
1220 let mut shutdowns = FuturesUnordered::new();
1221 for conn in &self.connection_pool {
1222 {
1223 let mut conn_state = conn.state.lock().await;
1224 conn_state.close_initiated = true;
1225 if let Some(tx) = &conn_state.ws_write_tx {
1226 shutdowns.push(self.close_connection_gracefully(tx.clone(), Arc::clone(conn)));
1227 }
1228 }
1229 }
1230
1231 let close_all = async {
1232 while let Some(result) = shutdowns.next().await {
1233 result?;
1234 }
1235 Ok::<(), WebsocketError>(())
1236 };
1237
1238 match timeout(Duration::from_secs(30), close_all).await {
1239 Ok(Ok(())) => {
1240 info!("Disconnected all WebSocket connections successfully.");
1241 for conn in &self.connection_pool {
1242 let mut st = conn.state.lock().await;
1243 st.is_session_logged_on = false;
1244 st.session_logon_req = None;
1245 }
1246 Ok(())
1247 }
1248 Ok(Err(err)) => {
1249 error!("Error while disconnecting: {:?}", err);
1250 Err(err)
1251 }
1252 Err(_) => {
1253 error!("Timed out while disconnecting WebSocket connections.");
1254 Err(WebsocketError::Timeout)
1255 }
1256 }
1257 }
1258
1259 async fn ping_server(&self) {
1272 let mut ready = Vec::new();
1273 for conn in &self.connection_pool {
1274 if self.is_connection_ready(conn, false).await {
1275 let id = conn.id.clone();
1276 let ws_write_tx = {
1277 let conn_state = conn.state.lock().await;
1278 conn_state.ws_write_tx.clone()
1279 };
1280 ready.push((id, ws_write_tx));
1281 }
1282 }
1283
1284 if ready.is_empty() {
1285 warn!("No ready connections for PING.");
1286 return;
1287 }
1288 info!("Sending PING to {} WebSocket connections.", ready.len());
1289
1290 let mut tasks = FuturesUnordered::new();
1291 for (id, ws_write_tx_opt) in ready {
1292 if let Some(tx) = ws_write_tx_opt {
1293 tasks.push(async move {
1294 if let Err(e) = tx.send(Message::Ping(Vec::new().into())) {
1295 error!("Failed to send PING to {}: {:?}", id, e);
1296 } else {
1297 debug!("Sent PING to connection {}", id);
1298 }
1299 });
1300 } else {
1301 error!("Connection {} was ready but has no write channel", id);
1302 }
1303 }
1304
1305 while tasks.next().await.is_some() {}
1306 }
1307
1308 async fn send(
1326 &self,
1327 payload: String,
1328 id: Option<String>,
1329 wait_for_reply: bool,
1330 timeout: Duration,
1331 connection: Option<Arc<WebsocketConnection>>,
1332 ) -> Result<Option<oneshot::Receiver<Result<Value, WebsocketError>>>, WebsocketError> {
1333 let conn = if let Some(c) = connection {
1334 c
1335 } else {
1336 self.get_connection(false, None).await?
1337 };
1338
1339 if !self.is_connected(Some(&conn)).await {
1340 warn!("Send attempted on a non-connected socket");
1341 return Err(WebsocketError::NotConnected);
1342 }
1343
1344 let ws_write_tx = {
1345 let conn_state = conn.state.lock().await;
1346 conn_state
1347 .ws_write_tx
1348 .clone()
1349 .ok_or(WebsocketError::NotConnected)?
1350 };
1351
1352 let pending_setup = if wait_for_reply {
1353 let request_id = id.ok_or_else(|| {
1354 error!("id is required when waiting for a reply");
1355 WebsocketError::NotConnected
1356 })?;
1357
1358 let (tx, rx) = oneshot::channel();
1359 {
1360 let mut conn_state = conn.state.lock().await;
1361 conn_state
1362 .pending_requests
1363 .insert(request_id.clone(), PendingRequest { completion: tx });
1364 }
1365
1366 Some((request_id, rx))
1367 } else {
1368 None
1369 };
1370
1371 debug!("Sending message to WebSocket on connection {}", conn.id);
1372
1373 if ws_write_tx
1374 .send(Message::Text(payload.clone().into()))
1375 .is_err()
1376 {
1377 if let Some((request_id, _)) = &pending_setup {
1378 let mut conn_state = conn.state.lock().await;
1379 conn_state.pending_requests.remove(request_id);
1380 }
1381 return Err(WebsocketError::NotConnected);
1382 }
1383
1384 let rx = if let Some((request_id, rx)) = pending_setup {
1385 let conn_clone = Arc::clone(&conn);
1386 let timeout_id = request_id.clone();
1387 spawn(async move {
1388 sleep(timeout).await;
1389 let mut conn_state = conn_clone.state.lock().await;
1390 if let Some(pending_req) = conn_state.pending_requests.remove(&timeout_id) {
1391 let _ = pending_req.completion.send(Err(WebsocketError::Timeout));
1392 }
1393 });
1394 Some(rx)
1395 } else {
1396 None
1397 };
1398
1399 Ok(rx)
1400 }
1401}
1402
1403#[derive(Debug, Default, Clone)]
1404pub struct WebsocketMessageSendOptions {
1405 pub with_api_key: bool,
1406 pub is_signed: bool,
1407 pub is_session_logon: Option<bool>,
1408 pub is_session_logout: Option<bool>,
1409}
1410
1411impl WebsocketMessageSendOptions {
1412 #[must_use]
1413 pub fn new() -> Self {
1414 Self::default()
1415 }
1416
1417 #[must_use]
1418 pub fn with_api_key(mut self) -> Self {
1419 self.with_api_key = true;
1420 self
1421 }
1422
1423 #[must_use]
1424 pub fn signed(mut self) -> Self {
1425 self.is_signed = true;
1426 self
1427 }
1428
1429 #[must_use]
1430 pub fn session_logon(mut self) -> Self {
1431 self.is_session_logon = Some(true);
1432 self
1433 }
1434
1435 #[must_use]
1436 pub fn session_logout(mut self) -> Self {
1437 self.is_session_logout = Some(true);
1438 self
1439 }
1440}
1441
1442#[derive(Debug)]
1443pub enum SendWebsocketMessageResult<R> {
1444 Single(WebsocketApiResponse<R>),
1445 Multiple(Vec<WebsocketApiResponse<R>>),
1446}
1447
1448impl<R> IntoIterator for SendWebsocketMessageResult<R> {
1449 type Item = WebsocketApiResponse<R>;
1450 type IntoIter = std::vec::IntoIter<Self::Item>;
1451
1452 fn into_iter(self) -> Self::IntoIter {
1453 match self {
1454 SendWebsocketMessageResult::Single(resp) => vec![resp].into_iter(),
1455 SendWebsocketMessageResult::Multiple(v) => v.into_iter(),
1456 }
1457 }
1458}
1459
1460pub struct WebsocketApi {
1461 pub common: Arc<WebsocketCommon>,
1462 configuration: ConfigurationWebsocketApi,
1463 is_connecting: Arc<Mutex<bool>>,
1464 stream_callbacks: Mutex<HashMap<String, Vec<Arc<dyn Fn(&Value) + Send + Sync + 'static>>>>,
1465}
1466
1467impl WebsocketApi {
1468 #[must_use]
1469 pub fn new(
1490 configuration: ConfigurationWebsocketApi,
1491 connection_pool: Vec<Arc<WebsocketConnection>>,
1492 ) -> Arc<Self> {
1493 let agent_clone = configuration.agent.clone();
1494 let user_agent_clone = configuration.user_agent.clone();
1495 let common = WebsocketCommon::new(
1496 connection_pool,
1497 configuration.mode.clone(),
1498 usize::try_from(configuration.reconnect_delay)
1499 .expect("reconnect_delay should fit in usize"),
1500 agent_clone,
1501 Some(user_agent_clone),
1502 );
1503
1504 Arc::new(Self {
1505 common: Arc::clone(&common),
1506 configuration,
1507 is_connecting: Arc::new(Mutex::new(false)),
1508 stream_callbacks: Mutex::new(HashMap::new()),
1509 })
1510 }
1511
1512 pub async fn connect(self: Arc<Self>) -> Result<(), WebsocketError> {
1534 if self.common.is_connected(None).await {
1535 info!("WebSocket connection already established");
1536 return Ok(());
1537 }
1538
1539 {
1540 let mut flag = self.is_connecting.lock().await;
1541 if *flag {
1542 info!("Already connecting...");
1543 return Ok(());
1544 }
1545 *flag = true;
1546 }
1547
1548 let url = self.prepare_url(self.configuration.ws_url.as_deref().unwrap_or_default());
1549
1550 let handler: Arc<dyn WebsocketHandler> = self.clone();
1551 for slot in &self.common.connection_pool {
1552 slot.set_handler(handler.clone()).await;
1553 }
1554
1555 let result = select! {
1556 () = sleep(Duration::from_secs(10)) => Err(WebsocketError::Timeout),
1557 r = self.common.clone().connect_pool(&url, None) => r,
1558 };
1559
1560 {
1561 let mut flag = self.is_connecting.lock().await;
1562 *flag = false;
1563 }
1564
1565 result
1566 }
1567
1568 pub async fn disconnect(&self) -> Result<(), WebsocketError> {
1581 self.common.disconnect().await
1582 }
1583
1584 pub async fn is_connected(&self) -> bool {
1590 self.common.is_connected(None).await
1591 }
1592
1593 pub async fn ping_server(&self) {
1598 self.common.ping_server().await;
1599 }
1600
1601 pub async fn send_message<R>(
1631 &self,
1632 method: &str,
1633 payload: BTreeMap<String, Value>,
1634 options: WebsocketMessageSendOptions,
1635 ) -> Result<SendWebsocketMessageResult<R>, WebsocketError>
1636 where
1637 R: DeserializeOwned + Send + Sync + 'static,
1638 {
1639 if !self.common.is_connected(None).await {
1640 return Err(WebsocketError::NotConnected);
1641 }
1642
1643 let do_multi =
1644 options.is_session_logon.unwrap_or(false) || options.is_session_logout.unwrap_or(false);
1645
1646 let connections = if do_multi {
1647 self.common.get_available_connections(false, None).await
1648 } else {
1649 vec![self.common.get_connection(false, None).await?]
1650 };
1651
1652 let skip_auth = if do_multi {
1653 false
1654 } else {
1655 let connection = &connections[0];
1656 let conn_state = connection.state.lock().await;
1657 self.configuration.auto_session_relogon && conn_state.is_session_logged_on
1658 };
1659
1660 let payload_clone = payload.clone();
1661
1662 let (id, request) =
1663 build_websocket_api_message(&self.configuration, method, payload, &options, skip_auth);
1664 let raw_payload = serde_json::to_string(&request).unwrap();
1665 debug!("Sending message to WebSocket API: {:?}", request);
1666
1667 let timeout = Duration::from_millis(self.configuration.timeout);
1668
1669 let mut receivers = Vec::with_capacity(connections.len());
1670 for connection in &connections {
1671 let opt_rx = self
1672 .common
1673 .send(
1674 raw_payload.clone(),
1675 Some(id.clone()),
1676 true,
1677 timeout,
1678 Some(connection.clone()),
1679 )
1680 .await?;
1681 receivers.push((connection.clone(), opt_rx));
1682 }
1683
1684 let mut raw_msgs = Vec::with_capacity(receivers.len());
1685 for (_conn, opt_rx) in receivers {
1686 let rx = opt_rx.ok_or(WebsocketError::NoResponse)?;
1687 let msg = rx.await.unwrap_or(Err(WebsocketError::Timeout))?;
1688 raw_msgs.push(msg);
1689 }
1690
1691 let mut responses = Vec::with_capacity(raw_msgs.len());
1692 for msg in raw_msgs {
1693 let raw = msg
1694 .get("result")
1695 .or_else(|| msg.get("response"))
1696 .cloned()
1697 .unwrap_or(Value::Null);
1698
1699 let rate_limits = msg
1700 .get("rateLimits")
1701 .and_then(Value::as_array)
1702 .map(|arr| {
1703 arr.iter()
1704 .filter_map(|v| serde_json::from_value(v.clone()).ok())
1705 .collect()
1706 })
1707 .unwrap_or_default();
1708
1709 responses.push(WebsocketApiResponse {
1710 raw,
1711 rate_limits,
1712 _marker: PhantomData,
1713 });
1714 }
1715
1716 if do_multi && self.configuration.auto_session_relogon {
1717 for connection in &connections {
1718 let mut state = connection.state.lock().await;
1719 if options.is_session_logon.unwrap_or(false) {
1720 state.is_session_logged_on = true;
1721 state.session_logon_req = Some(WebsocketSessionLogonReq {
1722 method: method.to_string(),
1723 payload: payload_clone.clone(),
1724 options: options.clone(),
1725 });
1726 } else {
1727 state.is_session_logged_on = false;
1728 state.session_logon_req = None;
1729 }
1730 }
1731 }
1732
1733 Ok(if responses.len() == 1 && !do_multi {
1734 SendWebsocketMessageResult::Single(responses.into_iter().next().unwrap())
1735 } else {
1736 SendWebsocketMessageResult::Multiple(responses)
1737 })
1738 }
1739
1740 fn prepare_url(&self, ws_url: &str) -> String {
1755 let mut url = ws_url.to_string();
1756
1757 let time_unit = match &self.configuration.time_unit {
1758 Some(u) => u.to_string(),
1759 None => return url,
1760 };
1761
1762 match validate_time_unit(&time_unit) {
1763 Ok(Some(validated)) => {
1764 let sep = if url.contains('?') { '&' } else { '?' };
1765 url.push(sep);
1766 url.push_str("timeUnit=");
1767 url.push_str(validated);
1768 }
1769 Ok(None) => {}
1770 Err(e) => {
1771 error!("Invalid time unit provided: {:?}", e);
1772 }
1773 }
1774
1775 url
1776 }
1777}
1778
1779#[async_trait]
1780impl WebsocketHandler for WebsocketApi {
1781 async fn on_open(&self, _url: String, connection: Arc<WebsocketConnection>) {
1799 let session_req = {
1800 let conn_state = connection.state.lock().await;
1801 conn_state.session_logon_req.clone()
1802 };
1803
1804 let Some(req) = session_req else {
1805 return;
1806 };
1807
1808 let already_logged_on = {
1809 let conn_state = connection.state.lock().await;
1810 conn_state.is_session_logged_on
1811 };
1812
1813 if already_logged_on {
1814 debug!(
1815 "Connection {} already logged on, skipping re-logon",
1816 connection.id
1817 );
1818 return;
1819 }
1820
1821 let conn = connection.clone();
1822 let common = Arc::clone(&self.common);
1823 let configuration = self.configuration.clone();
1824 let method = req.method.clone();
1825 let payload = req.payload.clone();
1826 let options = req.options.clone();
1827
1828 spawn(async move {
1829 let (id, json_msg) =
1830 build_websocket_api_message(&configuration, &method, payload, &options, false);
1831
1832 let raw_message = match serde_json::to_string(&json_msg) {
1833 Ok(msg) => msg,
1834 Err(e) => {
1835 warn!(
1836 "Failed to serialize session logon message for connection {}: {}",
1837 conn.id, e
1838 );
1839 return;
1840 }
1841 };
1842
1843 debug!(
1844 "Session re-logon on connection {}: {}",
1845 conn.id, raw_message
1846 );
1847
1848 let rx = match common
1849 .send(
1850 raw_message,
1851 Some(id.clone()),
1852 true,
1853 Duration::from_millis(configuration.timeout),
1854 Some(conn.clone()),
1855 )
1856 .await
1857 {
1858 Ok(Some(rx)) => rx,
1859 Ok(None) => {
1860 warn!(
1861 "Session re-logon dispatch returned None for connection {}",
1862 conn.id
1863 );
1864 return;
1865 }
1866 Err(e) => {
1867 warn!(
1868 "Session re-logon dispatch failed on connection {}: {}",
1869 conn.id, e
1870 );
1871 return;
1872 }
1873 };
1874
1875 let Ok(result) = timeout(Duration::from_millis(configuration.timeout), rx).await else {
1876 warn!("Session re-logon timed out on connection {}", conn.id);
1877 return;
1878 };
1879
1880 let final_result = match result {
1881 Ok(final_result) => final_result,
1882 Err(e) => {
1883 warn!(
1884 "Session re-logon receiver error on connection {}: {}",
1885 conn.id, e
1886 );
1887 return;
1888 }
1889 };
1890
1891 let payload = match final_result {
1892 Ok(payload) => payload,
1893 Err(e) => {
1894 warn!(
1895 "Session re-logon payload error on connection {}: {}",
1896 conn.id, e
1897 );
1898 return;
1899 }
1900 };
1901
1902 debug!(
1903 "Session re-logon succeeded on connection {}: {}",
1904 conn.id, payload
1905 );
1906 let mut conn_state = conn.state.lock().await;
1907 conn_state.is_session_logged_on = true;
1908 });
1909 }
1910
1911 async fn on_message(&self, data: String, connection: Arc<WebsocketConnection>) {
1931 let msg: Value = match serde_json::from_str(&data) {
1932 Ok(v) => v,
1933 Err(err) => {
1934 error!("Failed to parse WebSocket message {} – {}", data, err);
1935 return;
1936 }
1937 };
1938
1939 if let Some(id) = msg.get("id").and_then(Value::as_str) {
1940 let maybe_sender = {
1941 let mut conn_state = connection.state.lock().await;
1942 conn_state.pending_requests.remove(id)
1943 };
1944
1945 if let Some(PendingRequest { completion }) = maybe_sender {
1946 connection.drain_notify.notify_one();
1947 let status = msg.get("status").and_then(Value::as_u64).unwrap_or(200);
1948 if status >= 400 {
1949 let error_map = msg
1950 .get("error")
1951 .and_then(Value::as_object)
1952 .unwrap_or(&serde_json::Map::new())
1953 .clone();
1954
1955 let code = error_map
1956 .get("code")
1957 .and_then(Value::as_i64)
1958 .unwrap_or(status.try_into().unwrap());
1959
1960 let message = error_map
1961 .get("msg")
1962 .and_then(Value::as_str)
1963 .unwrap_or("Unknown error")
1964 .to_string();
1965
1966 let _ = completion.send(Err(WebsocketError::ResponseError { code, message }));
1967 } else {
1968 let _ = completion.send(Ok(msg.clone()));
1969 }
1970 }
1971
1972 return;
1973 }
1974
1975 if let Some(event) = msg.get("event") {
1976 if let Some(event_type) = event.get("e").and_then(Value::as_str) {
1977 if event_type == "serverShutdown" {
1978 warn!(
1979 "Received serverShutdown event on connection {}",
1980 connection.id
1981 );
1982
1983 let mut conn_state = connection.state.lock().await;
1984
1985 if !conn_state.renewal_pending && !conn_state.close_initiated {
1986 conn_state.renewal_pending = true;
1987
1988 let url = conn_state.url_path.clone().unwrap_or_default();
1989
1990 drop(conn_state);
1991
1992 if let Err(e) = self
1993 .common
1994 .reconnect_tx
1995 .send(ReconnectEntry {
1996 connection_id: connection.id.clone(),
1997 url,
1998 is_renewal: true,
1999 })
2000 .await
2001 {
2002 error!("Failed to enqueue serverShutdown renewal: {:?}", e);
2003 }
2004 }
2005
2006 return;
2007 }
2008 }
2009 }
2010
2011 if let Some(event) = msg.get("event") {
2012 if event.get("e").is_some() {
2013 for callbacks in self.stream_callbacks.lock().await.values() {
2014 for callback in callbacks {
2015 callback(event);
2016 }
2017 }
2018
2019 return;
2020 }
2021 }
2022
2023 warn!(
2024 "Received response for unknown or timed-out request: {}",
2025 data
2026 );
2027 }
2028
2029 async fn get_reconnect_url(
2040 &self,
2041 default_url: String,
2042 _connection: Arc<WebsocketConnection>,
2043 ) -> String {
2044 default_url
2045 }
2046}
2047
2048pub struct WebsocketStreams {
2049 pub common: Arc<WebsocketCommon>,
2050 pub stream_id_is_strictly_number: AtomicBool,
2051 url_paths: Vec<String>,
2052 is_connecting: Mutex<bool>,
2053 connection_streams: Mutex<HashMap<String, Arc<WebsocketConnection>>>,
2054 configuration: ConfigurationWebsocketStreams,
2055}
2056
2057impl WebsocketStreams {
2058 #[must_use]
2074 pub fn new(
2075 configuration: ConfigurationWebsocketStreams,
2076 mut connection_pool: Vec<Arc<WebsocketConnection>>,
2077 url_paths: Vec<String>,
2078 ) -> Arc<Self> {
2079 if !url_paths.is_empty() {
2080 let base_pool_size = configuration.mode.pool_size();
2081 let expected = base_pool_size * url_paths.len();
2082
2083 while connection_pool.len() < expected {
2084 connection_pool.push(WebsocketConnection::new(random_string()));
2085 }
2086 }
2087
2088 let agent_clone = configuration.agent.clone();
2089 let user_agent_clone = configuration.user_agent.clone();
2090 let common = WebsocketCommon::new(
2091 connection_pool,
2092 configuration.mode.clone(),
2093 usize::try_from(configuration.reconnect_delay)
2094 .expect("reconnect_delay should fit in usize"),
2095 agent_clone,
2096 Some(user_agent_clone),
2097 );
2098 Arc::new(Self {
2099 common,
2100 is_connecting: Mutex::new(false),
2101 connection_streams: Mutex::new(HashMap::new()),
2102 configuration,
2103 stream_id_is_strictly_number: AtomicBool::new(false),
2104 url_paths,
2105 })
2106 }
2107
2108 pub async fn connect(self: Arc<Self>, streams: Vec<String>) -> Result<(), WebsocketError> {
2125 if self.common.is_connected(None).await {
2126 info!("WebSocket connection already established");
2127 return Ok(());
2128 }
2129
2130 {
2131 let mut flag = self.is_connecting.lock().await;
2132 if *flag {
2133 info!("Already connecting...");
2134 return Ok(());
2135 }
2136 *flag = true;
2137 }
2138
2139 let handler: Arc<dyn WebsocketHandler> = self.clone();
2140 for conn in &self.common.connection_pool {
2141 conn.set_handler(handler.clone()).await;
2142 }
2143
2144 let base_pool_size = self.configuration.mode.pool_size();
2145
2146 let connect_fut = async {
2147 if self.url_paths.is_empty() {
2148 let url = self.prepare_url(&streams, None);
2149 self.common.clone().connect_pool(&url, None).await
2150 } else {
2151 let mut futures = Vec::with_capacity(self.url_paths.len());
2152
2153 for (i, path) in self.url_paths.iter().enumerate() {
2154 let start = i * base_pool_size;
2155
2156 let subset: Vec<Arc<WebsocketConnection>> = self
2157 .common
2158 .connection_pool
2159 .iter()
2160 .skip(start)
2161 .take(base_pool_size)
2162 .cloned()
2163 .collect();
2164
2165 if subset.len() != base_pool_size {
2166 return Err(WebsocketError::ServerError(format!(
2167 "connection_pool too small for url_paths: need {} per path, got {} for path index {}",
2168 base_pool_size,
2169 subset.len(),
2170 i
2171 )));
2172 }
2173
2174 for c in &subset {
2175 let mut st = c.state.lock().await;
2176 st.url_path = Some(path.clone());
2177 }
2178
2179 let url = self.prepare_url(&streams, Some(path.as_str()));
2180 let common = self.common.clone();
2181
2182 futures.push(async move { common.connect_pool(&url, Some(subset)).await });
2183 }
2184
2185 try_join_all(futures).await?;
2186 Ok(())
2187 }
2188 };
2189
2190 let connect_res = select! {
2191 () = sleep(Duration::from_secs(10)) => Err(WebsocketError::Timeout),
2192 r = connect_fut => r,
2193 };
2194
2195 {
2196 let mut flag = self.is_connecting.lock().await;
2197 *flag = false;
2198 }
2199
2200 connect_res
2201 }
2202
2203 pub async fn disconnect(&self) -> Result<(), WebsocketError> {
2219 for connection in &self.common.connection_pool {
2220 let mut conn_state = connection.state.lock().await;
2221 conn_state.stream_callbacks.clear();
2222 conn_state.pending_subscriptions.clear();
2223 }
2224 self.connection_streams.lock().await.clear();
2225 self.common.disconnect().await
2226 }
2227
2228 pub async fn is_connected(&self) -> bool {
2234 self.common.is_connected(None).await
2235 }
2236
2237 pub async fn ping_server(&self) {
2246 self.common.ping_server().await;
2247 }
2248
2249 pub async fn subscribe(
2269 self: Arc<Self>,
2270 streams: Vec<String>,
2271 id: Option<StreamId>,
2272 url_path: Option<&str>,
2273 ) {
2274 let streams: Vec<String> = {
2275 let map = self.connection_streams.lock().await;
2276 streams
2277 .into_iter()
2278 .filter(|s| {
2279 let key = self.stream_key(s, url_path);
2280 !map.contains_key(&key)
2281 })
2282 .collect()
2283 };
2284
2285 if streams.is_empty() {
2286 return;
2287 }
2288
2289 let connection_streams = self.handle_stream_assignment(streams, url_path).await;
2290
2291 for (conn, assigned_streams) in connection_streams {
2292 if !self.common.is_connected(Some(&conn)).await {
2293 info!(
2294 "Connection {} is not ready. Queuing subscription for streams: {:?}",
2295 conn.id, assigned_streams
2296 );
2297
2298 let mut conn_state = conn.state.lock().await;
2299 conn_state
2300 .pending_subscriptions
2301 .extend(assigned_streams.iter().cloned());
2302
2303 continue;
2304 }
2305
2306 self.send_subscription_payload(&conn, &assigned_streams, id.clone());
2307 }
2308 }
2309
2310 pub async fn unsubscribe(
2339 &self,
2340 streams: Vec<String>,
2341 id: Option<StreamId>,
2342 url_path: Option<&str>,
2343 ) {
2344 let request_id = normalize_stream_id(
2345 id.clone(),
2346 self.stream_id_is_strictly_number.load(Ordering::Relaxed),
2347 );
2348
2349 for stream in streams {
2350 let key = self.stream_key(&stream, url_path);
2351 let maybe_conn = { self.connection_streams.lock().await.get(&key).cloned() };
2352
2353 let conn = match maybe_conn {
2354 Some(c) => {
2355 if !self.common.is_connected(Some(&c)).await {
2356 warn!(
2357 "Stream {} not associated with an active connection.",
2358 stream
2359 );
2360 continue;
2361 }
2362 c
2363 }
2364 None => {
2365 warn!("Stream {} was not subscribed.", stream);
2366 continue;
2367 }
2368 };
2369
2370 let has_callbacks = {
2371 let conn_state = conn.state.lock().await;
2372 conn_state
2373 .stream_callbacks
2374 .get(&key)
2375 .is_some_and(|v| !v.is_empty())
2376 };
2377
2378 if has_callbacks {
2379 continue;
2380 }
2381
2382 let payload = json!({
2383 "method": "UNSUBSCRIBE",
2384 "params": [stream.clone()],
2385 "id": request_id,
2386 });
2387
2388 info!("UNSUBSCRIBE → {:?}", payload);
2389
2390 let common = Arc::clone(&self.common);
2391 let conn_clone = Arc::clone(&conn);
2392 let msg = serde_json::to_string(&payload).unwrap();
2393 spawn(async move {
2394 let _ = common
2395 .send(msg, None, false, Duration::ZERO, Some(conn_clone))
2396 .await;
2397 });
2398
2399 {
2400 let mut connection_streams = self.connection_streams.lock().await;
2401 connection_streams.remove(&key);
2402 }
2403 {
2404 let mut conn_state = conn.state.lock().await;
2405 conn_state.stream_callbacks.remove(&key);
2406 }
2407 }
2408 }
2409
2410 pub async fn is_subscribed(&self, stream: &str) -> bool {
2424 let map = self.connection_streams.lock().await;
2425
2426 if map.contains_key(stream) {
2427 return true;
2428 }
2429
2430 let suffix = format!("::{}", stream);
2431 map.keys().any(|k| k.ends_with(&suffix))
2432 }
2433
2434 fn stream_key(&self, stream: &str, url_path: Option<&str>) -> String {
2446 match url_path {
2447 Some(p) if !p.is_empty() => format!("{p}::{stream}"),
2448 _ => stream.to_string(),
2449 }
2450 }
2451
2452 fn prepare_url(&self, streams: &[String], url_path: Option<&str>) -> String {
2469 let mut url = format!(
2470 "{}/stream?streams={}",
2471 match url_path {
2472 Some(path) => format!(
2473 "{}/{}",
2474 self.configuration.ws_url.as_deref().unwrap_or(""),
2475 path
2476 ),
2477 None => self
2478 .configuration
2479 .ws_url
2480 .as_deref()
2481 .unwrap_or("")
2482 .to_string(),
2483 },
2484 streams.join("/")
2485 );
2486
2487 let time_unit = match &self.configuration.time_unit {
2488 Some(u) => u.to_string(),
2489 None => return url,
2490 };
2491
2492 match validate_time_unit(&time_unit) {
2493 Ok(Some(validated)) => {
2494 let sep = if url.contains('?') { '&' } else { '?' };
2495 url.push(sep);
2496 url.push_str("timeUnit=");
2497 url.push_str(validated);
2498 }
2499 Ok(None) => {}
2500 Err(e) => {
2501 error!("Invalid time unit provided: {:?}", e);
2502 }
2503 }
2504
2505 url
2506 }
2507
2508 async fn handle_stream_assignment(
2527 &self,
2528 streams: Vec<String>,
2529 url_path: Option<&str>,
2530 ) -> Vec<(Arc<WebsocketConnection>, Vec<String>)> {
2531 let mut connection_streams: Vec<(String, Arc<WebsocketConnection>)> = Vec::new();
2532
2533 for stream in streams {
2534 let key = self.stream_key(&stream, url_path);
2535
2536 let mut conn_opt = {
2537 let map = self.connection_streams.lock().await;
2538 map.get(&key).cloned()
2539 };
2540
2541 let need_new = if let Some(conn) = &conn_opt {
2542 let state = conn.state.lock().await;
2543 state.close_initiated || state.reconnection_pending
2544 } else {
2545 true
2546 };
2547
2548 if need_new {
2549 match self.common.get_connection(true, url_path).await {
2550 Ok(new_conn) => {
2551 let mut map = self.connection_streams.lock().await;
2552 map.insert(key.clone(), new_conn.clone());
2553 conn_opt = Some(new_conn);
2554 }
2555 Err(err) => {
2556 warn!(
2557 "No available WebSocket connection to subscribe stream `{}` (key `{}`): {:?}",
2558 stream, key, err
2559 );
2560 continue;
2561 }
2562 }
2563 }
2564
2565 if let Some(conn) = conn_opt {
2566 {
2567 let mut conn_state = conn.state.lock().await;
2568 conn_state.stream_callbacks.entry(key.clone()).or_default();
2569 }
2570 connection_streams.push((stream, conn));
2571 }
2572 }
2573
2574 let mut groups: Vec<(Arc<WebsocketConnection>, Vec<String>)> = Vec::new();
2575 for (stream, conn) in connection_streams {
2576 if let Some((_, vec)) = groups.iter_mut().find(|(c, _)| Arc::ptr_eq(c, &conn)) {
2577 vec.push(stream);
2578 } else {
2579 groups.push((conn, vec![stream]));
2580 }
2581 }
2582
2583 groups
2584 }
2585
2586 fn send_subscription_payload(
2599 &self,
2600 connection: &Arc<WebsocketConnection>,
2601 streams: &Vec<String>,
2602 id: Option<StreamId>,
2603 ) {
2604 let request_id = normalize_stream_id(
2605 id.clone(),
2606 self.stream_id_is_strictly_number.load(Ordering::Relaxed),
2607 );
2608
2609 let payload = json!({
2610 "method": "SUBSCRIBE",
2611 "params": streams,
2612 "id": request_id,
2613 });
2614
2615 info!("SUBSCRIBE → {:?}", payload);
2616
2617 let common = Arc::clone(&self.common);
2618 let msg = match serde_json::to_string(&payload) {
2619 Ok(s) => s,
2620 Err(e) => {
2621 error!("Failed to serialize SUBSCRIBE payload: {}", e);
2622 return;
2623 }
2624 };
2625 let conn_clone = Arc::clone(connection);
2626
2627 spawn(async move {
2628 let _ = common
2629 .send(msg, None, false, Duration::ZERO, Some(conn_clone))
2630 .await;
2631 });
2632 }
2633}
2634
2635#[async_trait]
2636impl WebsocketHandler for WebsocketStreams {
2637 async fn on_open(&self, _url: String, connection: Arc<WebsocketConnection>) {
2654 let pending_subs: Vec<String> = {
2655 let mut conn_state = connection.state.lock().await;
2656 take(&mut conn_state.pending_subscriptions)
2657 .into_iter()
2658 .collect()
2659 };
2660
2661 if !pending_subs.is_empty() {
2662 info!("Processing queued subscriptions for connection");
2663 self.send_subscription_payload(&connection, &pending_subs, None);
2664 }
2665 }
2666
2667 async fn on_message(&self, data: String, connection: Arc<WebsocketConnection>) {
2684 let msg: Value = match serde_json::from_str(&data) {
2685 Ok(v) => v,
2686 Err(err) => {
2687 error!(
2688 "Failed to parse WebSocket stream message {} – {}",
2689 data, err
2690 );
2691 return;
2692 }
2693 };
2694
2695 let (stream_name, payload) = match (
2696 msg.get("stream").and_then(Value::as_str),
2697 msg.get("data").cloned(),
2698 ) {
2699 (Some(name), Some(data)) => (name.to_string(), data),
2700 _ => return,
2701 };
2702
2703 let callbacks = {
2704 let conn_state = connection.state.lock().await;
2705 let key = self.stream_key(&stream_name, conn_state.url_path.as_deref());
2706 conn_state
2707 .stream_callbacks
2708 .get(&key)
2709 .cloned()
2710 .unwrap_or_else(Vec::new)
2711 };
2712
2713 for callback in callbacks {
2714 callback(&payload);
2715 }
2716 }
2717
2718 async fn get_reconnect_url(
2729 &self,
2730 _default_url: String,
2731 connection: Arc<WebsocketConnection>,
2732 ) -> String {
2733 let connection_streams = self.connection_streams.lock().await;
2734 let reconnect_streams = connection_streams
2735 .iter()
2736 .filter_map(|(key, conn_arc)| {
2737 if Arc::ptr_eq(conn_arc, &connection) {
2738 let stream = match key.split_once("::") {
2739 Some((_prefix, rest)) => rest.to_string(),
2740 None => key.clone(),
2741 };
2742 Some(stream)
2743 } else {
2744 None
2745 }
2746 })
2747 .collect::<Vec<_>>();
2748
2749 let url_path = {
2750 let st = connection.state.lock().await;
2751 st.url_path.as_deref().map(std::string::ToString::to_string)
2752 };
2753
2754 self.prepare_url(&reconnect_streams, url_path.as_deref())
2755 }
2756}
2757
2758pub struct WebsocketStream<T> {
2759 websocket_base: WebsocketBase,
2760 stream_or_id: String,
2761 url_path: Option<String>,
2762 callback: Mutex<Option<Arc<dyn Fn(&Value) + Send + Sync>>>,
2763 pub id: Option<StreamId>,
2764 _phantom: PhantomData<T>,
2765}
2766
2767impl<T> WebsocketStream<T>
2768where
2769 T: DeserializeOwned + Send + 'static,
2770{
2771 async fn on<F>(&self, event: &str, callback_fn: F)
2792 where
2793 F: Fn(T) + Send + Sync + 'static,
2794 {
2795 if event != "message" {
2796 return;
2797 }
2798
2799 let cb_wrapper: Arc<dyn Fn(&Value) + Send + Sync> =
2800 Arc::new(
2801 move |v: &Value| match serde_json::from_value::<T>(v.clone()) {
2802 Ok(data) => callback_fn(data),
2803 Err(e) => error!("Failed to deserialize stream payload: {:?}", e),
2804 },
2805 );
2806
2807 {
2808 let mut guard = self.callback.lock().await;
2809 *guard = Some(cb_wrapper.clone());
2810 }
2811
2812 match &self.websocket_base {
2813 WebsocketBase::WebsocketStreams(ws_streams) => {
2814 let key = ws_streams.stream_key(&self.stream_or_id, self.url_path.as_deref());
2815 let conn = {
2816 let map = ws_streams.connection_streams.lock().await;
2817 map.get(&key).cloned().expect("stream must be subscribed")
2818 };
2819
2820 {
2821 let mut conn_state = conn.state.lock().await;
2822 let entry = conn_state.stream_callbacks.entry(key).or_default();
2823
2824 if !entry
2825 .iter()
2826 .any(|existing| Arc::ptr_eq(existing, &cb_wrapper))
2827 {
2828 entry.push(cb_wrapper);
2829 }
2830 }
2831 }
2832 WebsocketBase::WebsocketApi(ws_api) => {
2833 let mut stream_callbacks = ws_api.stream_callbacks.lock().await;
2834 let entry = stream_callbacks
2835 .entry(self.stream_or_id.clone())
2836 .or_default();
2837
2838 if !entry
2839 .iter()
2840 .any(|existing| Arc::ptr_eq(existing, &cb_wrapper))
2841 {
2842 entry.push(cb_wrapper);
2843 }
2844 }
2845 }
2846 }
2847
2848 pub fn on_message<F>(self: &Arc<Self>, callback_fn: F)
2867 where
2868 T: Send + Sync,
2869 F: Fn(T) + Send + Sync + 'static,
2870 {
2871 let handler: Arc<Self> = Arc::clone(self);
2872
2873 std::thread::spawn(move || {
2874 let rt = tokio::runtime::Builder::new_current_thread()
2875 .enable_all()
2876 .build()
2877 .expect("failed to build Tokio runtime");
2878
2879 rt.block_on(handler.on("message", callback_fn));
2880 })
2881 .join()
2882 .expect("on_message thread panicked");
2883 }
2884
2885 pub async fn unsubscribe(&self) {
2900 let maybe_cb = {
2901 let mut guard = self.callback.lock().await;
2902 guard.take()
2903 };
2904
2905 if let Some(cb) = maybe_cb {
2906 match &self.websocket_base {
2907 WebsocketBase::WebsocketStreams(ws_streams) => {
2908 let key = ws_streams.stream_key(&self.stream_or_id, self.url_path.as_deref());
2909 let conn = {
2910 let map = ws_streams.connection_streams.lock().await;
2911 map.get(&key)
2912 .cloned()
2913 .expect("stream must have been subscribed")
2914 };
2915
2916 {
2917 let mut conn_state = conn.state.lock().await;
2918 if let Some(list) = conn_state.stream_callbacks.get_mut(&key) {
2919 list.retain(|existing| !Arc::ptr_eq(existing, &cb));
2920 }
2921 }
2922
2923 let stream = self.stream_or_id.clone();
2924 let id = self.id.clone();
2925 let url_path = self.url_path.clone();
2926 let websocket_streams_base = Arc::clone(ws_streams);
2927 spawn(async move {
2928 websocket_streams_base
2929 .unsubscribe(vec![stream], id, url_path.as_deref())
2930 .await;
2931 });
2932 }
2933 WebsocketBase::WebsocketApi(ws_api) => {
2934 let mut stream_callbacks = ws_api.stream_callbacks.lock().await;
2935 if let Some(list) = stream_callbacks.get_mut(&self.stream_or_id) {
2936 list.retain(|existing| !Arc::ptr_eq(existing, &cb));
2937 }
2938 }
2939 }
2940 }
2941 }
2942}
2943
2944pub async fn create_stream_handler<T>(
2959 websocket_base: WebsocketBase,
2960 stream_or_id: String,
2961 id: Option<StreamId>,
2962 url_path: Option<String>,
2963) -> Arc<WebsocketStream<T>>
2964where
2965 T: DeserializeOwned + Send + 'static,
2966{
2967 match &websocket_base {
2968 WebsocketBase::WebsocketStreams(ws_streams) => {
2969 ws_streams
2970 .clone()
2971 .subscribe(vec![stream_or_id.clone()], id.clone(), url_path.as_deref())
2972 .await;
2973 }
2974 WebsocketBase::WebsocketApi(_) => {}
2975 }
2976
2977 Arc::new(WebsocketStream {
2978 websocket_base,
2979 stream_or_id,
2980 url_path,
2981 id,
2982 callback: Mutex::new(None),
2983 _phantom: PhantomData,
2984 })
2985}
2986
2987#[cfg(test)]
2988mod tests {
2989 use crate::TOKIO_SHARED_RT;
2990 use crate::common::utils::{SignatureGenerator, build_user_agent};
2991 use crate::common::websocket::{
2992 PendingRequest, ReconnectEntry, SendWebsocketMessageResult, WebsocketApi, WebsocketBase,
2993 WebsocketCommon, WebsocketConnection, WebsocketEvent, WebsocketEventEmitter,
2994 WebsocketHandler, WebsocketMessageSendOptions, WebsocketMode, WebsocketSessionLogonReq,
2995 WebsocketStream, WebsocketStreams, create_stream_handler,
2996 };
2997 use crate::config::{ConfigurationWebsocketApi, ConfigurationWebsocketStreams, PrivateKey};
2998 use crate::errors::{WebsocketConnectionFailureReason, WebsocketError};
2999 use crate::models::{StreamId, TimeUnit};
3000 use async_trait::async_trait;
3001 use futures::{SinkExt, StreamExt};
3002 use http::header::USER_AGENT;
3003 use regex::Regex;
3004 use serde_json::{Value, json};
3005 use std::collections::{BTreeMap, HashSet};
3006 use std::marker::PhantomData;
3007 use std::net::SocketAddr;
3008 use std::sync::{
3009 Arc,
3010 atomic::{AtomicBool, AtomicUsize, Ordering},
3011 };
3012 use tokio::net::TcpListener;
3013 use tokio::sync::{
3014 Mutex,
3015 mpsc::{Receiver, unbounded_channel},
3016 oneshot,
3017 };
3018 use tokio::time::{Duration, advance, pause, resume, sleep, timeout};
3019 use tokio_tungstenite::{accept_async, accept_hdr_async, tungstenite, tungstenite::Message};
3020 use tungstenite::handshake::server::Request;
3021
3022 struct AbortOnDrop(tokio::task::JoinHandle<()>);
3028
3029 impl Drop for AbortOnDrop {
3030 fn drop(&mut self) {
3031 self.0.abort();
3032 }
3033 }
3034
3035 fn spawn_mock_ws_listener(listener: TcpListener) -> AbortOnDrop {
3039 AbortOnDrop(tokio::spawn(async move {
3040 if let Ok((stream, _)) = listener.accept().await {
3041 let Ok(mut ws) = accept_async(stream).await else {
3042 return;
3043 };
3044 while ws.next().await.is_some() {}
3045 }
3046 }))
3047 }
3048
3049 fn subscribe_events(common: &WebsocketCommon) -> Arc<Mutex<Vec<WebsocketEvent>>> {
3050 let events = Arc::new(Mutex::new(Vec::new()));
3051 let events_clone = events.clone();
3052 common.events.subscribe(move |event| {
3053 let events_clone = events_clone.clone();
3054 tokio::spawn(async move {
3055 events_clone.lock().await.push(event);
3056 });
3057 });
3058 events
3059 }
3060
3061 async fn create_connection(
3062 id: &str,
3063 has_writer: bool,
3064 reconnection_pending: bool,
3065 renewal_pending: bool,
3066 close_initiated: bool,
3067 ) -> Arc<WebsocketConnection> {
3068 let conn = WebsocketConnection::new(id);
3069 let mut st = conn.state.lock().await;
3070 st.reconnection_pending = reconnection_pending;
3071 st.renewal_pending = renewal_pending;
3072 st.close_initiated = close_initiated;
3073 if has_writer {
3074 let (tx, _) = unbounded_channel::<Message>();
3075 st.ws_write_tx = Some(tx);
3076 } else {
3077 st.ws_write_tx = None;
3078 }
3079 drop(st);
3080 conn
3081 }
3082
3083 fn create_websocket_api(
3084 time_unit: Option<TimeUnit>,
3085 mode: Option<WebsocketMode>,
3086 auto_session_relogon: Option<bool>,
3087 ) -> Arc<WebsocketApi> {
3088 let mode = mode.unwrap_or(WebsocketMode::Single);
3089 let auto_session_relogon = auto_session_relogon.unwrap_or(true);
3090 let sig_gen = SignatureGenerator::new(
3091 Some("api_secret".into()),
3092 None::<PrivateKey>,
3093 None::<String>,
3094 );
3095 let config = ConfigurationWebsocketApi {
3096 api_key: Some("api_key".into()),
3097 api_secret: Some("api_secret".into()),
3098 private_key: None,
3099 private_key_passphrase: None,
3100 ws_url: Some("wss://example.com".into()),
3101 mode,
3102 reconnect_delay: 1000,
3103 signature_gen: sig_gen,
3104 timeout: 500,
3105 time_unit,
3106 auto_session_relogon,
3107 agent: None,
3108 user_agent: build_user_agent("product"),
3109 };
3110 let conn1 = WebsocketConnection::new("c1");
3111 let conn2 = WebsocketConnection::new("c2");
3112 WebsocketApi::new(config, vec![conn1, conn2])
3113 }
3114
3115 fn create_websocket_streams(
3116 ws_url: Option<&str>,
3117 conns: Option<Vec<Arc<WebsocketConnection>>>,
3118 url_paths: Option<Vec<String>>,
3119 ) -> Arc<WebsocketStreams> {
3120 let mut connections: Vec<Arc<WebsocketConnection>> = vec![];
3121 let url_paths = url_paths.unwrap_or_default();
3122 if conns.is_none() {
3123 connections.push(WebsocketConnection::new("c1"));
3124 connections.push(WebsocketConnection::new("c2"));
3125 } else {
3126 connections = conns.expect("Expected connections to be set");
3127 }
3128 let config = ConfigurationWebsocketStreams {
3129 ws_url: Some(ws_url.unwrap_or("example.com").to_string()),
3130 mode: WebsocketMode::Single,
3131 reconnect_delay: 500,
3132 time_unit: None,
3133 agent: None,
3134 user_agent: build_user_agent("product"),
3135 };
3136 WebsocketStreams::new(config, connections, url_paths)
3137 }
3138
3139 fn subscribe_to_emitter(emitter: &WebsocketEventEmitter) -> Receiver<WebsocketEvent> {
3140 let (test_tx, test_rx) = tokio::sync::mpsc::channel(16);
3141 let _sub = emitter.subscribe(move |evt| {
3142 let _ = test_tx.try_send(evt);
3143 });
3144 test_rx
3145 }
3146
3147 async fn expect_websocket_event(rx: &mut Receiver<WebsocketEvent>) -> WebsocketEvent {
3148 timeout(Duration::from_millis(200), rx.recv())
3149 .await
3150 .expect("timed out waiting for event")
3151 .expect("subscriber channel closed")
3152 }
3153
3154 async fn eventually_async<F, Fut>(max_wait: Duration, mut f: F) -> bool
3155 where
3156 F: FnMut() -> Fut,
3157 Fut: std::future::Future<Output = bool>,
3158 {
3159 let start = tokio::time::Instant::now();
3160 while start.elapsed() < max_wait {
3161 if f().await {
3162 return true;
3163 }
3164 sleep(Duration::from_millis(20)).await;
3165 }
3166 false
3167 }
3168
3169 mod event_emitter {
3170 use super::*;
3171
3172 #[test]
3173 fn event_emitter_subscribe_and_emit() {
3174 TOKIO_SHARED_RT.block_on(async {
3175 let emitter = WebsocketEventEmitter::new();
3176 let (tx, rx) = oneshot::channel();
3177 let tx = Arc::new(std::sync::Mutex::new(Some(tx)));
3178 let tx_clone = tx.clone();
3179 let _sub = emitter.subscribe(move |event| {
3180 if let Some(sender) = tx_clone.lock().unwrap().take() {
3181 let _ = sender.send(event);
3182 }
3183 });
3184 emitter.emit(&WebsocketEvent::Open);
3185 let received = timeout(Duration::from_millis(100), rx)
3186 .await
3187 .expect("timed out");
3188 assert_eq!(received, Ok(WebsocketEvent::Open));
3189 });
3190 }
3191
3192 #[test]
3193 fn single_subscriber_gets_event() {
3194 TOKIO_SHARED_RT.block_on(async {
3195 let emitter = WebsocketEventEmitter::new();
3196 let mut rx = subscribe_to_emitter(&emitter);
3197
3198 let e1 = WebsocketEvent::Open;
3199 emitter.emit(&e1);
3200
3201 let got = expect_websocket_event(&mut rx).await;
3202 assert_eq!(got, e1);
3203 });
3204 }
3205
3206 #[test]
3207 fn multiple_subscribers_get_event() {
3208 TOKIO_SHARED_RT.block_on(async {
3209 let emitter = WebsocketEventEmitter::new();
3210 let mut rx1 = subscribe_to_emitter(&emitter);
3211 let mut rx2 = subscribe_to_emitter(&emitter);
3212
3213 let e = WebsocketEvent::Message("hello".into());
3214 emitter.emit(&e);
3215
3216 assert_eq!(expect_websocket_event(&mut rx1).await, e.clone());
3217 assert_eq!(expect_websocket_event(&mut rx2).await, e);
3218 });
3219 }
3220
3221 #[test]
3222 fn closed_subscribers_are_pruned() {
3223 TOKIO_SHARED_RT.block_on(async {
3224 let emitter = WebsocketEventEmitter::new();
3225 let rx1 = subscribe_to_emitter(&emitter);
3226 let mut rx2 = subscribe_to_emitter(&emitter);
3227 drop(rx1);
3228
3229 let e = WebsocketEvent::Pong;
3230 emitter.emit(&e);
3231
3232 assert_eq!(expect_websocket_event(&mut rx2).await, e);
3233 });
3234 }
3235
3236 #[test]
3237 fn prune_on_error_does_not_hang() {
3238 TOKIO_SHARED_RT.block_on(async {
3239 let emitter = WebsocketEventEmitter::new();
3240 let rx = subscribe_to_emitter(&emitter);
3241 drop(rx);
3242
3243 let e = WebsocketEvent::Close(1000, "bye".into());
3244 emitter.emit(&e);
3245 });
3246 }
3247 }
3248
3249 mod websocket_common {
3250 use super::*;
3251
3252 mod initialisation {
3253 use super::*;
3254
3255 #[test]
3256 fn single_mode() {
3257 TOKIO_SHARED_RT.block_on(async {
3258 let common = WebsocketCommon::new(vec![], WebsocketMode::Single, 0, None, None);
3259 assert_eq!(common.connection_pool.len(), 1);
3260 });
3261 }
3262
3263 #[test]
3264 fn pool_mode() {
3265 TOKIO_SHARED_RT.block_on(async {
3266 let common =
3267 WebsocketCommon::new(vec![], WebsocketMode::Pool(3), 0, None, None);
3268 assert_eq!(common.connection_pool.len(), 3);
3269 });
3270 }
3271 }
3272
3273 mod spawn_reconnect_loop {
3274 use super::*;
3275
3276 #[test]
3277 fn successful_reconnect_entry_triggers_init_connect() {
3278 TOKIO_SHARED_RT.block_on(async {
3279 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
3280 let addr = listener.local_addr().unwrap();
3281 tokio::spawn(async move {
3282 if let Ok((stream, _)) = listener.accept().await {
3283 let mut ws = accept_async(stream).await.unwrap();
3284 sleep(Duration::from_secs(5)).await;
3285 let _ = ws.close(None).await;
3286 }
3287 });
3288
3289 let conn = WebsocketConnection::new("c1");
3290 let common = WebsocketCommon::new(
3291 vec![conn.clone()],
3292 WebsocketMode::Single,
3293 10,
3294 None,
3295 None,
3296 );
3297 let url = format!("ws://{addr}");
3298 common
3299 .reconnect_tx
3300 .send(ReconnectEntry {
3301 connection_id: "c1".into(),
3302 url: url.clone(),
3303 is_renewal: false,
3304 })
3305 .await
3306 .unwrap();
3307
3308 let mut ok = false;
3309 for _ in 0..100 {
3310 if conn.state.lock().await.ws_write_tx.is_some() {
3311 ok = true;
3312 break;
3313 }
3314 sleep(Duration::from_millis(50)).await;
3315 }
3316 assert!(ok, "expected ws_write_tx to be Some after reconnect");
3317 });
3318 }
3319
3320 #[test]
3321 fn reconnect_entry_with_unknown_id_is_ignored() {
3322 TOKIO_SHARED_RT.block_on(async {
3323 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
3324 let addr = listener.local_addr().unwrap();
3325 tokio::spawn(async move {
3326 if let Ok((stream, _)) = listener.accept().await {
3327 let mut ws = accept_async(stream).await.unwrap();
3328 let _ = ws.close(None).await;
3329 }
3330 });
3331
3332 let conn = WebsocketConnection::new("c1");
3333 let common = WebsocketCommon::new(
3334 vec![conn.clone()],
3335 WebsocketMode::Single,
3336 5,
3337 None,
3338 None,
3339 );
3340 let url = format!("ws://{addr}");
3341 common
3342 .reconnect_tx
3343 .send(ReconnectEntry {
3344 connection_id: "other".into(),
3345 url,
3346 is_renewal: false,
3347 })
3348 .await
3349 .unwrap();
3350
3351 sleep(Duration::from_secs(1)).await;
3352
3353 let st = conn.state.lock().await;
3354 assert!(st.ws_write_tx.is_none());
3355 });
3356 }
3357
3358 #[test]
3359 fn renewal_entries_bypass_initial_delay() {
3360 TOKIO_SHARED_RT.block_on(async {
3361 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
3362 let addr = listener.local_addr().unwrap();
3363 tokio::spawn(async move {
3364 if let Ok((stream, _)) = listener.accept().await {
3365 let mut ws = accept_async(stream).await.unwrap();
3366 let _ = ws.close(None).await;
3367 }
3368 });
3369
3370 let conn = WebsocketConnection::new("renew");
3371 let common = WebsocketCommon::new(
3372 vec![conn.clone()],
3373 WebsocketMode::Single,
3374 200,
3375 None,
3376 None,
3377 );
3378 let url = format!("ws://{addr}");
3379 common
3380 .reconnect_tx
3381 .send(ReconnectEntry {
3382 connection_id: "renew".into(),
3383 url: url.clone(),
3384 is_renewal: true,
3385 })
3386 .await
3387 .unwrap();
3388
3389 sleep(Duration::from_secs(2)).await;
3390
3391 let st = conn.state.lock().await;
3392
3393 assert!(st.ws_write_tx.is_some());
3394 });
3395 }
3396
3397 #[test]
3398 fn non_renewal_entries_respect_initial_delay() {
3399 TOKIO_SHARED_RT.block_on(async {
3400 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
3401 let addr = listener.local_addr().unwrap();
3402 let _listener_guard = spawn_mock_ws_listener(listener);
3403
3404 let reconnect_delay_ms = 1500;
3405 let conn = WebsocketConnection::new("nonrenew");
3406 let common = WebsocketCommon::new(
3407 vec![conn.clone()],
3408 WebsocketMode::Single,
3409 reconnect_delay_ms,
3410 None,
3411 None,
3412 );
3413 let url = format!("ws://{addr}");
3414 let queued_at = tokio::time::Instant::now();
3415 common
3416 .reconnect_tx
3417 .send(ReconnectEntry {
3418 connection_id: "nonrenew".into(),
3419 url: url.clone(),
3420 is_renewal: false,
3421 })
3422 .await
3423 .unwrap();
3424
3425 sleep(Duration::from_millis(100)).await;
3426 if queued_at.elapsed()
3427 < Duration::from_millis(reconnect_delay_ms as u64)
3428 .saturating_sub(Duration::from_millis(200))
3429 {
3430 assert!(conn.state.lock().await.ws_write_tx.is_none());
3431 }
3432
3433 let mut ok = false;
3434 for _ in 0..200 {
3435 if conn.state.lock().await.ws_write_tx.is_some() {
3436 ok = true;
3437 break;
3438 }
3439 sleep(Duration::from_millis(50)).await;
3440 }
3441 assert!(
3442 ok,
3443 "expected ws_write_tx to be Some after reconnect delay elapsed"
3444 );
3445 });
3446 }
3447 }
3448
3449 mod spawn_renewal_loop {
3450 use super::*;
3451
3452 #[tokio::test]
3453 async fn scheduling_renewal_does_not_panic_for_known_connection() {
3454 pause();
3455
3456 let conn = WebsocketConnection::new("known");
3457 let common =
3458 WebsocketCommon::new(vec![conn.clone()], WebsocketMode::Single, 0, None, None);
3459 let url = "wss://example".to_string();
3460 common
3461 .renewal_tx
3462 .send((conn.id.clone(), url))
3463 .await
3464 .unwrap();
3465 advance(Duration::from_secs(23 * 60 * 60 + 1)).await;
3466 resume();
3467 }
3468
3469 #[tokio::test]
3470 async fn scheduling_renewal_ignored_for_unknown_connection() {
3471 pause();
3472
3473 let conn = WebsocketConnection::new("c1");
3474 let common =
3475 WebsocketCommon::new(vec![conn.clone()], WebsocketMode::Single, 0, None, None);
3476 common
3477 .renewal_tx
3478 .send(("other".into(), "u".into()))
3479 .await
3480 .unwrap();
3481 advance(Duration::from_secs(23 * 60 * 60 + 1)).await;
3482
3483 resume();
3484 }
3485 }
3486
3487 mod reconnect_regressions {
3488 use super::*;
3489
3490 #[test]
3491 fn init_connect_is_not_skipped_when_reconnection_pending() {
3492 TOKIO_SHARED_RT.block_on(async {
3493 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
3494 let addr = listener.local_addr().unwrap();
3495
3496 tokio::spawn(async move {
3497 if let Ok((stream, _)) = listener.accept().await {
3498 tokio::spawn(async move {
3499 let mut ws = accept_async(stream).await.unwrap();
3500 sleep(Duration::from_millis(500)).await;
3501 let _ = ws.close(None).await;
3502 });
3503 }
3504 });
3505
3506 let conn = WebsocketConnection::new("c-reconnect");
3507 {
3508 let mut st = conn.state.lock().await;
3509 let (tx, _) = unbounded_channel::<Message>();
3510 st.ws_write_tx = Some(tx);
3511 st.reconnection_pending = true;
3512 }
3513
3514 let common = WebsocketCommon::new(
3515 vec![conn.clone()],
3516 WebsocketMode::Single,
3517 0,
3518 None,
3519 None,
3520 );
3521
3522 let url = format!("ws://{addr}");
3523 common
3524 .clone()
3525 .init_connect(&url, false, Some(conn.clone()))
3526 .await
3527 .unwrap();
3528
3529 let ok = eventually_async(Duration::from_secs(2), || {
3530 let conn = conn.clone();
3531 async move {
3532 let st = conn.state.lock().await;
3533 st.ws_write_tx.is_some() && !st.reconnection_pending
3534 }
3535 })
3536 .await;
3537
3538 assert!(
3539 ok,
3540 "expected writer installed and reconnection_pending cleared"
3541 );
3542 });
3543 }
3544
3545 #[test]
3546 fn pending_request_is_resolved_on_socket_drop() {
3547 TOKIO_SHARED_RT.block_on(async {
3548 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
3549 let addr = listener.local_addr().unwrap();
3550
3551 tokio::spawn(async move {
3552 if let Ok((stream, _)) = listener.accept().await {
3553 let mut ws = accept_async(stream).await.unwrap();
3554 let _ = ws.next().await;
3555 let _ = ws.close(None).await;
3556 }
3557 });
3558
3559 let conn = WebsocketConnection::new("c1");
3560 let common = WebsocketCommon::new(
3561 vec![conn.clone()],
3562 WebsocketMode::Single,
3563 0,
3564 None,
3565 None,
3566 );
3567
3568 let url = format!("ws://{addr}");
3569 common
3570 .clone()
3571 .init_connect(&url, false, Some(conn.clone()))
3572 .await
3573 .unwrap();
3574
3575 let rx = common
3576 .send(
3577 "{\"id\":\"req-1\",\"method\":\"PING\"}".to_string(),
3578 Some("req-1".to_string()),
3579 true,
3580 Duration::from_millis(150),
3581 Some(conn.clone()),
3582 )
3583 .await
3584 .unwrap()
3585 .expect("expected oneshot receiver");
3586
3587 let res = timeout(Duration::from_secs(2), rx)
3588 .await
3589 .expect("did not resolve pending request")
3590 .expect("oneshot cancelled");
3591
3592 assert!(matches!(res, Err(WebsocketError::Timeout)));
3593
3594 let ok = eventually_async(Duration::from_secs(1), || {
3595 let conn = conn.clone();
3596 async move { conn.state.lock().await.pending_requests.is_empty() }
3597 })
3598 .await;
3599
3600 assert!(ok, "pending_requests should be drained");
3601 });
3602 }
3603 }
3604
3605 mod is_connection_ready {
3606 use super::*;
3607
3608 #[test]
3609 fn is_connection_ready() {
3610 TOKIO_SHARED_RT.block_on(async {
3611 let conn = WebsocketConnection::new("c1");
3612 let common = WebsocketCommon::new(
3613 vec![conn.clone()],
3614 WebsocketMode::Single,
3615 0,
3616 None,
3617 None,
3618 );
3619 assert!(!common.is_connection_ready(&conn, false).await);
3620 assert!(common.is_connection_ready(&conn, true).await);
3621 });
3622 }
3623
3624 #[test]
3625 fn connection_ready_basic() {
3626 TOKIO_SHARED_RT.block_on(async {
3627 let conn = create_connection("c1", true, false, false, false).await;
3628 let common = WebsocketCommon::new(
3629 vec![conn.clone()],
3630 WebsocketMode::Single,
3631 0,
3632 None,
3633 None,
3634 );
3635 assert!(common.is_connection_ready(&conn, false).await);
3636 });
3637 }
3638
3639 #[test]
3640 fn connection_not_ready_without_writer() {
3641 TOKIO_SHARED_RT.block_on(async {
3642 let conn = create_connection("c1", false, false, false, false).await;
3643 let common = WebsocketCommon::new(
3644 vec![conn.clone()],
3645 WebsocketMode::Single,
3646 0,
3647 None,
3648 None,
3649 );
3650 assert!(!common.is_connection_ready(&conn, false).await);
3651 assert!(common.is_connection_ready(&conn, true).await);
3652 });
3653 }
3654
3655 #[test]
3656 fn connection_not_ready_when_flagged() {
3657 TOKIO_SHARED_RT.block_on(async {
3658 let conn1 = create_connection("c1", true, true, false, false).await;
3659 let conn2 = create_connection("c2", true, false, true, false).await;
3660 let conn3 = create_connection("c3", true, false, false, true).await;
3661
3662 let common = WebsocketCommon::new(
3663 vec![conn1.clone(), conn2.clone(), conn3.clone()],
3664 WebsocketMode::Pool(3),
3665 0,
3666 None,
3667 None,
3668 );
3669
3670 assert!(!common.is_connection_ready(&conn1, false).await);
3671 assert!(common.is_connection_ready(&conn2, false).await);
3672 assert!(!common.is_connection_ready(&conn3, false).await);
3673 });
3674 }
3675 }
3676
3677 mod is_connected {
3678 use super::*;
3679
3680 #[test]
3681 fn with_pool_various_connections() {
3682 TOKIO_SHARED_RT.block_on(async {
3683 let conn_a = create_connection("a", true, false, false, false).await;
3684 let conn_b = create_connection("b", false, false, false, false).await;
3685 let conn_c = create_connection("c", true, true, false, false).await;
3686 let pool = vec![conn_a.clone(), conn_b.clone(), conn_c.clone()];
3687 let common = WebsocketCommon::new(pool, WebsocketMode::Pool(3), 0, None, None);
3688
3689 assert!(common.is_connected(None).await);
3690 assert!(common.is_connected(Some(&conn_a)).await);
3691 assert!(!common.is_connected(Some(&conn_b)).await);
3692 assert!(!common.is_connected(Some(&conn_c)).await);
3693 });
3694 }
3695
3696 #[test]
3697 fn with_pool_all_bad_connections() {
3698 TOKIO_SHARED_RT.block_on(async {
3699 let bad1 = create_connection("c1", false, false, false, false).await;
3700 let bad2 = create_connection("c2", true, true, false, false).await;
3701 let bad3 = create_connection("c3", true, false, false, true).await;
3702 let common = WebsocketCommon::new(
3703 vec![bad1, bad2, bad3],
3704 WebsocketMode::Pool(3),
3705 0,
3706 None,
3707 None,
3708 );
3709
3710 assert!(!common.is_connected(None).await);
3711 });
3712 }
3713
3714 #[test]
3715 fn with_pool_ignore_close_initiated() {
3716 TOKIO_SHARED_RT.block_on(async {
3717 let good = create_connection("c1", true, false, false, false).await;
3718 let closed = create_connection("c2", true, false, false, true).await;
3719 let bad = create_connection("c3", false, false, false, false).await;
3720 let common = WebsocketCommon::new(
3721 vec![closed.clone(), good.clone(), bad.clone()],
3722 WebsocketMode::Pool(3),
3723 0,
3724 None,
3725 None,
3726 );
3727
3728 assert!(common.is_connected(None).await);
3729 assert!(!common.is_connected(Some(&closed)).await);
3730 });
3731 }
3732 }
3733
3734 mod get_available_connections {
3735 use super::*;
3736
3737 #[test]
3738 fn single_mode() {
3739 TOKIO_SHARED_RT.block_on(async {
3740 let common = WebsocketCommon::new(vec![], WebsocketMode::Single, 0, None, None);
3741 let connections = common.get_available_connections(false, None).await;
3742 assert_eq!(connections[0].id, common.connection_pool[0].id);
3743 });
3744 }
3745
3746 #[test]
3747 fn single_mode_with_url_path_does_not_force_first_connection() {
3748 TOKIO_SHARED_RT.block_on(async {
3749 let conn1 = WebsocketConnection::new("c1");
3750 let conn2 = WebsocketConnection::new("c2");
3751
3752 let (tx2, _rx2) = unbounded_channel();
3753 {
3754 let mut s2 = conn2.state.lock().await;
3755 s2.ws_write_tx = Some(tx2);
3756 s2.url_path = Some("path1".to_string());
3757 }
3758
3759 {
3760 let mut s1 = conn1.state.lock().await;
3761 s1.url_path = Some("path1".to_string());
3762 }
3763
3764 let pool = vec![conn1.clone(), conn2.clone()];
3765 let common = WebsocketCommon::new(pool, WebsocketMode::Single, 0, None, None);
3766
3767 let connections = common.get_available_connections(false, Some("path1")).await;
3768
3769 assert_eq!(connections.len(), 1);
3770 assert_eq!(connections[0].id, "c2");
3771 });
3772 }
3773
3774 #[test]
3775 fn pool_mode_not_ready() {
3776 TOKIO_SHARED_RT.block_on(async {
3777 let common =
3778 WebsocketCommon::new(vec![], WebsocketMode::Pool(2), 0, None, None);
3779 let connections = common.get_available_connections(false, None).await;
3780 assert!(connections.is_empty());
3781 });
3782 }
3783
3784 #[test]
3785 fn pool_mode_with_ready() {
3786 TOKIO_SHARED_RT.block_on(async {
3787 let conn1 = WebsocketConnection::new("c1");
3788 let conn2 = WebsocketConnection::new("c2");
3789 let (tx1, _rx1) = unbounded_channel();
3790 {
3791 let mut s1 = conn1.state.lock().await;
3792 s1.ws_write_tx = Some(tx1);
3793 }
3794 let pool = vec![conn1.clone(), conn2.clone()];
3795 let common = WebsocketCommon::new(pool, WebsocketMode::Pool(2), 0, None, None);
3796 let connections = common.get_available_connections(false, None).await;
3797 assert!(connections.len() == 1);
3798 });
3799 }
3800 }
3801
3802 mod get_connection {
3803 use super::*;
3804
3805 #[test]
3806 fn single_mode() {
3807 TOKIO_SHARED_RT.block_on(async {
3808 let common = WebsocketCommon::new(vec![], WebsocketMode::Single, 0, None, None);
3809 let conn = common
3810 .get_connection(false, None)
3811 .await
3812 .expect("should get connection");
3813 assert_eq!(conn.id, common.connection_pool[0].id);
3814 });
3815 }
3816
3817 #[test]
3818 fn single_mode_with_url_path_selects_matching_ready_connection() {
3819 TOKIO_SHARED_RT.block_on(async {
3820 let conn1 = WebsocketConnection::new("c1");
3821 let conn2 = WebsocketConnection::new("c2");
3822
3823 let (tx1, _rx1) = unbounded_channel();
3824 {
3825 let mut s1 = conn1.state.lock().await;
3826 s1.ws_write_tx = Some(tx1);
3827 s1.url_path = Some("path2".to_string());
3828 }
3829
3830 let (tx2, _rx2) = unbounded_channel();
3831 {
3832 let mut s2 = conn2.state.lock().await;
3833 s2.ws_write_tx = Some(tx2);
3834 s2.url_path = Some("path1".to_string());
3835 }
3836
3837 let pool = vec![conn1.clone(), conn2.clone()];
3838 let common = WebsocketCommon::new(pool, WebsocketMode::Single, 0, None, None);
3839
3840 let chosen = common
3841 .get_connection(false, Some("path1"))
3842 .await
3843 .expect("should get connection");
3844
3845 assert_eq!(chosen.id, "c2");
3846 });
3847 }
3848
3849 #[test]
3850 fn pool_mode_not_ready() {
3851 TOKIO_SHARED_RT.block_on(async {
3852 let common =
3853 WebsocketCommon::new(vec![], WebsocketMode::Pool(2), 0, None, None);
3854 let result = common.get_connection(false, None).await;
3855 assert!(matches!(
3856 result,
3857 Err(crate::errors::WebsocketError::NotConnected)
3858 ));
3859 });
3860 }
3861
3862 #[test]
3863 fn pool_mode_with_ready() {
3864 TOKIO_SHARED_RT.block_on(async {
3865 let conn1 = WebsocketConnection::new("c1");
3866 let conn2 = WebsocketConnection::new("c2");
3867 let (tx1, _rx1) = unbounded_channel();
3868 {
3869 let mut s1 = conn1.state.lock().await;
3870 s1.ws_write_tx = Some(tx1);
3871 }
3872 let pool = vec![conn1.clone(), conn2.clone()];
3873 let common = WebsocketCommon::new(pool, WebsocketMode::Pool(2), 0, None, None);
3874 let result = common.get_connection(false, None).await;
3875 assert!(result.is_ok());
3876 let chosen = result.unwrap();
3877 assert_eq!(chosen.id, conn1.id);
3878 });
3879 }
3880
3881 #[test]
3882 fn pool_mode_with_url_path_filters_connections() {
3883 TOKIO_SHARED_RT.block_on(async {
3884 let conn1 = WebsocketConnection::new("c1");
3885 let conn2 = WebsocketConnection::new("c2");
3886
3887 let (tx1, _rx1) = unbounded_channel();
3888 {
3889 let mut s1 = conn1.state.lock().await;
3890 s1.ws_write_tx = Some(tx1);
3891 s1.url_path = Some("path1".to_string());
3892 }
3893
3894 let (tx2, _rx2) = unbounded_channel();
3895 {
3896 let mut s2 = conn2.state.lock().await;
3897 s2.ws_write_tx = Some(tx2);
3898 s2.url_path = Some("path2".to_string());
3899 }
3900
3901 let pool = vec![conn1.clone(), conn2.clone()];
3902 let common = WebsocketCommon::new(pool, WebsocketMode::Pool(2), 0, None, None);
3903
3904 let chosen = common
3905 .get_connection(false, Some("path2"))
3906 .await
3907 .expect("should pick ready connection for path2");
3908
3909 assert_eq!(chosen.id, "c2");
3910 });
3911 }
3912
3913 #[test]
3914 fn url_path_no_match_returns_not_connected() {
3915 TOKIO_SHARED_RT.block_on(async {
3916 let conn1 = WebsocketConnection::new("c1");
3917 let (tx1, _rx1) = unbounded_channel();
3918 {
3919 let mut s1 = conn1.state.lock().await;
3920 s1.ws_write_tx = Some(tx1);
3921 s1.url_path = Some("path1".to_string());
3922 }
3923
3924 let pool = vec![conn1.clone()];
3925 let common = WebsocketCommon::new(pool, WebsocketMode::Pool(1), 0, None, None);
3926
3927 let result = common.get_connection(false, Some("path2")).await;
3928 assert!(matches!(
3929 result,
3930 Err(crate::errors::WebsocketError::NotConnected)
3931 ));
3932 });
3933 }
3934 }
3935
3936 mod close_connection_gracefully {
3937 use super::*;
3938
3939 #[tokio::test]
3940 async fn waits_for_pending_requests_then_closes() {
3941 pause();
3942
3943 let conn = WebsocketConnection::new("c1");
3944 let (tx, mut rx) = unbounded_channel::<Message>();
3945 let (req_tx, _req_rx) = oneshot::channel();
3946 {
3947 let mut st = conn.state.lock().await;
3948 st.pending_requests
3949 .insert("r".to_string(), PendingRequest { completion: req_tx });
3950 }
3951 let common =
3952 WebsocketCommon::new(vec![conn.clone()], WebsocketMode::Single, 0, None, None);
3953 let close_fut = common.close_connection_gracefully(tx.clone(), conn.clone());
3954 advance(Duration::from_secs(1)).await;
3955 {
3956 let mut st = conn.state.lock().await;
3957 st.pending_requests.clear();
3958 }
3959 conn.drain_notify.notify_waiters();
3960 advance(Duration::from_secs(1)).await;
3961 close_fut.await.unwrap();
3962 match rx.try_recv() {
3963 Ok(Message::Close(_)) => {}
3964 other => panic!("expected Close, got {other:?}"),
3965 }
3966
3967 resume();
3968 }
3969
3970 #[tokio::test]
3971 async fn force_closes_after_timeout() {
3972 pause();
3973
3974 let conn = WebsocketConnection::new("c2");
3975 let (tx, mut rx) = unbounded_channel::<Message>();
3976 let (req_tx, _req_rx) = oneshot::channel();
3977 {
3978 let mut st = conn.state.lock().await;
3979 st.pending_requests.insert(
3980 "request_id".to_string(),
3981 PendingRequest { completion: req_tx },
3982 );
3983 }
3984 let common =
3985 WebsocketCommon::new(vec![conn.clone()], WebsocketMode::Single, 0, None, None);
3986 let close_fut = common.close_connection_gracefully(tx.clone(), conn.clone());
3987 advance(Duration::from_secs(30)).await;
3988 close_fut.await.unwrap();
3989 match rx.try_recv() {
3990 Ok(Message::Close(_)) => {}
3991 other => panic!("expected Close on timeout, got {other:?}"),
3992 }
3993
3994 resume();
3995 }
3996 }
3997
3998 mod get_reconnect_url {
3999 use super::*;
4000
4001 struct DummyHandler {
4002 url: String,
4003 }
4004
4005 #[async_trait::async_trait]
4006 impl WebsocketHandler for DummyHandler {
4007 async fn on_open(&self, _url: String, _connection: Arc<WebsocketConnection>) {}
4008 async fn on_message(&self, _data: String, _connection: Arc<WebsocketConnection>) {}
4009 async fn get_reconnect_url(
4010 &self,
4011 _default_url: String,
4012 _connection: Arc<WebsocketConnection>,
4013 ) -> String {
4014 self.url.clone()
4015 }
4016 }
4017
4018 #[test]
4019 fn returns_default_when_no_handler() {
4020 TOKIO_SHARED_RT.block_on(async {
4021 let conn = WebsocketConnection::new("c1");
4022 let common = WebsocketCommon::new(
4023 vec![conn.clone()],
4024 WebsocketMode::Single,
4025 0,
4026 None,
4027 None,
4028 );
4029 let default = "wss://default".to_string();
4030 let result = common.get_reconnect_url(&default, conn.clone()).await;
4031 assert_eq!(result, default);
4032 });
4033 }
4034
4035 #[test]
4036 fn returns_handler_url_when_set() {
4037 TOKIO_SHARED_RT.block_on(async {
4038 let conn = WebsocketConnection::new("c2");
4039 let handler = Arc::new(DummyHandler {
4040 url: "wss://custom".into(),
4041 });
4042 conn.set_handler(handler).await;
4043 let common = WebsocketCommon::new(
4044 vec![conn.clone()],
4045 WebsocketMode::Single,
4046 0,
4047 None,
4048 None,
4049 );
4050 let default = "wss://default".to_string();
4051 let result = common.get_reconnect_url(&default, conn.clone()).await;
4052 assert_eq!(result, "wss://custom");
4053 });
4054 }
4055 }
4056
4057 mod on_open {
4058 use super::*;
4059
4060 struct DummyHandler {
4061 called: Arc<Mutex<bool>>,
4062 opened_url: Arc<Mutex<Option<String>>>,
4063 }
4064
4065 #[async_trait]
4066 impl WebsocketHandler for DummyHandler {
4067 async fn on_open(&self, url: String, _connection: Arc<WebsocketConnection>) {
4068 let mut flag = self.called.lock().await;
4069 *flag = true;
4070 let mut store = self.opened_url.lock().await;
4071 *store = Some(url);
4072 }
4073 async fn on_message(&self, _data: String, _connection: Arc<WebsocketConnection>) {}
4074 async fn get_reconnect_url(
4075 &self,
4076 default_url: String,
4077 _connection: Arc<WebsocketConnection>,
4078 ) -> String {
4079 default_url
4080 }
4081 }
4082
4083 #[test]
4084 fn emits_open_and_calls_handler() {
4085 TOKIO_SHARED_RT.block_on(async {
4086 let conn = WebsocketConnection::new("c1");
4087 let called = Arc::new(Mutex::new(false));
4088 let opened_url = Arc::new(Mutex::new(None));
4089 let handler = Arc::new(DummyHandler {
4090 called: called.clone(),
4091 opened_url: opened_url.clone(),
4092 });
4093
4094 conn.set_handler(handler.clone()).await;
4095 let common = WebsocketCommon::new(
4096 vec![conn.clone()],
4097 WebsocketMode::Single,
4098 0,
4099 None,
4100 None,
4101 );
4102 let events = subscribe_events(&common);
4103 common
4104 .on_open("wss://example.com".into(), conn.clone(), None)
4105 .await;
4106
4107 sleep(std::time::Duration::from_millis(10)).await;
4108
4109 let evs = events.lock().await;
4110 assert!(evs.iter().any(|e| matches!(e, WebsocketEvent::Open)));
4111 assert!(*called.lock().await);
4112 assert_eq!(
4113 opened_url.lock().await.as_deref(),
4114 Some("wss://example.com")
4115 );
4116 });
4117 }
4118
4119 #[test]
4120 fn handles_renewal_pending_and_closes_old_writer() {
4121 TOKIO_SHARED_RT.block_on(async {
4122 let conn = WebsocketConnection::new("c2");
4123 let (old_tx, mut old_rx) = unbounded_channel::<Message>();
4124 {
4125 let mut st = conn.state.lock().await;
4126 st.renewal_pending = true;
4127 }
4128 let common = WebsocketCommon::new(
4129 vec![conn.clone()],
4130 WebsocketMode::Single,
4131 0,
4132 None,
4133 None,
4134 );
4135 common
4136 .on_open("url".into(), conn.clone(), Some(old_tx.clone()))
4137 .await;
4138 assert!(!conn.state.lock().await.renewal_pending);
4139 match old_rx.try_recv() {
4140 Ok(Message::Close(_)) => {}
4141 other => panic!("expected Close, got {other:?}"),
4142 }
4143 });
4144 }
4145 }
4146
4147 mod on_message {
4148 use super::*;
4149
4150 struct DummyHandler {
4151 called_with: Arc<Mutex<Vec<String>>>,
4152 }
4153
4154 #[async_trait]
4155 impl WebsocketHandler for DummyHandler {
4156 async fn on_open(&self, _url: String, _connection: Arc<WebsocketConnection>) {}
4157 async fn on_message(&self, data: String, _connection: Arc<WebsocketConnection>) {
4158 self.called_with.lock().await.push(data);
4159 }
4160 async fn get_reconnect_url(
4161 &self,
4162 default_url: String,
4163 _connection: Arc<WebsocketConnection>,
4164 ) -> String {
4165 default_url
4166 }
4167 }
4168
4169 #[test]
4170 fn emits_message_event_without_handler() {
4171 TOKIO_SHARED_RT.block_on(async {
4172 let conn = WebsocketConnection::new("c1");
4173 let common = WebsocketCommon::new(
4174 vec![conn.clone()],
4175 WebsocketMode::Single,
4176 0,
4177 None,
4178 None,
4179 );
4180 let events = subscribe_events(&common);
4181 common.on_message("msg".into(), conn.clone()).await;
4182
4183 sleep(Duration::from_millis(10)).await;
4184
4185 let locked = events.lock().await;
4186 assert!(
4187 locked
4188 .iter()
4189 .any(|e| matches!(e, WebsocketEvent::Message(m) if m == "msg"))
4190 );
4191 });
4192 }
4193
4194 #[test]
4195 fn calls_handler_and_emits_message() {
4196 TOKIO_SHARED_RT.block_on(async {
4197 let conn = WebsocketConnection::new("c2");
4198 let called = Arc::new(Mutex::new(Vec::new()));
4199 let handler = Arc::new(DummyHandler {
4200 called_with: called.clone(),
4201 });
4202 conn.set_handler(handler.clone()).await;
4203
4204 let common = WebsocketCommon::new(
4205 vec![conn.clone()],
4206 WebsocketMode::Single,
4207 0,
4208 None,
4209 None,
4210 );
4211 let events = subscribe_events(&common);
4212 common.on_message("msg".into(), conn.clone()).await;
4213
4214 sleep(Duration::from_millis(10)).await;
4215
4216 let evs = events.lock().await;
4217 assert!(
4218 evs.iter()
4219 .any(|e| matches!(e, WebsocketEvent::Message(m) if m == "msg"))
4220 );
4221 let msgs = called.lock().await;
4222 assert_eq!(msgs.as_slice(), &["msg".to_string()]);
4223 });
4224 }
4225
4226 #[test]
4227 fn preserves_message_order() {
4228 TOKIO_SHARED_RT.block_on(async {
4229 let conn = WebsocketConnection::new("c3");
4230 let called = Arc::new(Mutex::new(Vec::new()));
4231 let handler = Arc::new(DummyHandler {
4232 called_with: called.clone(),
4233 });
4234 conn.set_handler(handler.clone()).await;
4235
4236 let common = WebsocketCommon::new(
4237 vec![conn.clone()],
4238 WebsocketMode::Single,
4239 0,
4240 None,
4241 None,
4242 );
4243
4244 for i in 0..20 {
4245 common.on_message(format!("msg_{i}"), conn.clone()).await;
4246 }
4247
4248 let msgs = called.lock().await;
4249 let expected: Vec<String> = (0..20).map(|i| format!("msg_{i}")).collect();
4250 assert_eq!(msgs.as_slice(), expected.as_slice());
4251 });
4252 }
4253
4254 #[test]
4255 fn preserves_order_with_slow_handler() {
4256 use std::sync::atomic::{AtomicU32, Ordering};
4257
4258 struct SlowHandler {
4259 received: Arc<Mutex<Vec<String>>>,
4260 concurrent_count: Arc<AtomicU32>,
4261 max_concurrent: Arc<AtomicU32>,
4262 }
4263
4264 #[async_trait]
4265 impl WebsocketHandler for SlowHandler {
4266 async fn on_open(&self, _url: String, _connection: Arc<WebsocketConnection>) {}
4267 async fn on_message(
4268 &self,
4269 data: String,
4270 _connection: Arc<WebsocketConnection>,
4271 ) {
4272 let prev = self.concurrent_count.fetch_add(1, Ordering::SeqCst);
4273 self.max_concurrent.fetch_max(prev + 1, Ordering::SeqCst);
4274 sleep(Duration::from_millis(1)).await;
4275 self.received.lock().await.push(data);
4276 self.concurrent_count.fetch_sub(1, Ordering::SeqCst);
4277 }
4278 async fn get_reconnect_url(
4279 &self,
4280 default_url: String,
4281 _connection: Arc<WebsocketConnection>,
4282 ) -> String {
4283 default_url
4284 }
4285 }
4286
4287 TOKIO_SHARED_RT.block_on(async {
4288 let conn = WebsocketConnection::new("c4");
4289 let received = Arc::new(Mutex::new(Vec::new()));
4290 let concurrent_count = Arc::new(AtomicU32::new(0));
4291 let max_concurrent = Arc::new(AtomicU32::new(0));
4292 let handler = Arc::new(SlowHandler {
4293 received: received.clone(),
4294 concurrent_count: concurrent_count.clone(),
4295 max_concurrent: max_concurrent.clone(),
4296 });
4297 conn.set_handler(handler.clone()).await;
4298
4299 let common = WebsocketCommon::new(
4300 vec![conn.clone()],
4301 WebsocketMode::Single,
4302 0,
4303 None,
4304 None,
4305 );
4306
4307 for i in 0..10 {
4308 common.on_message(format!("msg_{i}"), conn.clone()).await;
4309 }
4310
4311 let msgs = received.lock().await;
4312 let expected: Vec<String> = (0..10).map(|i| format!("msg_{i}")).collect();
4313 assert_eq!(msgs.as_slice(), expected.as_slice());
4314 assert_eq!(
4315 max_concurrent.load(Ordering::SeqCst),
4316 1,
4317 "messages must be processed sequentially, not concurrently"
4318 );
4319 });
4320 }
4321 }
4322
4323 mod create_websocket {
4324 use super::*;
4325
4326 #[test]
4327 fn successful_connection() {
4328 TOKIO_SHARED_RT.block_on(async {
4329 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
4330 let addr: SocketAddr = listener.local_addr().unwrap();
4331
4332 let expected_ua = build_user_agent("product");
4333 let expected_ua_clone = expected_ua.clone();
4334
4335 tokio::spawn(async move {
4336 if let Ok((stream, _)) = listener.accept().await {
4337 let callback = |req: &Request, resp| {
4338 let got = req
4339 .headers()
4340 .get(USER_AGENT)
4341 .expect("no USER_AGENT header in WS handshake")
4342 .to_str()
4343 .expect("invalid USER_AGENT header");
4344 assert_eq!(got, expected_ua_clone, "User-Agent mismatch");
4345 Ok(resp)
4346 };
4347 let _ = accept_hdr_async(stream, callback).await.unwrap();
4348 }
4349 });
4350
4351 let url = format!("ws://{addr}");
4352 let res =
4353 WebsocketCommon::create_websocket(&url, None, Some(expected_ua)).await;
4354 assert!(res.is_ok(), "handshake failed: {res:?}");
4355 });
4356 }
4357
4358 #[test]
4359 fn invalid_url_returns_handshake_error() {
4360 TOKIO_SHARED_RT.block_on(async {
4361 let res =
4362 WebsocketCommon::create_websocket("not-a-valid-url", None, None).await;
4363 assert!(matches!(res, Err(WebsocketError::Handshake(_))));
4364 });
4365 }
4366
4367 #[test]
4368 fn unreachable_host_returns_handshake_error() {
4369 TOKIO_SHARED_RT.block_on(async {
4370 let res =
4371 WebsocketCommon::create_websocket("ws://127.0.0.1:1", None, None).await;
4372 assert!(matches!(res, Err(WebsocketError::Handshake(_))));
4373 });
4374 }
4375 }
4376
4377 mod connect_pool {
4378 use super::*;
4379
4380 #[test]
4381 fn connects_all_in_pool() {
4382 TOKIO_SHARED_RT.block_on(async {
4383 let pool_size = 3;
4384 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
4385 let addr = listener.local_addr().unwrap();
4386 tokio::spawn(async move {
4387 for _ in 0..pool_size {
4388 if let Ok((stream, _)) = listener.accept().await {
4389 tokio::spawn(async move {
4390 let mut ws = accept_async(stream).await.unwrap();
4391 sleep(Duration::from_millis(500)).await;
4392 let _ = ws.close(None).await;
4393 });
4394 }
4395 }
4396 });
4397 let conns: Vec<Arc<WebsocketConnection>> = (0..pool_size)
4398 .map(|i| WebsocketConnection::new(format!("c{i}")))
4399 .collect();
4400 let common = WebsocketCommon::new(
4401 conns.clone(),
4402 WebsocketMode::Pool(pool_size),
4403 0,
4404 None,
4405 None,
4406 );
4407 let url = format!("ws://{addr}");
4408 common.clone().connect_pool(&url, None).await.unwrap();
4409 for conn in conns {
4410 let mut ok = false;
4411 for _ in 0..100 {
4412 if conn.state.lock().await.ws_write_tx.is_some() {
4413 ok = true;
4414 break;
4415 }
4416 sleep(Duration::from_millis(50)).await;
4417 }
4418 assert!(ok, "expected ws_write_tx Some after connect");
4419 }
4420 });
4421 }
4422
4423 #[test]
4424 fn fails_if_any_refused() {
4425 TOKIO_SHARED_RT.block_on(async {
4426 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
4427 let addr = listener.local_addr().unwrap();
4428 let pool_size = 3;
4429 tokio::spawn(async move {
4430 for _ in 0..2 {
4431 if let Ok((stream, _)) = listener.accept().await {
4432 let mut ws = accept_async(stream).await.unwrap();
4433 let _ = ws.close(None).await;
4434 }
4435 }
4436 });
4437 let mut conns = Vec::new();
4438 let valid_url = format!("ws://{addr}");
4439 for i in 0..2 {
4440 conns.push(WebsocketConnection::new(format!("c{i}")));
4441 }
4442 conns.push(WebsocketConnection::new("bad"));
4443 let common = WebsocketCommon::new(
4444 conns.clone(),
4445 WebsocketMode::Pool(pool_size),
4446 0,
4447 None,
4448 None,
4449 );
4450 let res = common.clone().connect_pool(&valid_url, None).await;
4451 assert!(matches!(res, Err(WebsocketError::Handshake(_))));
4452 });
4453 }
4454
4455 #[test]
4456 fn fails_on_invalid_url() {
4457 TOKIO_SHARED_RT.block_on(async {
4458 let conns = vec![WebsocketConnection::new("c1")];
4459 let common = WebsocketCommon::new(conns, WebsocketMode::Pool(1), 0, None, None);
4460 let res = common.connect_pool("not-a-url", None).await;
4461 assert!(matches!(res, Err(WebsocketError::Handshake(_))));
4462 });
4463 }
4464
4465 #[test]
4466 fn fails_if_mixed_success_and_invalid_url() {
4467 TOKIO_SHARED_RT.block_on(async {
4468 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
4469 let addr = listener.local_addr().unwrap();
4470 tokio::spawn(async move {
4471 if let Ok((stream, _)) = listener.accept().await {
4472 let mut ws = accept_async(stream).await.unwrap();
4473 let _ = ws.close(None).await;
4474 }
4475 });
4476 let good = WebsocketConnection::new("good");
4477 let bad = WebsocketConnection::new("bad");
4478 let common = WebsocketCommon::new(
4479 vec![good, bad],
4480 WebsocketMode::Pool(2),
4481 0,
4482 None,
4483 None,
4484 );
4485 let url = format!("ws://{addr}");
4486 let res = common.connect_pool(&url, None).await;
4487 assert!(matches!(res, Err(WebsocketError::Handshake(_))));
4488 });
4489 }
4490
4491 #[test]
4492 fn init_connect_invoked_for_each() {
4493 TOKIO_SHARED_RT.block_on(async {
4494 let pool_size = 2;
4495 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
4496 let addr = listener.local_addr().unwrap();
4497 tokio::spawn(async move {
4498 for _ in 0..pool_size {
4499 if let Ok((stream, _)) = listener.accept().await {
4500 tokio::spawn(async move {
4501 let mut ws = accept_async(stream).await.unwrap();
4502 sleep(Duration::from_millis(500)).await;
4503 let _ = ws.close(None).await;
4504 });
4505 }
4506 }
4507 });
4508 let conns: Vec<Arc<WebsocketConnection>> = (0..pool_size)
4509 .map(|i| WebsocketConnection::new(format!("c{i}")))
4510 .collect();
4511 let common = WebsocketCommon::new(
4512 conns.clone(),
4513 WebsocketMode::Pool(pool_size),
4514 0,
4515 None,
4516 None,
4517 );
4518 let url = format!("ws://{addr}");
4519 common.clone().connect_pool(&url, None).await.unwrap();
4520 for conn in conns {
4521 let mut ok = false;
4522 for _ in 0..100 {
4523 if conn.state.lock().await.ws_write_tx.is_some() {
4524 ok = true;
4525 break;
4526 }
4527 sleep(Duration::from_millis(25)).await;
4528 }
4529 assert!(ok, "expected ws_write_tx Some for {}", conn.id);
4530 }
4531 });
4532 }
4533
4534 #[test]
4535 fn single_mode_uses_first_connection() {
4536 TOKIO_SHARED_RT.block_on(async {
4537 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
4538 let addr = listener.local_addr().unwrap();
4539 let _listener_guard = spawn_mock_ws_listener(listener);
4540 let conn = WebsocketConnection::new("c1");
4541 let common = WebsocketCommon::new(
4542 vec![conn.clone()],
4543 WebsocketMode::Single,
4544 0,
4545 None,
4546 None,
4547 );
4548 let url = format!("ws://{addr}");
4549 common.connect_pool(&url, None).await.unwrap();
4550 let ok = eventually_async(Duration::from_secs(5), || {
4551 let conn = conn.clone();
4552 async move { conn.state.lock().await.ws_write_tx.is_some() }
4553 })
4554 .await;
4555
4556 assert!(ok, "single mode did not select first connection");
4557 });
4558 }
4559
4560 #[test]
4561 fn empty_subset_is_ok_and_connects_none() {
4562 TOKIO_SHARED_RT.block_on(async {
4563 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
4564 let addr = listener.local_addr().unwrap();
4565
4566 tokio::spawn(async move {
4567 let _ = addr;
4568 });
4569
4570 let c1 = WebsocketConnection::new("c1");
4571 let c2 = WebsocketConnection::new("c2");
4572
4573 let common = WebsocketCommon::new(
4574 vec![c1.clone(), c2.clone()],
4575 WebsocketMode::Pool(2),
4576 0,
4577 None,
4578 None,
4579 );
4580
4581 let url = format!("ws://{addr}");
4582 common
4583 .clone()
4584 .connect_pool(&url, Some(vec![]))
4585 .await
4586 .unwrap();
4587
4588 assert!(c1.state.lock().await.ws_write_tx.is_none());
4589 assert!(c2.state.lock().await.ws_write_tx.is_none());
4590 });
4591 }
4592 }
4593
4594 mod init_connect {
4595 use super::*;
4596 use std::panic;
4597 use tokio::sync::mpsc::{channel, error::TryRecvError};
4598
4599 #[test]
4600 fn pool_mode_none_connection_uses_first() {
4601 TOKIO_SHARED_RT.block_on(async {
4602 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
4603 let addr = listener.local_addr().unwrap();
4604 tokio::spawn(async move {
4605 for _ in 0..2 {
4606 if let Ok((stream, _)) = listener.accept().await {
4607 let mut ws = accept_async(stream).await.unwrap();
4608 ws.close(None).await.ok();
4609 }
4610 }
4611 });
4612
4613 let c1 = WebsocketConnection::new("c1");
4614 let c2 = WebsocketConnection::new("c2");
4615 let common = WebsocketCommon::new(
4616 vec![c1.clone(), c2.clone()],
4617 WebsocketMode::Pool(2),
4618 0,
4619 None,
4620 None,
4621 );
4622 let url = format!("ws://{addr}");
4623
4624 common
4625 .clone()
4626 .init_connect(&url, false, None)
4627 .await
4628 .unwrap();
4629
4630 let ok = eventually_async(Duration::from_secs(5), || {
4631 let conn1 = c1.clone();
4632 async move { conn1.state.lock().await.ws_write_tx.is_some() }
4633 })
4634 .await;
4635
4636 assert!(ok, "first connection was never selected");
4637 assert!(c2.state.lock().await.ws_write_tx.is_none());
4638 });
4639 }
4640
4641 #[test]
4642 fn writer_channel_can_send_text() {
4643 TOKIO_SHARED_RT.block_on(async {
4644 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
4645 let addr = listener.local_addr().unwrap();
4646 let received = Arc::new(Mutex::new(None::<String>));
4647 let received_clone = received.clone();
4648
4649 tokio::spawn(async move {
4650 if let Ok((stream, _)) = listener.accept().await {
4651 let mut ws = accept_async(stream).await.unwrap();
4652 if let Some(Ok(Message::Text(txt))) = ws.next().await {
4653 *received_clone.lock().await = Some(txt.to_string());
4654 }
4655 ws.close(None).await.ok();
4656 }
4657 });
4658
4659 let conn = WebsocketConnection::new("cw");
4660 let common = WebsocketCommon::new(
4661 vec![conn.clone()],
4662 WebsocketMode::Single,
4663 0,
4664 None,
4665 None,
4666 );
4667 let url = format!("ws://{addr}");
4668 common
4669 .clone()
4670 .init_connect(&url, false, Some(conn.clone()))
4671 .await
4672 .unwrap();
4673
4674 let tx = conn.state.lock().await.ws_write_tx.clone().unwrap();
4675 tx.send(Message::Text("ping".into())).unwrap();
4676
4677 sleep(Duration::from_millis(50)).await;
4678
4679 let lock = received.lock().await;
4680 assert_eq!(lock.as_deref(), Some("ping"));
4681 });
4682 }
4683
4684 #[test]
4685 fn does_not_skip_when_reconnection_pending_even_if_writer_exists() {
4686 TOKIO_SHARED_RT.block_on(async {
4687 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
4688 let addr = listener.local_addr().unwrap();
4689 let _listener_guard = spawn_mock_ws_listener(listener);
4690
4691 let conn = WebsocketConnection::new("c-reconnect");
4692 {
4693 let mut st = conn.state.lock().await;
4694 let (tx, _) = unbounded_channel::<Message>();
4695 st.ws_write_tx = Some(tx);
4696 st.reconnection_pending = true;
4697 }
4698
4699 let common = WebsocketCommon::new(
4700 vec![conn.clone()],
4701 WebsocketMode::Single,
4702 0,
4703 None,
4704 None,
4705 );
4706
4707 let url = format!("ws://{addr}");
4708 common
4709 .clone()
4710 .init_connect(&url, false, Some(conn.clone()))
4711 .await
4712 .unwrap();
4713
4714 let st = conn.state.lock().await;
4715 assert!(
4716 st.ws_write_tx.is_some(),
4717 "writer should be set after connect"
4718 );
4719 assert!(
4720 !st.reconnection_pending,
4721 "reconnection_pending should be cleared after successful connect"
4722 );
4723 });
4724 }
4725
4726 #[test]
4727 fn responds_to_ping_with_pong() {
4728 TOKIO_SHARED_RT.block_on(async {
4729 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
4730 let addr = listener.local_addr().unwrap();
4731
4732 let saw_pong = Arc::new(Mutex::new(false));
4733 let saw_pong2 = saw_pong.clone();
4734
4735 tokio::spawn(async move {
4736 if let Ok((stream, _)) = listener.accept().await {
4737 let mut ws = accept_async(stream).await.unwrap();
4738 ws.send(Message::Ping(vec![1, 2, 3].into())).await.unwrap();
4739 if let Some(Ok(Message::Pong(payload))) = ws.next().await {
4740 if payload[..] == [1, 2, 3] {
4741 *saw_pong2.lock().await = true;
4742 }
4743 }
4744 let _ = ws.close(None).await;
4745 }
4746 });
4747
4748 let conn = WebsocketConnection::new("c-ping");
4749 let common = WebsocketCommon::new(
4750 vec![conn.clone()],
4751 WebsocketMode::Single,
4752 0,
4753 None,
4754 None,
4755 );
4756 let url = format!("ws://{addr}");
4757 common
4758 .clone()
4759 .init_connect(&url, false, Some(conn))
4760 .await
4761 .unwrap();
4762
4763 sleep(Duration::from_millis(50)).await;
4764
4765 assert!(*saw_pong.lock().await, "server should have seen a Pong");
4766 });
4767 }
4768
4769 #[test]
4770 fn handshake_error_on_invalid_url() {
4771 TOKIO_SHARED_RT.block_on(async {
4772 let conn = WebsocketConnection::new("c-invalid");
4773 let common = WebsocketCommon::new(
4774 vec![conn.clone()],
4775 WebsocketMode::Single,
4776 0,
4777 None,
4778 None,
4779 );
4780 let res = common
4781 .clone()
4782 .init_connect("not-a-url", false, Some(conn.clone()))
4783 .await;
4784 assert!(matches!(res, Err(WebsocketError::Handshake(_))));
4785 });
4786 }
4787
4788 #[test]
4789 fn skip_if_writer_exists_and_not_renewal() {
4790 TOKIO_SHARED_RT.block_on(async {
4791 let conn = WebsocketConnection::new("c-writer");
4792 let (tx, mut rx) = unbounded_channel::<Message>();
4793 {
4794 let mut st = conn.state.lock().await;
4795 st.ws_write_tx = Some(tx.clone());
4796 }
4797 let common = WebsocketCommon::new(
4798 vec![conn.clone()],
4799 WebsocketMode::Single,
4800 0,
4801 None,
4802 None,
4803 );
4804 let res = common
4805 .clone()
4806 .init_connect("ws://127.0.0.1:1", false, Some(conn.clone()))
4807 .await;
4808
4809 assert!(res.is_ok());
4810 assert!(rx.try_recv().is_err());
4811 });
4812 }
4813
4814 #[test]
4815 fn short_circuit_on_already_renewing() {
4816 TOKIO_SHARED_RT.block_on(async {
4817 let conn = WebsocketConnection::new("c-renew");
4818 {
4819 let mut st = conn.state.lock().await;
4820 st.renewal_pending = true;
4821 }
4822 let common = WebsocketCommon::new(
4823 vec![conn.clone()],
4824 WebsocketMode::Single,
4825 0,
4826 None,
4827 None,
4828 );
4829 let res = common
4830 .clone()
4831 .init_connect("ws://127.0.0.1:1", true, Some(conn.clone()))
4832 .await;
4833
4834 assert!(res.is_ok());
4835 assert!(conn.state.lock().await.ws_write_tx.is_none());
4836 });
4837 }
4838
4839 #[test]
4840 fn is_renewal_true_sets_and_clears_flag() {
4841 struct GatedHandler {
4842 gate: Mutex<Option<oneshot::Receiver<()>>>,
4843 }
4844 #[async_trait]
4845 impl WebsocketHandler for GatedHandler {
4846 async fn on_open(&self, _url: String, _connection: Arc<WebsocketConnection>) {
4847 if let Some(rx) = self.gate.lock().await.take() {
4848 let _ = rx.await;
4849 }
4850 }
4851 async fn on_message(
4852 &self,
4853 _data: String,
4854 _connection: Arc<WebsocketConnection>,
4855 ) {
4856 }
4857 async fn get_reconnect_url(
4858 &self,
4859 url: String,
4860 _connection: Arc<WebsocketConnection>,
4861 ) -> String {
4862 url
4863 }
4864 }
4865
4866 TOKIO_SHARED_RT.block_on(async {
4867 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
4868 let addr = listener.local_addr().unwrap();
4869 let _listener_guard = spawn_mock_ws_listener(listener);
4870
4871 let conn = WebsocketConnection::new("c-new-renew");
4872 let (gate_tx, gate_rx) = oneshot::channel();
4873 conn.set_handler(Arc::new(GatedHandler {
4874 gate: Mutex::new(Some(gate_rx)),
4875 }))
4876 .await;
4877
4878 let common = WebsocketCommon::new(
4879 vec![conn.clone()],
4880 WebsocketMode::Single,
4881 0,
4882 None,
4883 None,
4884 );
4885 let url = format!("ws://{addr}");
4886 let res = common
4887 .clone()
4888 .init_connect(&url, true, Some(conn.clone()))
4889 .await;
4890
4891 assert!(res.is_ok());
4892
4893 {
4894 let st = conn.state.lock().await;
4895 assert!(st.ws_write_tx.is_some(), "writer should be set");
4896 assert!(
4897 st.renewal_pending,
4898 "renewal_pending must be true until on_open"
4899 );
4900 }
4901
4902 let _ = gate_tx.send(());
4903
4904 let ok = eventually_async(Duration::from_secs(2), || {
4905 let conn = conn.clone();
4906 async move { !conn.state.lock().await.renewal_pending }
4907 })
4908 .await;
4909 assert!(ok, "renewal_pending should be cleared in on_open");
4910 let st = conn.state.lock().await;
4911 assert!(
4912 !st.renewal_pending,
4913 "renewal_pending should be cleared in on_open"
4914 );
4915 });
4916 }
4917
4918 #[test]
4919 fn does_not_schedule_reconnect_on_renewal() {
4920 TOKIO_SHARED_RT.block_on(async {
4921 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
4922 let addr = listener.local_addr().unwrap();
4923 let server_handle = tokio::spawn(async move {
4924 if let Ok((stream, _)) = listener.accept().await {
4925 let _ws = accept_async(stream).await.unwrap();
4926 if let Ok((stream, _)) = listener.accept().await {
4927 let _ws = accept_async(stream).await.unwrap();
4928 }
4929 sleep(Duration::from_secs(10)).await;
4930 }
4931 });
4932
4933 let conn = WebsocketConnection::new("c-needs-renew");
4934 let (reconnect_tx, mut reconnect_rx) = channel::<ReconnectEntry>(1);
4935 let (renewal_tx, _renewal_rx) = channel::<(String, String)>(1);
4936 let common = Arc::new(WebsocketCommon {
4937 events: WebsocketEventEmitter::new(),
4938 mode: WebsocketMode::Single,
4939 round_robin_index: AtomicUsize::new(0),
4940 connection_pool: vec![conn.clone()],
4941 reconnect_tx,
4942 renewal_tx,
4943 reconnect_delay: 0,
4944 agent: None,
4945 user_agent: None,
4946 });
4947 let url = format!("ws://{addr}");
4948 let res = common
4949 .clone()
4950 .init_connect(&url, false, Some(conn.clone()))
4951 .await;
4952
4953 assert!(res.is_ok());
4954
4955 {
4956 let st = conn.state.lock().await;
4957 assert!(st.ws_write_tx.is_some(), "writer should be set");
4958 }
4959
4960 common
4961 .clone()
4962 .init_connect(&url, true, Some(conn.clone()))
4963 .await
4964 .expect("Renewal init_connect should succeed");
4965
4966 sleep(Duration::from_millis(1000)).await;
4967
4968 match reconnect_rx.try_recv() {
4969 Err(TryRecvError::Empty) => {}
4970 Ok(_) => panic!("Received reconnection request on renewal"),
4971 Err(TryRecvError::Disconnected) => {
4972 panic!("Sender for reconnection_rx disconnected")
4973 }
4974 }
4975 server_handle.abort();
4976 });
4977 }
4978
4979 #[test]
4980 fn default_connection_selected_when_none_passed() {
4981 TOKIO_SHARED_RT.block_on(async {
4982 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
4983 let addr = listener.local_addr().unwrap();
4984 let _listener_guard = spawn_mock_ws_listener(listener);
4985 let conn = WebsocketConnection::new("c-default");
4986 let common = WebsocketCommon::new(
4987 vec![conn.clone()],
4988 WebsocketMode::Single,
4989 0,
4990 None,
4991 None,
4992 );
4993 let url = format!("ws://{addr}");
4994 let res = common.clone().init_connect(&url, false, None).await;
4995
4996 assert!(res.is_ok());
4997 let ok = eventually_async(Duration::from_secs(5), || {
4998 let conn = conn.clone();
4999 async move { conn.state.lock().await.ws_write_tx.is_some() }
5000 })
5001 .await;
5002
5003 assert!(ok, "default connection was never selected");
5004 });
5005 }
5006
5007 #[test]
5008 fn schedules_reconnect_on_abnormal_close() {
5009 TOKIO_SHARED_RT.block_on(async {
5010 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
5011 let addr = listener.local_addr().unwrap();
5012 tokio::spawn(async move {
5013 if let Ok((stream, _)) = listener.accept().await {
5014 let ws = accept_async(stream).await.unwrap();
5015
5016 drop(ws);
5022 }
5023 });
5024 let conn = WebsocketConnection::new("c-close");
5025 let common = WebsocketCommon::new(
5026 vec![conn.clone()],
5027 WebsocketMode::Single,
5028 5_000,
5029 None,
5030 None,
5031 );
5032 let url = format!("ws://{addr}");
5033 common
5034 .clone()
5035 .init_connect(&url, false, Some(conn.clone()))
5036 .await
5037 .unwrap();
5038
5039 sleep(Duration::from_millis(50)).await;
5040
5041 let st = conn.state.lock().await;
5042 assert!(
5043 st.reconnection_pending,
5044 "expected reconnection_pending to be true after abnormal close"
5045 );
5046 assert!(
5047 st.ws_write_tx.is_none(),
5048 "ws_write_tx should be cleared when scheduling a reconnect"
5049 );
5050 });
5051 }
5052 }
5053
5054 mod disconnect {
5055 use super::*;
5056
5057 #[test]
5058 fn returns_ok_when_no_connections_are_ready() {
5059 TOKIO_SHARED_RT.block_on(async {
5060 let conn = WebsocketConnection::new("c1");
5061 let common = WebsocketCommon::new(
5062 vec![conn.clone()],
5063 WebsocketMode::Single,
5064 0,
5065 None,
5066 None,
5067 );
5068 let res = common.disconnect().await;
5069
5070 assert!(res.is_ok());
5071 assert!(!conn.state.lock().await.close_initiated);
5072 });
5073 }
5074
5075 #[test]
5076 fn closes_all_ready_connections() {
5077 TOKIO_SHARED_RT.block_on(async {
5078 let conn1 = WebsocketConnection::new("c1");
5079 let conn2 = WebsocketConnection::new("c2");
5080 let (tx1, mut rx1) = unbounded_channel::<Message>();
5081 let (tx2, mut rx2) = unbounded_channel::<Message>();
5082 {
5083 let mut s1 = conn1.state.lock().await;
5084 s1.ws_write_tx = Some(tx1);
5085 }
5086 {
5087 let mut s2 = conn2.state.lock().await;
5088 s2.ws_write_tx = Some(tx2);
5089 }
5090 let common = WebsocketCommon::new(
5091 vec![conn1.clone(), conn2.clone()],
5092 WebsocketMode::Pool(2),
5093 0,
5094 None,
5095 None,
5096 );
5097 let fut = common.disconnect();
5098
5099 sleep(Duration::from_millis(50)).await;
5100
5101 fut.await.unwrap();
5102
5103 assert!(conn1.state.lock().await.close_initiated);
5104 assert!(conn2.state.lock().await.close_initiated);
5105
5106 {
5107 let st = conn1.state.lock().await;
5108 assert!(!st.is_session_logged_on, "conn1 should be logged out");
5109 assert!(st.session_logon_req.is_none(), "conn1 req cleared");
5110 }
5111 {
5112 let st = conn2.state.lock().await;
5113 assert!(!st.is_session_logged_on, "conn2 should be logged out");
5114 assert!(st.session_logon_req.is_none(), "conn2 req cleared");
5115 }
5116
5117 match (rx1.try_recv(), rx2.try_recv()) {
5118 (Ok(Message::Close(_)), Ok(Message::Close(_))) => {}
5119 other => panic!("expected two Close frames, got {other:?}"),
5120 }
5121 });
5122 }
5123
5124 #[test]
5125 fn does_not_mark_close_initiated_if_no_writer() {
5126 TOKIO_SHARED_RT.block_on(async {
5127 let conn = WebsocketConnection::new("c-new");
5128 let common = WebsocketCommon::new(
5129 vec![conn.clone()],
5130 WebsocketMode::Single,
5131 0,
5132 None,
5133 None,
5134 );
5135 common.disconnect().await.unwrap();
5136
5137 assert!(!conn.state.lock().await.close_initiated);
5138 });
5139 }
5140
5141 #[test]
5142 fn mixed_pool_marks_all_and_closes_only_writers() {
5143 TOKIO_SHARED_RT.block_on(async {
5144 let conn_w = WebsocketConnection::new("with");
5145 let conn_wo = WebsocketConnection::new("without");
5146 let (tx, mut rx) = unbounded_channel::<Message>();
5147 {
5148 let mut st = conn_w.state.lock().await;
5149 st.ws_write_tx = Some(tx);
5150 }
5151 let common = WebsocketCommon::new(
5152 vec![conn_w.clone(), conn_wo.clone()],
5153 WebsocketMode::Pool(2),
5154 0,
5155 None,
5156 None,
5157 );
5158 let fut = common.disconnect();
5159
5160 sleep(Duration::from_millis(50)).await;
5161
5162 fut.await.unwrap();
5163
5164 assert!(conn_w.state.lock().await.close_initiated);
5165 assert!(conn_wo.state.lock().await.close_initiated);
5166 assert!(matches!(rx.try_recv(), Ok(Message::Close(_))));
5167 });
5168 }
5169
5170 #[test]
5171 fn after_disconnect_not_connected() {
5172 TOKIO_SHARED_RT.block_on(async {
5173 let conn = WebsocketConnection::new("c1");
5174 let (tx, mut _rx) = unbounded_channel::<Message>();
5175 {
5176 let mut st = conn.state.lock().await;
5177 st.ws_write_tx = Some(tx);
5178 }
5179 let common = WebsocketCommon::new(
5180 vec![conn.clone()],
5181 WebsocketMode::Single,
5182 0,
5183 None,
5184 None,
5185 );
5186 common.disconnect().await.unwrap();
5187 assert!(!common.is_connected(Some(&conn)).await);
5188 });
5189 }
5190 }
5191
5192 mod ping_server {
5193 use super::*;
5194
5195 #[test]
5196 fn sends_ping_to_all_ready_connections() {
5197 TOKIO_SHARED_RT.block_on(async {
5198 let mut conns = Vec::new();
5199 for i in 0..3 {
5200 let conn = WebsocketConnection::new(format!("c{i}"));
5201 let (tx, rx) = unbounded_channel::<Message>();
5202 {
5203 let mut st = conn.state.lock().await;
5204 st.ws_write_tx = Some(tx);
5205 }
5206 conns.push((conn, rx));
5207 }
5208 let common = WebsocketCommon::new(
5209 conns.iter().map(|(c, _)| c.clone()).collect(),
5210 WebsocketMode::Pool(3),
5211 0,
5212 None,
5213 None,
5214 );
5215 common.ping_server().await;
5216 for (_, mut rx) in conns {
5217 match rx.try_recv() {
5218 Ok(Message::Ping(payload)) if payload.is_empty() => {}
5219 other => panic!("expected empty-payload Ping, got {other:?}"),
5220 }
5221 }
5222 });
5223 }
5224
5225 #[test]
5226 fn skips_not_ready_and_partial() {
5227 TOKIO_SHARED_RT.block_on(async {
5228 let ready = WebsocketConnection::new("ready");
5229 let not_ready = WebsocketConnection::new("not-ready");
5230 let (tx_r, mut rx_r) = unbounded_channel::<Message>();
5231 {
5232 let mut st = ready.state.lock().await;
5233 st.ws_write_tx = Some(tx_r);
5234 }
5235 {
5236 let mut st = not_ready.state.lock().await;
5237 st.ws_write_tx = None;
5238 }
5239 let common = WebsocketCommon::new(
5240 vec![ready.clone(), not_ready.clone()],
5241 WebsocketMode::Pool(2),
5242 0,
5243 None,
5244 None,
5245 );
5246 common.ping_server().await;
5247 match rx_r.try_recv() {
5248 Ok(Message::Ping(payload)) if payload.is_empty() => {}
5249 other => panic!("expected Ping on ready, got {other:?}"),
5250 }
5251 });
5252 }
5253
5254 #[test]
5255 fn no_ping_when_flags_block() {
5256 TOKIO_SHARED_RT.block_on(async {
5257 let conn = WebsocketConnection::new("c1");
5258 let (tx, mut rx) = unbounded_channel::<Message>();
5259 {
5260 let mut st = conn.state.lock().await;
5261 st.ws_write_tx = Some(tx);
5262 st.reconnection_pending = true;
5263 }
5264 let common = WebsocketCommon::new(
5265 vec![conn.clone()],
5266 WebsocketMode::Single,
5267 0,
5268 None,
5269 None,
5270 );
5271 common.ping_server().await;
5272 assert!(rx.try_recv().is_err());
5273 });
5274 }
5275 }
5276
5277 mod send {
5278 use super::*;
5279
5280 #[test]
5281 fn round_robin_send_without_specific() {
5282 TOKIO_SHARED_RT.block_on(async {
5283 let conn1 = WebsocketConnection::new("c1");
5284 let conn2 = WebsocketConnection::new("c2");
5285 let (tx1, mut rx1) = unbounded_channel::<Message>();
5286 let (tx2, mut rx2) = unbounded_channel::<Message>();
5287 {
5288 let mut s1 = conn1.state.lock().await;
5289 s1.ws_write_tx = Some(tx1);
5290 }
5291 {
5292 let mut s2 = conn2.state.lock().await;
5293 s2.ws_write_tx = Some(tx2);
5294 }
5295 let common = WebsocketCommon::new(
5296 vec![conn1.clone(), conn2.clone()],
5297 WebsocketMode::Pool(2),
5298 0,
5299 None,
5300 None,
5301 );
5302
5303 let res1 = common
5304 .send("a".into(), None, false, Duration::from_secs(1), None)
5305 .await
5306 .unwrap();
5307 assert!(res1.is_none());
5308
5309 let res2 = common
5310 .send("b".into(), None, false, Duration::from_secs(1), None)
5311 .await
5312 .unwrap();
5313 assert!(res2.is_none());
5314
5315 assert_eq!(
5316 if let Message::Text(t) = rx1.try_recv().unwrap() {
5317 t
5318 } else {
5319 panic!()
5320 },
5321 "a"
5322 );
5323 assert_eq!(
5324 if let Message::Text(t) = rx2.try_recv().unwrap() {
5325 t
5326 } else {
5327 panic!()
5328 },
5329 "b"
5330 );
5331 });
5332 }
5333
5334 #[test]
5335 fn round_robin_skips_not_ready() {
5336 TOKIO_SHARED_RT.block_on(async {
5337 let conn1 = WebsocketConnection::new("c1");
5338 let conn2 = WebsocketConnection::new("c2");
5339 let (tx2, mut rx2) = unbounded_channel::<Message>();
5340 {
5341 let mut s1 = conn1.state.lock().await;
5342 s1.ws_write_tx = None;
5343 }
5344 {
5345 let mut s2 = conn2.state.lock().await;
5346 s2.ws_write_tx = Some(tx2);
5347 }
5348 let common = WebsocketCommon::new(
5349 vec![conn1.clone(), conn2.clone()],
5350 WebsocketMode::Pool(2),
5351 0,
5352 None,
5353 None,
5354 );
5355 let res = common
5356 .send("bar".into(), None, false, Duration::from_secs(1), None)
5357 .await
5358 .unwrap();
5359 assert!(res.is_none());
5360 match rx2.try_recv().unwrap() {
5361 Message::Text(t) => assert_eq!(t, "bar"),
5362 other => panic!("unexpected {other:?}"),
5363 }
5364 });
5365 }
5366
5367 #[test]
5368 fn sync_send_on_specific_connection() {
5369 TOKIO_SHARED_RT.block_on(async {
5370 let conn1 = WebsocketConnection::new("c1");
5371 let conn2 = WebsocketConnection::new("c2");
5372 let (tx2, mut rx2) = unbounded_channel::<Message>();
5373 {
5374 let mut st = conn2.state.lock().await;
5375 st.ws_write_tx = Some(tx2);
5376 }
5377 let common = WebsocketCommon::new(
5378 vec![conn1.clone(), conn2.clone()],
5379 WebsocketMode::Pool(2),
5380 0,
5381 None,
5382 None,
5383 );
5384 let res = common
5385 .send(
5386 "payload".into(),
5387 Some("id".into()),
5388 false,
5389 Duration::from_secs(1),
5390 Some(conn2.clone()),
5391 )
5392 .await
5393 .unwrap();
5394 assert!(res.is_none());
5395 match rx2.try_recv() {
5396 Ok(Message::Text(t)) => assert_eq!(t, "payload"),
5397 other => panic!("expected Text, got {other:?}"),
5398 }
5399 });
5400 }
5401
5402 #[test]
5403 fn sync_send_with_id_does_not_insert_pending() {
5404 TOKIO_SHARED_RT.block_on(async {
5405 let conn = WebsocketConnection::new("c1");
5406 let (tx, mut rx) = unbounded_channel::<Message>();
5407 {
5408 let mut st = conn.state.lock().await;
5409 st.ws_write_tx = Some(tx);
5410 }
5411 let common = WebsocketCommon::new(
5412 vec![conn.clone()],
5413 WebsocketMode::Single,
5414 0,
5415 None,
5416 None,
5417 );
5418 let res = common
5419 .send(
5420 "msg".into(),
5421 Some("id".into()),
5422 false,
5423 Duration::from_secs(1),
5424 Some(conn.clone()),
5425 )
5426 .await
5427 .unwrap();
5428 assert!(res.is_none());
5429 assert!(conn.state.lock().await.pending_requests.is_empty());
5430 match rx.try_recv().unwrap() {
5431 Message::Text(t) => assert_eq!(t, "msg"),
5432 other => panic!("unexpected {other:?}"),
5433 }
5434 });
5435 }
5436
5437 #[test]
5438 fn sync_send_error_if_not_ready() {
5439 TOKIO_SHARED_RT.block_on(async {
5440 let conn = WebsocketConnection::new("c1");
5441 let common = WebsocketCommon::new(
5442 vec![conn.clone()],
5443 WebsocketMode::Single,
5444 0,
5445 None,
5446 None,
5447 );
5448 let err = common
5449 .send(
5450 "msg".into(),
5451 Some("id".into()),
5452 false,
5453 Duration::from_secs(1),
5454 Some(conn.clone()),
5455 )
5456 .await
5457 .unwrap_err();
5458 assert!(matches!(err, WebsocketError::NotConnected));
5459 });
5460 }
5461
5462 #[test]
5463 fn sync_send_error_when_no_ready() {
5464 TOKIO_SHARED_RT.block_on(async {
5465 let conn = WebsocketConnection::new("c1");
5466 let common = WebsocketCommon::new(
5467 vec![conn.clone()],
5468 WebsocketMode::Single,
5469 0,
5470 None,
5471 None,
5472 );
5473 let err = common
5474 .send("msg".into(), None, false, Duration::from_secs(1), None)
5475 .await
5476 .unwrap_err();
5477 assert!(matches!(err, WebsocketError::NotConnected));
5478 });
5479 }
5480
5481 #[test]
5482 fn async_send_and_receive() {
5483 TOKIO_SHARED_RT.block_on(async {
5484 let conn = WebsocketConnection::new("c1");
5485 let (tx, mut rx) = unbounded_channel::<Message>();
5486 {
5487 let mut st = conn.state.lock().await;
5488 st.ws_write_tx = Some(tx);
5489 }
5490 let common = WebsocketCommon::new(
5491 vec![conn.clone()],
5492 WebsocketMode::Single,
5493 0,
5494 None,
5495 None,
5496 );
5497 let fut = common
5498 .send(
5499 "hello".into(),
5500 Some("id".into()),
5501 true,
5502 Duration::from_secs(5),
5503 Some(conn.clone()),
5504 )
5505 .await
5506 .unwrap()
5507 .unwrap();
5508 match rx.try_recv() {
5509 Ok(Message::Text(t)) => assert_eq!(t, "hello"),
5510 other => panic!("expected Text, got {other:?}"),
5511 }
5512 {
5513 let mut st = conn.state.lock().await;
5514 let pr = st.pending_requests.remove("id").unwrap();
5515 pr.completion.send(Ok(serde_json::json!("ok"))).unwrap();
5516 }
5517 let resp = fut.await.unwrap().unwrap();
5518 assert_eq!(resp, serde_json::json!("ok"));
5519 });
5520 }
5521
5522 #[test]
5523 fn async_send_default_connection() {
5524 TOKIO_SHARED_RT.block_on(async {
5525 let conn = WebsocketConnection::new("c1");
5526 let (tx, mut rx) = unbounded_channel::<Message>();
5527 {
5528 let mut st = conn.state.lock().await;
5529 st.ws_write_tx = Some(tx);
5530 }
5531 let common = WebsocketCommon::new(
5532 vec![conn.clone()],
5533 WebsocketMode::Single,
5534 0,
5535 None,
5536 None,
5537 );
5538 let fut = common
5539 .send(
5540 "msg".into(),
5541 Some("id".into()),
5542 true,
5543 Duration::from_secs(5),
5544 None,
5545 )
5546 .await
5547 .unwrap()
5548 .unwrap();
5549 match rx.try_recv() {
5550 Ok(Message::Text(t)) => assert_eq!(t, "msg"),
5551 _ => panic!("no text"),
5552 }
5553 {
5554 let mut st = conn.state.lock().await;
5555 let pr = st.pending_requests.remove("id").unwrap();
5556 pr.completion.send(Ok(serde_json::json!(123))).unwrap();
5557 }
5558 let resp = fut.await.unwrap().unwrap();
5559 assert_eq!(resp, serde_json::json!(123));
5560 });
5561 }
5562
5563 #[test]
5564 fn async_send_failure_removes_pending_entry() {
5565 TOKIO_SHARED_RT.block_on(async {
5573 let conn = WebsocketConnection::new("c1");
5574 let (tx, rx) = unbounded_channel::<Message>();
5575 drop(rx);
5576 {
5577 let mut st = conn.state.lock().await;
5578 st.ws_write_tx = Some(tx);
5579 }
5580 let common = WebsocketCommon::new(
5581 vec![conn.clone()],
5582 WebsocketMode::Single,
5583 0,
5584 None,
5585 None,
5586 );
5587 let err = common
5588 .send(
5589 "msg".into(),
5590 Some("id".into()),
5591 true,
5592 Duration::from_secs(1),
5593 Some(conn.clone()),
5594 )
5595 .await
5596 .unwrap_err();
5597 assert!(matches!(err, WebsocketError::NotConnected));
5598 assert!(
5599 conn.state.lock().await.pending_requests.is_empty(),
5600 "pending_requests must be empty after a failed send"
5601 );
5602 });
5603 }
5604
5605 #[test]
5606 fn async_send_error_if_no_id() {
5607 TOKIO_SHARED_RT.block_on(async {
5608 let conn = WebsocketConnection::new("c§");
5609 let (tx, _rx) = unbounded_channel::<Message>();
5610 {
5611 let mut st = conn.state.lock().await;
5612 st.ws_write_tx = Some(tx);
5613 }
5614 let common = WebsocketCommon::new(
5615 vec![conn.clone()],
5616 WebsocketMode::Single,
5617 0,
5618 None,
5619 None,
5620 );
5621 let err = common
5622 .send(
5623 "msg".into(),
5624 None,
5625 true,
5626 Duration::from_secs(1),
5627 Some(conn.clone()),
5628 )
5629 .await
5630 .unwrap_err();
5631 assert!(matches!(err, WebsocketError::NotConnected));
5632 });
5633 }
5634
5635 #[test]
5636 fn timeout_rejects_async() {
5637 TOKIO_SHARED_RT.block_on(async {
5638 pause();
5639 let conn = WebsocketConnection::new("c1");
5640 let (tx, _rx) = unbounded_channel::<Message>();
5641 {
5642 let mut st = conn.state.lock().await;
5643 st.ws_write_tx = Some(tx);
5644 }
5645 let common = WebsocketCommon::new(
5646 vec![conn.clone()],
5647 WebsocketMode::Single,
5648 0,
5649 None,
5650 None,
5651 );
5652 let fut = common
5653 .send(
5654 "msg".into(),
5655 Some("id".into()),
5656 true,
5657 Duration::from_secs(1),
5658 Some(conn.clone()),
5659 )
5660 .await
5661 .unwrap()
5662 .unwrap();
5663 advance(Duration::from_secs(1)).await;
5664 let res = fut.await.unwrap();
5665 assert!(res.is_err(), "expected timeout error");
5666 assert!(!conn.state.lock().await.pending_requests.contains_key("id"));
5667 });
5668 }
5669
5670 #[test]
5671 fn async_send_errors_if_no_connection_ready() {
5672 TOKIO_SHARED_RT.block_on(async {
5673 let conn = WebsocketConnection::new("c1");
5674 let common = WebsocketCommon::new(
5675 vec![conn.clone()],
5676 WebsocketMode::Single,
5677 0,
5678 None,
5679 None,
5680 );
5681 let err = common
5682 .send(
5683 "msg".into(),
5684 Some("id".into()),
5685 true,
5686 Duration::from_secs(1),
5687 None,
5688 )
5689 .await
5690 .unwrap_err();
5691 assert!(matches!(err, WebsocketError::NotConnected));
5692 });
5693 }
5694 }
5695 }
5696
5697 mod websocket_api {
5698 use super::*;
5699
5700 mod initialisation {
5701 use super::*;
5702
5703 #[test]
5704 fn new_initializes_common() {
5705 TOKIO_SHARED_RT.block_on(async {
5706 let conn = WebsocketConnection::new("id");
5707 let pool = vec![conn.clone()];
5708
5709 let sig_gen = SignatureGenerator::new(
5710 Some("api_secret".to_string()),
5711 None::<PrivateKey>,
5712 None::<String>,
5713 );
5714
5715 let config = ConfigurationWebsocketApi {
5716 api_key: Some("api_key".to_string()),
5717 api_secret: Some("api_secret".to_string()),
5718 private_key: None,
5719 private_key_passphrase: None,
5720 ws_url: Some("wss://example".to_string()),
5721 mode: WebsocketMode::Single,
5722 reconnect_delay: 1000,
5723 signature_gen: sig_gen,
5724 timeout: 500,
5725 time_unit: None,
5726 auto_session_relogon: false,
5727 agent: None,
5728 user_agent: build_user_agent("product"),
5729 };
5730
5731 let api = WebsocketApi::new(config, pool.clone());
5732
5733 assert_eq!(api.common.connection_pool.len(), 1);
5734 assert_eq!(api.common.mode, WebsocketMode::Single);
5735
5736 let flag = *api.is_connecting.lock().await;
5737 assert!(!flag);
5738 });
5739 }
5740 }
5741
5742 mod connect {
5743 use super::*;
5744
5745 #[test]
5746 fn connect_when_not_connected_establishes() {
5747 TOKIO_SHARED_RT.block_on(async {
5748 let conn = WebsocketConnection::new("id");
5749 {
5750 let mut st = conn.state.lock().await;
5751 st.ws_write_tx = None;
5752 }
5753 let sig = SignatureGenerator::new(
5754 Some("api_secret".into()),
5755 None::<PrivateKey>,
5756 None::<String>,
5757 );
5758 let cfg = ConfigurationWebsocketApi {
5759 api_key: Some("api_key".into()),
5760 api_secret: Some("api_secret".to_string()),
5761 private_key: None,
5762 private_key_passphrase: None,
5763 ws_url: Some("ws://doesnotexist:1".to_string()),
5764 mode: WebsocketMode::Single,
5765 reconnect_delay: 0,
5766 signature_gen: sig,
5767 timeout: 10,
5768 time_unit: None,
5769 auto_session_relogon: false,
5770 agent: None,
5771 user_agent: build_user_agent("product"),
5772 };
5773 let api = WebsocketApi::new(cfg, vec![conn.clone()]);
5774 let res = api.clone().connect().await;
5775 assert!(!matches!(res, Err(WebsocketError::Timeout)));
5776 });
5777 }
5778
5779 #[test]
5780 fn already_connected_returns_ok() {
5781 TOKIO_SHARED_RT.block_on(async {
5782 let conn = WebsocketConnection::new("id2");
5783 let (tx, _) = unbounded_channel();
5784 {
5785 let mut st = conn.state.lock().await;
5786 st.ws_write_tx = Some(tx);
5787 }
5788 let sig = SignatureGenerator::new(
5789 Some("api_secret".to_string()),
5790 None::<PrivateKey>,
5791 None::<String>,
5792 );
5793 let cfg = ConfigurationWebsocketApi {
5794 api_key: Some("api_key".to_string()),
5795 api_secret: Some("api_secret".to_string()),
5796 private_key: None,
5797 private_key_passphrase: None,
5798 ws_url: Some("ws://example.com".to_string()),
5799 mode: WebsocketMode::Single,
5800 reconnect_delay: 0,
5801 signature_gen: sig,
5802 timeout: 10,
5803 time_unit: None,
5804 auto_session_relogon: false,
5805 agent: None,
5806 user_agent: build_user_agent("product"),
5807 };
5808 let api = WebsocketApi::new(cfg, vec![conn.clone()]);
5809 let res = api.connect().await;
5810 assert!(res.is_ok());
5811 });
5812 }
5813
5814 #[test]
5815 fn not_connected_returns_error() {
5816 TOKIO_SHARED_RT.block_on(async {
5817 let conn = WebsocketConnection::new("id1");
5818 let sig = SignatureGenerator::new(
5819 Some("api_secret".to_string()),
5820 None::<PrivateKey>,
5821 None::<String>,
5822 );
5823 let cfg = ConfigurationWebsocketApi {
5824 api_key: Some("api_key".to_string()),
5825 api_secret: Some("api_secret".to_string()),
5826 private_key: None,
5827 private_key_passphrase: None,
5828 ws_url: Some("ws://127.0.0.1:9".to_string()),
5829 mode: WebsocketMode::Single,
5830 reconnect_delay: 0,
5831 signature_gen: sig,
5832 timeout: 10,
5833 time_unit: None,
5834 auto_session_relogon: false,
5835 agent: None,
5836 user_agent: build_user_agent("product"),
5837 };
5838 let api = WebsocketApi::new(cfg, vec![conn.clone()]);
5839 let res = api.connect().await;
5840 assert!(res.is_err());
5841 });
5842 }
5843
5844 #[test]
5845 fn concurrent_calls_both_error_or_ok() {
5846 TOKIO_SHARED_RT.block_on(async {
5847 let conn = WebsocketConnection::new("id3");
5848 let sig = SignatureGenerator::new(
5849 Some("api_secret".to_string()),
5850 None::<PrivateKey>,
5851 None::<String>,
5852 );
5853 let cfg = ConfigurationWebsocketApi {
5854 api_key: Some("api_key".to_string()),
5855 api_secret: Some("api_secret".to_string()),
5856 private_key: None,
5857 private_key_passphrase: None,
5858 ws_url: Some("wss://invalid-domain".to_string()),
5859 mode: WebsocketMode::Single,
5860 reconnect_delay: 0,
5861 signature_gen: sig,
5862 timeout: 10,
5863 time_unit: None,
5864 auto_session_relogon: false,
5865 agent: None,
5866 user_agent: build_user_agent("product"),
5867 };
5868 let api = WebsocketApi::new(cfg, vec![conn.clone()]);
5869 let fut1 = tokio::spawn(api.clone().connect());
5870 let fut2 = tokio::spawn(api.clone().connect());
5871 let r1 = fut1.await.unwrap();
5872 let r2 = fut2.await.unwrap();
5873
5874 assert!(r1.is_err());
5875 assert!(r2.is_err() || r2.is_ok());
5876 });
5877 }
5878
5879 #[test]
5880 fn pool_failure_is_propagated() {
5881 TOKIO_SHARED_RT.block_on(async {
5882 let conn = WebsocketConnection::new("w");
5883 let sig = SignatureGenerator::new(
5884 Some("api_secret".to_string()),
5885 None::<PrivateKey>,
5886 None::<String>,
5887 );
5888 let cfg = ConfigurationWebsocketApi {
5889 api_key: Some("api_key".into()),
5890 api_secret: Some("api_secret".to_string()),
5891 private_key: None,
5892 private_key_passphrase: None,
5893 ws_url: Some("ws://doesnotexist:1".to_string()),
5894 mode: WebsocketMode::Single,
5895 reconnect_delay: 0,
5896 signature_gen: sig,
5897 timeout: 10,
5898 time_unit: None,
5899 auto_session_relogon: false,
5900 agent: None,
5901 user_agent: build_user_agent("product"),
5902 };
5903 let api = WebsocketApi::new(cfg, vec![conn.clone()]);
5904 let res = api.clone().connect().await;
5905 match res {
5906 Err(WebsocketError::Handshake(_) | WebsocketError::Timeout) => {}
5907 _ => panic!("expected handshake or timeout error"),
5908 }
5909 });
5910 }
5911 }
5912
5913 mod send_message {
5914 use super::*;
5915
5916 #[test]
5917 fn unsigned_message() {
5918 TOKIO_SHARED_RT.block_on(async {
5919 let api = create_websocket_api(None, None, None);
5920 let conn = &api.common.connection_pool[0];
5921 let (tx, mut rx) = unbounded_channel::<Message>();
5922 {
5923 let mut st = conn.state.lock().await;
5924 st.ws_write_tx = Some(tx);
5925 }
5926
5927 let fut = tokio::spawn({
5928 let api = api.clone();
5929 async move {
5930 let mut params = BTreeMap::new();
5931 params.insert("foo".into(), Value::String("bar".into()));
5932 let send_res = api
5933 .send_message::<Value>(
5934 "method",
5935 params,
5936 WebsocketMessageSendOptions {
5937 with_api_key: false,
5938 is_signed: false,
5939 ..Default::default()
5940 },
5941 )
5942 .await
5943 .unwrap();
5944
5945 match send_res {
5946 SendWebsocketMessageResult::Single(resp) => resp,
5947 SendWebsocketMessageResult::Multiple(_) => {
5948 panic!("expected single response")
5949 }
5950 }
5951 }
5952 });
5953
5954 let Message::Text(txt) = rx.recv().await.unwrap() else {
5955 panic!()
5956 };
5957 let req: Value = serde_json::from_str(&txt).unwrap();
5958 assert_eq!(req["method"], "method");
5959 assert_eq!(req["params"]["foo"], "bar");
5960 assert!(req["params"].get("apiKey").is_none());
5961 assert!(req["params"].get("timestamp").is_none());
5962 assert!(req["params"].get("signature").is_none());
5963
5964 let id = req["id"].as_str().unwrap().to_string();
5965 let mut st = conn.state.lock().await;
5966 let pending = st.pending_requests.remove(&id).unwrap();
5967 let reply = json!({
5968 "id": id,
5969 "result": { "x": 42 },
5970 "rateLimits": [{ "limit": 7 }]
5971 });
5972 pending.completion.send(Ok(reply)).unwrap();
5973
5974 let resp = fut.await.unwrap();
5975 let rate_limits = resp.rate_limits.unwrap_or_default();
5976
5977 assert!(rate_limits.is_empty());
5978 assert_eq!(resp.raw, json!({"x": 42}));
5979 });
5980 }
5981
5982 #[test]
5983 fn with_api_key_only() {
5984 TOKIO_SHARED_RT.block_on(async {
5985 let api = create_websocket_api(None, None, None);
5986 let conn = &api.common.connection_pool[0];
5987 let (tx, mut rx) = unbounded_channel::<Message>();
5988 {
5989 let mut st = conn.state.lock().await;
5990 st.ws_write_tx = Some(tx);
5991 }
5992
5993 let fut = tokio::spawn({
5994 let api = api.clone();
5995 async move {
5996 let params = BTreeMap::new();
5997 let send_res = api
5998 .send_message::<Value>(
5999 "method",
6000 params,
6001 WebsocketMessageSendOptions {
6002 with_api_key: true,
6003 is_signed: false,
6004 ..Default::default()
6005 },
6006 )
6007 .await
6008 .unwrap();
6009
6010 match send_res {
6011 SendWebsocketMessageResult::Single(resp) => resp,
6012 SendWebsocketMessageResult::Multiple(_) => {
6013 panic!("expected single response")
6014 }
6015 }
6016 }
6017 });
6018
6019 let Message::Text(txt) = rx.recv().await.unwrap() else {
6020 panic!()
6021 };
6022 let req: Value = serde_json::from_str(&txt).unwrap();
6023 assert_eq!(req["params"]["apiKey"], "api_key");
6024
6025 let id = req["id"].as_str().unwrap().to_string();
6026 let mut st = conn.state.lock().await;
6027 let pending = st.pending_requests.remove(&id).unwrap();
6028 pending
6029 .completion
6030 .send(Ok(json!({
6031 "id": id,
6032 "result": {},
6033 "rateLimits": []
6034 })))
6035 .unwrap();
6036
6037 let resp = fut.await.unwrap();
6038
6039 assert_eq!(resp.raw, json!({}));
6040 assert!(st.pending_requests.is_empty());
6041 });
6042 }
6043
6044 #[test]
6045 fn signed_message_has_timestamp_and_signature() {
6046 TOKIO_SHARED_RT.block_on(async {
6047 let api = create_websocket_api(None, None, None);
6048 let conn = &api.common.connection_pool[0];
6049 let (tx, mut rx) = unbounded_channel::<Message>();
6050 {
6051 let mut st = conn.state.lock().await;
6052 st.ws_write_tx = Some(tx);
6053 }
6054
6055 let fut = tokio::spawn({
6056 let api = api.clone();
6057 async move {
6058 let mut params = BTreeMap::new();
6059 params.insert("foo".into(), Value::String("bar".into()));
6060 let send_res = api
6061 .send_message::<Value>(
6062 "method",
6063 params,
6064 WebsocketMessageSendOptions {
6065 with_api_key: true,
6066 is_signed: true,
6067 ..Default::default()
6068 },
6069 )
6070 .await
6071 .unwrap();
6072
6073 match send_res {
6074 SendWebsocketMessageResult::Single(resp) => resp,
6075 SendWebsocketMessageResult::Multiple(_) => {
6076 panic!("expected single response")
6077 }
6078 }
6079 }
6080 });
6081
6082 let Message::Text(txt) = rx.recv().await.unwrap() else {
6083 panic!()
6084 };
6085 let req: Value = serde_json::from_str(&txt).unwrap();
6086 let p = &req["params"];
6087 assert_eq!(p["apiKey"], "api_key");
6088 assert!(p["timestamp"].is_number());
6089 assert!(p["signature"].is_string());
6090
6091 let id = req["id"].as_str().unwrap().to_string();
6092 let mut st = conn.state.lock().await;
6093 let pending = st.pending_requests.remove(&id).unwrap();
6094 pending
6095 .completion
6096 .send(Ok(json!({
6097 "id": id,
6098 "result": { "ok": true },
6099 "rateLimits": []
6100 })))
6101 .unwrap();
6102
6103 let resp = fut.await.unwrap();
6104 assert_eq!(resp.raw, json!({ "ok": true }));
6105 });
6106 }
6107
6108 #[test]
6109 fn multi_session_logon() {
6110 TOKIO_SHARED_RT.block_on(async {
6111 let api = create_websocket_api(None, Some(WebsocketMode::Pool(2)), None);
6112 let conn0 = &api.common.connection_pool[0];
6113 let conn1 = &api.common.connection_pool[1];
6114
6115 let (tx0, mut rx0) = unbounded_channel::<Message>();
6116 let (tx1, mut rx1) = unbounded_channel::<Message>();
6117 {
6118 let mut st0 = conn0.state.lock().await;
6119 st0.ws_write_tx = Some(tx0);
6120 }
6121 {
6122 let mut st1 = conn1.state.lock().await;
6123 st1.ws_write_tx = Some(tx1);
6124 }
6125
6126 let fut = tokio::spawn({
6127 let api = api.clone();
6128 async move {
6129 let params = BTreeMap::new();
6130 let send_res = api
6131 .send_message::<Value>(
6132 "method",
6133 params,
6134 WebsocketMessageSendOptions {
6135 is_session_logon: Some(true),
6136 ..Default::default()
6137 },
6138 )
6139 .await
6140 .unwrap();
6141
6142 match send_res {
6143 SendWebsocketMessageResult::Multiple(v) => v,
6144 SendWebsocketMessageResult::Single(_) => {
6145 panic!("expected multiple responses")
6146 }
6147 }
6148 }
6149 });
6150
6151 let Message::Text(txt0) = rx0.recv().await.unwrap() else {
6152 panic!()
6153 };
6154 let Message::Text(txt1) = rx1.recv().await.unwrap() else {
6155 panic!()
6156 };
6157 let req0: Value = serde_json::from_str(&txt0).unwrap();
6158 let req1: Value = serde_json::from_str(&txt1).unwrap();
6159 assert_eq!(req0["method"], "method");
6160 assert_eq!(req1["method"], "method");
6161 let id = req0["id"].as_str().unwrap().to_string();
6162 assert_eq!(req1["id"].as_str().unwrap(), &id);
6163
6164 {
6165 let mut st0 = conn0.state.lock().await;
6166 let pending0 = st0.pending_requests.remove(&id).unwrap();
6167 pending0
6168 .completion
6169 .send(Ok(json!({
6170 "id": id,
6171 "result": { "ok": true },
6172 "rateLimits": []
6173 })))
6174 .unwrap();
6175 }
6176 {
6177 let mut st1 = conn1.state.lock().await;
6178 let pending1 = st1.pending_requests.remove(&id).unwrap();
6179 pending1
6180 .completion
6181 .send(Ok(json!({
6182 "id": id,
6183 "result": { "ok": true },
6184 "rateLimits": []
6185 })))
6186 .unwrap();
6187 }
6188
6189 let results = fut.await.unwrap();
6190 assert_eq!(results.len(), 2);
6191
6192 for conn in &api.common.connection_pool {
6193 let st = conn.state.lock().await;
6194 assert!(st.is_session_logged_on, "should be logged out");
6195 assert!(st.session_logon_req.is_some(), "req cleared");
6196
6197 }
6213 });
6214 }
6215
6216 #[test]
6217 fn multi_session_logout() {
6218 TOKIO_SHARED_RT.block_on(async {
6219 let api = create_websocket_api(None, Some(WebsocketMode::Pool(2)), None);
6220
6221 for conn in &api.common.connection_pool {
6222 let (tx, _rx) = unbounded_channel::<Message>();
6223 let mut st = conn.state.lock().await;
6224 st.ws_write_tx = Some(tx);
6225 st.is_session_logged_on = true;
6226 st.session_logon_req = Some(WebsocketSessionLogonReq {
6227 method: "method".into(),
6228 payload: BTreeMap::new(),
6229 options: WebsocketMessageSendOptions::default(),
6230 });
6231 }
6232
6233 let mut rxs = Vec::new();
6234 for conn in &api.common.connection_pool {
6235 let rx = {
6236 let (tx, rx) = unbounded_channel::<Message>();
6237 conn.state.lock().await.ws_write_tx = Some(tx);
6238 rx
6239 };
6240 rxs.push(rx);
6241 }
6242
6243 let fut = tokio::spawn({
6244 let api = api.clone();
6245 async move {
6246 let send_res = api
6247 .send_message::<Value>(
6248 "method",
6249 BTreeMap::new(),
6250 WebsocketMessageSendOptions {
6251 is_signed: false,
6252 with_api_key: false,
6253 is_session_logout: Some(true),
6254 ..Default::default()
6255 },
6256 )
6257 .await
6258 .unwrap();
6259
6260 match send_res {
6261 SendWebsocketMessageResult::Multiple(v) => v,
6262 SendWebsocketMessageResult::Single(_) => panic!("expected multi"),
6263 }
6264 }
6265 });
6266
6267 let mut ids = Vec::new();
6268 for mut rx in rxs {
6269 let Message::Text(txt) = rx.recv().await.unwrap() else {
6270 panic!()
6271 };
6272 let req: Value = serde_json::from_str(&txt).unwrap();
6273 assert_eq!(req["method"], "method");
6274 ids.push(req["id"].as_str().unwrap().to_string());
6275 }
6276
6277 assert_eq!(ids[0], ids[1]);
6278
6279 for conn in &api.common.connection_pool {
6280 let id = &ids[0];
6281 let mut st = conn.state.lock().await;
6282 let pending = st.pending_requests.remove(id).unwrap();
6283 pending
6284 .completion
6285 .send(Ok(json!({
6286 "id": id,
6287 "result": {},
6288 "rateLimits": []
6289 })))
6290 .unwrap();
6291 }
6292
6293 let results = fut.await.unwrap();
6294 assert_eq!(results.len(), 2);
6295
6296 for conn in &api.common.connection_pool {
6297 let st = conn.state.lock().await;
6298 assert!(!st.is_session_logged_on, "should be logged out");
6299 assert!(st.session_logon_req.is_none(), "req cleared");
6300 }
6301 });
6302 }
6303
6304 #[test]
6305 fn skip_signature_when_logged_on_and_auto_relogon() {
6306 TOKIO_SHARED_RT.block_on(async {
6307 let api = create_websocket_api(None, Some(WebsocketMode::Single), None);
6308 let conn = &api.common.connection_pool[0];
6309 {
6310 let mut st = conn.state.lock().await;
6311 st.ws_write_tx = Some(unbounded_channel::<Message>().0);
6312 st.is_session_logged_on = true;
6313 }
6314
6315 let mut rx;
6316 {
6317 let mut st = conn.state.lock().await;
6318 let (tx, new_rx) = unbounded_channel::<Message>();
6319 st.ws_write_tx = Some(tx);
6320 rx = new_rx;
6321 }
6322
6323 let fut = tokio::spawn({
6324 let api = api.clone();
6325 async move {
6326 let send_res = api
6327 .send_message::<Value>(
6328 "method",
6329 BTreeMap::new(),
6330 WebsocketMessageSendOptions {
6331 is_signed: true,
6332 ..Default::default()
6333 },
6334 )
6335 .await
6336 .unwrap();
6337
6338 match send_res {
6339 SendWebsocketMessageResult::Single(resp) => resp,
6340 SendWebsocketMessageResult::Multiple(_) => {
6341 panic!("expected single")
6342 }
6343 }
6344 }
6345 });
6346
6347 let Message::Text(txt) = rx.recv().await.unwrap() else {
6348 panic!()
6349 };
6350 let req: Value = serde_json::from_str(&txt).unwrap();
6351 let p = &req["params"];
6352 assert!(p.get("timestamp").is_some());
6353 assert!(p.get("signature").is_none());
6354
6355 let id = req["id"].as_str().unwrap();
6356 let mut st = conn.state.lock().await;
6357 let pending = st.pending_requests.remove(id).unwrap();
6358 pending
6359 .completion
6360 .send(Ok(json!({
6361 "id": id,
6362 "result": {},
6363 "rateLimits": []
6364 })))
6365 .unwrap();
6366
6367 let resp = fut.await.unwrap();
6368 assert_eq!(resp.raw, json!({}));
6369 });
6370 }
6371
6372 #[test]
6373 fn include_signature_when_logged_on_and_no_auto_relogon() {
6374 TOKIO_SHARED_RT.block_on(async {
6375 let api = create_websocket_api(None, Some(WebsocketMode::Single), Some(false));
6376 let conn = &api.common.connection_pool[0];
6377 {
6378 let mut st = conn.state.lock().await;
6379 st.ws_write_tx = Some(unbounded_channel::<Message>().0);
6380 st.is_session_logged_on = true;
6381 }
6382
6383 let mut rx;
6384 {
6385 let mut st = conn.state.lock().await;
6386 let (tx, new_rx) = unbounded_channel::<Message>();
6387 st.ws_write_tx = Some(tx);
6388 rx = new_rx;
6389 }
6390
6391 let fut = tokio::spawn({
6392 let api = api.clone();
6393 async move {
6394 let send_res = api
6395 .send_message::<Value>(
6396 "method",
6397 BTreeMap::new(),
6398 WebsocketMessageSendOptions {
6399 is_signed: true,
6400 ..Default::default()
6401 },
6402 )
6403 .await
6404 .unwrap();
6405
6406 match send_res {
6407 SendWebsocketMessageResult::Single(resp) => resp,
6408 SendWebsocketMessageResult::Multiple(_) => {
6409 panic!("expected single")
6410 }
6411 }
6412 }
6413 });
6414
6415 let Message::Text(txt) = rx.recv().await.unwrap() else {
6416 panic!()
6417 };
6418 let req: Value = serde_json::from_str(&txt).unwrap();
6419 let p = &req["params"];
6420 assert!(p.get("timestamp").is_some());
6421 assert!(p.get("signature").is_some());
6422
6423 let id = req["id"].as_str().unwrap();
6424 let mut st = conn.state.lock().await;
6425 let pending = st.pending_requests.remove(id).unwrap();
6426 pending
6427 .completion
6428 .send(Ok(json!({
6429 "id": id,
6430 "result": {},
6431 "rateLimits": []
6432 })))
6433 .unwrap();
6434
6435 let resp = fut.await.unwrap();
6436 assert_eq!(resp.raw, json!({}));
6437 });
6438 }
6439
6440 #[test]
6441 fn error_if_not_connected() {
6442 TOKIO_SHARED_RT.block_on(async {
6443 let api = create_websocket_api(None, None, None);
6444 let conn = &api.common.connection_pool[0];
6445 {
6446 let mut st = conn.state.lock().await;
6447 st.ws_write_tx = None;
6448 }
6449 let params = BTreeMap::new();
6450 let err = api
6451 .send_message::<Value>(
6452 "method",
6453 params,
6454 WebsocketMessageSendOptions {
6455 with_api_key: false,
6456 is_signed: false,
6457 ..Default::default()
6458 },
6459 )
6460 .await
6461 .unwrap_err();
6462 matches!(err, WebsocketError::NotConnected);
6463 });
6464 }
6465 }
6466
6467 mod prepare_url {
6468 use super::*;
6469
6470 #[test]
6471 fn no_time_unit() {
6472 TOKIO_SHARED_RT.block_on(async {
6473 let api = create_websocket_api(None, None, None);
6474 let url = "wss://example.com/ws".to_string();
6475 assert_eq!(api.prepare_url(&url), url);
6476 });
6477 }
6478
6479 #[test]
6480 fn appends_time_unit() {
6481 TOKIO_SHARED_RT.block_on(async {
6482 let api = create_websocket_api(Some(TimeUnit::Millisecond), None, None);
6483 let base = "wss://example.com/ws".to_string();
6484 let got = api.prepare_url(&base);
6485 assert_eq!(got, format!("{base}?timeUnit=millisecond"));
6486 });
6487 }
6488
6489 #[test]
6490 fn handles_existing_query() {
6491 TOKIO_SHARED_RT.block_on(async {
6492 let api = create_websocket_api(Some(TimeUnit::Microsecond), None, None);
6493 let base = "wss://example.com/ws?foo=bar".to_string();
6494 let got = api.prepare_url(&base);
6495 assert_eq!(got, format!("{base}&timeUnit=microsecond"));
6496 });
6497 }
6498 }
6499
6500 mod on_open {
6501 use super::*;
6502
6503 fn create_websocket_api_and_conn() -> (Arc<WebsocketApi>, Arc<WebsocketConnection>) {
6504 let sig_gen = SignatureGenerator::new(
6505 Some("api_secret".to_string()),
6506 None::<_>,
6507 None::<String>,
6508 );
6509 let config = ConfigurationWebsocketApi {
6510 api_key: Some("api_key".to_string()),
6511 api_secret: Some("api_secret".to_string()),
6512 private_key: None,
6513 private_key_passphrase: None,
6514 ws_url: Some("wss://example".to_string()),
6515 mode: WebsocketMode::Single,
6516 reconnect_delay: 0,
6517 signature_gen: sig_gen,
6518 timeout: 1000,
6519 time_unit: None,
6520 auto_session_relogon: true,
6521 agent: None,
6522 user_agent: build_user_agent("product"),
6523 };
6524 let conn = WebsocketConnection::new("test-conn");
6525 let api = WebsocketApi::new(config, vec![conn.clone()]);
6526 (api, conn)
6527 }
6528
6529 #[test]
6530 fn session_relogon_on_open() {
6531 TOKIO_SHARED_RT.block_on(async {
6532 let (api, conn) = create_websocket_api_and_conn();
6533
6534 let req = WebsocketSessionLogonReq {
6535 method: "method".into(),
6536 payload: {
6537 let mut m = BTreeMap::new();
6538 m.insert("foo".into(), Value::String("bar".into()));
6539 m
6540 },
6541 options: WebsocketMessageSendOptions {
6542 with_api_key: true,
6543 is_signed: true,
6544 is_session_logon: Some(true),
6545 ..Default::default()
6546 },
6547 };
6548
6549 let (tx, mut rx) = unbounded_channel::<Message>();
6550 {
6551 let mut st = conn.state.lock().await;
6552 st.session_logon_req = Some(req.clone());
6553 st.is_session_logged_on = false;
6554 st.ws_write_tx = Some(tx);
6555 }
6556
6557 api.on_open("wss://example".to_string(), conn.clone()).await;
6558
6559 let Message::Text(raw) = rx.recv().await.unwrap() else {
6560 panic!("expected a Text message");
6561 };
6562 let msg: Value = serde_json::from_str(&raw).unwrap();
6563 assert_eq!(msg["method"], "method");
6564 assert_eq!(msg["params"]["foo"], "bar");
6565
6566 let id = msg["id"].as_str().unwrap().to_string();
6567 {
6568 let mut st = conn.state.lock().await;
6569 let pending = st.pending_requests.remove(&id).expect("pending request");
6570 pending
6571 .completion
6572 .send(Ok(json!({
6573 "id": id,
6574 "result": {},
6575 "rateLimits": []
6576 })))
6577 .unwrap();
6578 }
6579
6580 sleep(Duration::from_millis(10)).await;
6581
6582 let st = conn.state.lock().await;
6583 assert!(st.is_session_logged_on, "should now be logged on");
6584 });
6585 }
6586
6587 #[test]
6588 fn no_relogon_if_already_logged_on() {
6589 TOKIO_SHARED_RT.block_on(async {
6590 let (api, conn) = create_websocket_api_and_conn();
6591
6592 let req = WebsocketSessionLogonReq {
6593 method: "method".into(),
6594 payload: BTreeMap::new(),
6595 options: WebsocketMessageSendOptions {
6596 is_session_logon: Some(true),
6597 ..Default::default()
6598 },
6599 };
6600
6601 let (tx, mut rx) = unbounded_channel::<Message>();
6602 {
6603 let mut st = conn.state.lock().await;
6604 st.session_logon_req = Some(req);
6605 st.is_session_logged_on = true;
6606 st.ws_write_tx = Some(tx);
6607 }
6608
6609 api.on_open("wss://example".to_string(), conn.clone()).await;
6610
6611 assert!(rx.try_recv().is_err(), "no re‐logon when already on");
6612
6613 let st = conn.state.lock().await;
6614 assert!(st.is_session_logged_on);
6615 });
6616 }
6617
6618 #[test]
6619 fn session_relogon_fails_gracefully() {
6620 TOKIO_SHARED_RT.block_on(async {
6621 let (api, conn) = create_websocket_api_and_conn();
6622
6623 let req = WebsocketSessionLogonReq {
6624 method: "method".into(),
6625 payload: {
6626 let mut m = BTreeMap::new();
6627 m.insert("x".into(), Value::Number(1.into()));
6628 m
6629 },
6630 options: WebsocketMessageSendOptions {
6631 is_session_logon: Some(true),
6632 ..Default::default()
6633 },
6634 };
6635 {
6636 let mut st = conn.state.lock().await;
6637 st.session_logon_req = Some(req);
6638 st.is_session_logged_on = false;
6639 st.ws_write_tx = None;
6640 }
6641
6642 api.on_open("wss://example".into(), conn.clone()).await;
6643
6644 let st = conn.state.lock().await;
6645 assert!(
6646 !st.is_session_logged_on,
6647 "should remain logged‐off on failure"
6648 );
6649 });
6650 }
6651
6652 #[test]
6653 fn session_relogon_noop_when_no_req() {
6654 TOKIO_SHARED_RT.block_on(async {
6655 let (api, conn) = create_websocket_api_and_conn();
6656
6657 {
6658 let mut st = conn.state.lock().await;
6659 st.session_logon_req = None;
6660 st.is_session_logged_on = false;
6661 st.ws_write_tx = Some(unbounded_channel::<Message>().0);
6662 }
6663
6664 api.on_open("wss://example".into(), conn.clone()).await;
6665
6666 let st = conn.state.lock().await;
6667 assert!(!st.is_session_logged_on, "still logged‐off");
6668 });
6669 }
6670 }
6671
6672 mod on_message {
6673 use super::*;
6674
6675 fn create_websocket_api_and_conn() -> (Arc<WebsocketApi>, Arc<WebsocketConnection>) {
6676 let sig_gen = SignatureGenerator::new(
6677 Some("api_secret".to_string()),
6678 None::<_>,
6679 None::<String>,
6680 );
6681 let config = ConfigurationWebsocketApi {
6682 api_key: Some("api_key".to_string()),
6683 api_secret: Some("api_secret".to_string()),
6684 private_key: None,
6685 private_key_passphrase: None,
6686 ws_url: Some("wss://example".to_string()),
6687 mode: WebsocketMode::Single,
6688 reconnect_delay: 0,
6689 signature_gen: sig_gen,
6690 timeout: 1000,
6691 time_unit: None,
6692 auto_session_relogon: false,
6693 agent: None,
6694 user_agent: build_user_agent("product"),
6695 };
6696 let conn = WebsocketConnection::new("test");
6697 let api = WebsocketApi::new(config, vec![conn.clone()]);
6698 (api, conn)
6699 }
6700
6701 #[test]
6702 fn resolves_pending_and_removes_request() {
6703 TOKIO_SHARED_RT.block_on(async {
6704 let (api, conn) = create_websocket_api_and_conn();
6705 let (tx, rx) = oneshot::channel();
6706 {
6707 let mut st = conn.state.lock().await;
6708 st.pending_requests
6709 .insert("id1".to_string(), PendingRequest { completion: tx });
6710 }
6711 let msg = json!({"id":"id1","status":200,"foo":"bar"});
6712 api.on_message(msg.to_string(), conn.clone()).await;
6713 let got = rx.await.unwrap().unwrap();
6714 assert_eq!(got, msg);
6715 let st = conn.state.lock().await;
6716 assert!(!st.pending_requests.contains_key("id1"));
6717 });
6718 }
6719
6720 #[test]
6721 fn uses_result_when_present() {
6722 TOKIO_SHARED_RT.block_on(async {
6723 let (api, conn) = create_websocket_api_and_conn();
6724 let (tx, rx) = oneshot::channel();
6725 {
6726 let mut st = conn.state.lock().await;
6727 st.pending_requests
6728 .insert("id1".to_string(), PendingRequest { completion: tx });
6729 }
6730 let msg = json!({
6731 "id": "id1",
6732 "status": 200,
6733 "response": [1,2],
6734 "result": {"a":1}
6735 });
6736 api.on_message(msg.to_string(), conn.clone()).await;
6737 let got = rx.await.unwrap().unwrap();
6738 assert_eq!(got.get("result").unwrap(), &json!({"a":1}));
6739 });
6740 }
6741
6742 #[test]
6743 fn uses_response_when_no_result() {
6744 TOKIO_SHARED_RT.block_on(async {
6745 let (api, conn) = create_websocket_api_and_conn();
6746 let (tx, rx) = oneshot::channel();
6747 {
6748 let mut st = conn.state.lock().await;
6749 st.pending_requests
6750 .insert("id1".to_string(), PendingRequest { completion: tx });
6751 }
6752 let msg = json!({
6753 "id": "id1",
6754 "status": 200,
6755 "response": ["ok"]
6756 });
6757 api.on_message(msg.to_string(), conn.clone()).await;
6758 let got = rx.await.unwrap().unwrap();
6759 assert_eq!(got.get("response").unwrap(), &json!(["ok"]));
6760 });
6761 }
6762
6763 #[test]
6764 fn errors_for_status_ge_400() {
6765 TOKIO_SHARED_RT.block_on(async {
6766 let (api, conn) = create_websocket_api_and_conn();
6767 let (tx, rx) = oneshot::channel();
6768 {
6769 let mut st = conn.state.lock().await;
6770 st.pending_requests
6771 .insert("bad".to_string(), PendingRequest { completion: tx });
6772 }
6773 let err_obj = json!({"code":123,"msg":"oops"});
6774 let msg = json!({"id":"bad","status":500,"error":err_obj});
6775 api.on_message(msg.to_string(), conn.clone()).await;
6776 match rx.await.unwrap() {
6777 Err(WebsocketError::ResponseError { code, message }) => {
6778 assert_eq!(code, 123);
6779 assert_eq!(message, "oops");
6780 }
6781 other => panic!("expected ResponseError, got {other:?}"),
6782 }
6783 let st = conn.state.lock().await;
6784 assert!(!st.pending_requests.contains_key("bad"));
6785 });
6786 }
6787
6788 #[test]
6789 fn ignores_unknown_id() {
6790 TOKIO_SHARED_RT.block_on(async {
6791 let (api, conn) = create_websocket_api_and_conn();
6792 let msg = json!({"id":"nope","status":200});
6793 api.on_message(msg.to_string(), conn.clone()).await;
6794 let st = conn.state.lock().await;
6795 assert!(st.pending_requests.is_empty());
6796 });
6797 }
6798
6799 #[test]
6800 fn parse_error_ignored() {
6801 TOKIO_SHARED_RT.block_on(async {
6802 let (api, conn) = create_websocket_api_and_conn();
6803 api.on_message("not json".to_string(), conn.clone()).await;
6804 let st = conn.state.lock().await;
6805 assert!(st.pending_requests.is_empty());
6806 });
6807 }
6808
6809 #[test]
6810 fn error_status_sends_error() {
6811 TOKIO_SHARED_RT.block_on(async {
6812 let (api, conn) = create_websocket_api_and_conn();
6813 let (tx, rx) = oneshot::channel();
6814 {
6815 let mut st = conn.state.lock().await;
6816 st.pending_requests
6817 .insert("err".to_string(), PendingRequest { completion: tx });
6818 }
6819 let msg = json!({
6820 "id": "err",
6821 "status": 500,
6822 "error": { "code": 42, "msg": "Bad!" }
6823 });
6824 api.on_message(msg.to_string(), conn.clone()).await;
6825 match rx.await.unwrap() {
6826 Err(WebsocketError::ResponseError { code, message }) => {
6827 assert_eq!(code, 42);
6828 assert_eq!(message, "Bad!");
6829 }
6830 other => panic!("expected ResponseError, got {other:?}"),
6831 }
6832 });
6833 }
6834
6835 #[test]
6836 fn unknown_id_logs_warning_and_leaves_pending() {
6837 TOKIO_SHARED_RT.block_on(async {
6838 let (api, conn) = create_websocket_api_and_conn();
6839 {
6840 let mut st = conn.state.lock().await;
6841 st.pending_requests.insert(
6842 "keep".to_string(),
6843 PendingRequest {
6844 completion: oneshot::channel().0,
6845 },
6846 );
6847 }
6848 api.on_message(
6849 json!({ "id": "foo", "status": 200, "result": 1 }).to_string(),
6850 conn.clone(),
6851 )
6852 .await;
6853 let st = conn.state.lock().await;
6854 assert!(st.pending_requests.contains_key("keep"));
6855 });
6856 }
6857
6858 #[test]
6859 fn server_shutdown_enqueues_reconnect() {
6860 TOKIO_SHARED_RT.block_on(async {
6861 let (api, conn) = create_websocket_api_and_conn();
6862
6863 let msg = json!({
6864 "event": { "e": "serverShutdown" }
6865 });
6866
6867 api.on_message(msg.to_string(), conn.clone()).await;
6868
6869 let st = conn.state.lock().await;
6870 assert!(st.renewal_pending);
6871 });
6872 }
6873
6874 #[test]
6875 fn server_shutdown_ignored_if_renewal_pending() {
6876 TOKIO_SHARED_RT.block_on(async {
6877 let (api, conn) = create_websocket_api_and_conn();
6878
6879 {
6880 let mut st = conn.state.lock().await;
6881 st.renewal_pending = true;
6882 }
6883
6884 let msg = json!({
6885 "event": { "e": "serverShutdown" }
6886 });
6887
6888 api.on_message(msg.to_string(), conn.clone()).await;
6889
6890 let st = conn.state.lock().await;
6891
6892 assert!(st.renewal_pending);
6893 });
6894 }
6895
6896 #[test]
6897 fn server_shutdown_ignored_if_close_initiated() {
6898 TOKIO_SHARED_RT.block_on(async {
6899 let (api, conn) = create_websocket_api_and_conn();
6900
6901 {
6902 let mut st = conn.state.lock().await;
6903 st.close_initiated = true;
6904 }
6905
6906 let msg = json!({
6907 "event": { "e": "serverShutdown" }
6908 });
6909
6910 api.on_message(msg.to_string(), conn.clone()).await;
6911
6912 let st = conn.state.lock().await;
6913
6914 assert!(!st.renewal_pending);
6915 });
6916 }
6917
6918 #[test]
6919 fn server_shutdown_does_not_touch_pending_requests() {
6920 TOKIO_SHARED_RT.block_on(async {
6921 let (api, conn) = create_websocket_api_and_conn();
6922
6923 {
6924 let mut st = conn.state.lock().await;
6925 st.pending_requests.insert(
6926 "keep".to_string(),
6927 PendingRequest {
6928 completion: oneshot::channel().0,
6929 },
6930 );
6931 }
6932
6933 let msg = json!({
6934 "event": { "e": "serverShutdown" }
6935 });
6936
6937 api.on_message(msg.to_string(), conn.clone()).await;
6938
6939 let st = conn.state.lock().await;
6940
6941 assert!(st.pending_requests.contains_key("keep"));
6942 assert!(st.renewal_pending);
6943 });
6944 }
6945 }
6946 }
6947
6948 mod websocket_streams {
6949 use super::*;
6950
6951 mod initialisation {
6952 use super::*;
6953
6954 #[test]
6955 fn new_initializes_fields() {
6956 TOKIO_SHARED_RT.block_on(async {
6957 let config = ConfigurationWebsocketStreams {
6958 ws_url: Some("wss://example".to_string()),
6959 mode: WebsocketMode::Pool(2),
6960 reconnect_delay: 500,
6961 time_unit: None,
6962 agent: None,
6963 user_agent: build_user_agent("product"),
6964 };
6965 let conn1 = WebsocketConnection::new("c1");
6966 let conn2 = WebsocketConnection::new("c2");
6967 let api = WebsocketStreams::new(
6968 config.clone(),
6969 vec![conn1.clone(), conn2.clone()],
6970 vec![],
6971 );
6972
6973 assert_eq!(api.common.connection_pool.len(), 2);
6974 assert!(Arc::ptr_eq(&api.common.connection_pool[0], &conn1));
6975 assert!(Arc::ptr_eq(&api.common.connection_pool[1], &conn2));
6976 assert_eq!(api.configuration.ws_url, Some("wss://example".to_string()));
6977 let flag = api.is_connecting.lock().await;
6978 assert!(!*flag);
6979 });
6980 }
6981
6982 #[test]
6983 fn new_expands_pool_when_url_paths_present() {
6984 TOKIO_SHARED_RT.block_on(async {
6985 let config = ConfigurationWebsocketStreams {
6986 ws_url: Some("wss://example".to_string()),
6987 mode: WebsocketMode::Pool(2),
6988 reconnect_delay: 500,
6989 time_unit: None,
6990 agent: None,
6991 user_agent: build_user_agent("product"),
6992 };
6993
6994 let conn1 = WebsocketConnection::new("c1");
6995 let conn2 = WebsocketConnection::new("c2");
6996
6997 let api = WebsocketStreams::new(
6998 config,
6999 vec![conn1.clone(), conn2.clone()],
7000 vec!["path1".to_string(), "path2".to_string()],
7001 );
7002
7003 assert_eq!(api.common.connection_pool.len(), 4);
7004 assert!(Arc::ptr_eq(&api.common.connection_pool[0], &conn1));
7005 assert!(Arc::ptr_eq(&api.common.connection_pool[1], &conn2));
7006 });
7007 }
7008
7009 #[test]
7010 fn new_does_not_expand_pool_when_already_sized_for_url_paths() {
7011 TOKIO_SHARED_RT.block_on(async {
7012 let config = ConfigurationWebsocketStreams {
7013 ws_url: Some("wss://example".to_string()),
7014 mode: WebsocketMode::Pool(2),
7015 reconnect_delay: 500,
7016 time_unit: None,
7017 agent: None,
7018 user_agent: build_user_agent("product"),
7019 };
7020
7021 let conns = vec![
7022 WebsocketConnection::new("c1"),
7023 WebsocketConnection::new("c2"),
7024 WebsocketConnection::new("c3"),
7025 WebsocketConnection::new("c4"),
7026 ];
7027
7028 let api = WebsocketStreams::new(
7029 config,
7030 conns.clone(),
7031 vec!["path1".to_string(), "path2".to_string()],
7032 );
7033
7034 assert_eq!(api.common.connection_pool.len(), 4);
7035 for (i, c) in conns.iter().enumerate() {
7036 assert!(Arc::ptr_eq(&api.common.connection_pool[i], c));
7037 }
7038 });
7039 }
7040 }
7041
7042 mod connect {
7043 use super::*;
7044
7045 #[test]
7046 fn establishes_successfully() {
7047 TOKIO_SHARED_RT.block_on(async {
7048 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
7049 let port = listener.local_addr().unwrap().port();
7050
7051 tokio::spawn(async move {
7052 for _ in 0..2 {
7053 if let Ok((stream, _)) = listener.accept().await {
7054 let mut ws = accept_async(stream).await.unwrap();
7055 ws.close(None).await.ok();
7056 }
7057 }
7058 });
7059
7060 let create_websocket_streams = |ws_url: &str| {
7061 let c1 = WebsocketConnection::new("c1");
7062 let c2 = WebsocketConnection::new("c2");
7063 let config = ConfigurationWebsocketStreams {
7064 ws_url: Some(ws_url.to_string()),
7065 mode: WebsocketMode::Pool(2),
7066 reconnect_delay: 500,
7067 time_unit: None,
7068 agent: None,
7069 user_agent: build_user_agent("product"),
7070 };
7071 WebsocketStreams::new(config, vec![c1, c2], vec![])
7072 };
7073
7074 let url = format!("ws://127.0.0.1:{port}");
7075 let ws = create_websocket_streams(&url);
7076
7077 let res = ws.connect(vec!["stream1".into()]).await;
7078 assert!(res.is_ok());
7079 });
7080 }
7081
7082 #[test]
7083 fn establishes_successfully_with_url_paths() {
7084 TOKIO_SHARED_RT.block_on(async {
7085 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
7086 let addr = listener.local_addr().unwrap();
7087
7088 tokio::spawn(async move {
7089 for _ in 0..4 {
7090 if let Ok((stream, _)) = listener.accept().await {
7091 let mut ws = accept_async(stream).await.unwrap();
7092 ws.close(None).await.ok();
7093 }
7094 }
7095 });
7096
7097 let config = ConfigurationWebsocketStreams {
7098 ws_url: Some(format!("ws://{}", addr)),
7099 mode: WebsocketMode::Pool(2),
7100 reconnect_delay: 500,
7101 time_unit: None,
7102 agent: None,
7103 user_agent: build_user_agent("product"),
7104 };
7105
7106 let ws = WebsocketStreams::new(
7107 config,
7108 vec![],
7109 vec!["path1".to_string(), "path2".to_string()],
7110 );
7111
7112 let res = ws.clone().connect(vec!["stream1".into()]).await;
7113 assert!(res.is_ok());
7114 assert_eq!(ws.common.connection_pool.len(), 4);
7115 });
7116 }
7117
7118 #[test]
7119 fn connect_sets_url_path_on_connections_when_url_paths_present() {
7120 TOKIO_SHARED_RT.block_on(async {
7121 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
7122 let addr = listener.local_addr().unwrap();
7123
7124 tokio::spawn(async move {
7125 for _ in 0..4 {
7126 if let Ok((stream, _)) = listener.accept().await {
7127 let mut ws = accept_async(stream).await.unwrap();
7128 ws.close(None).await.ok();
7129 }
7130 }
7131 });
7132
7133 let config = ConfigurationWebsocketStreams {
7134 ws_url: Some(format!("ws://{}", addr)),
7135 mode: WebsocketMode::Pool(2),
7136 reconnect_delay: 500,
7137 time_unit: None,
7138 agent: None,
7139 user_agent: build_user_agent("product"),
7140 };
7141
7142 let ws = WebsocketStreams::new(
7143 config,
7144 vec![],
7145 vec!["path1".to_string(), "path2".to_string()],
7146 );
7147
7148 ws.clone().connect(vec!["stream1".into()]).await.unwrap();
7149
7150 let pool_size = ws.configuration.mode.pool_size();
7151
7152 for (i, conn) in ws.common.connection_pool.iter().enumerate() {
7153 let expected = if i < pool_size { "path1" } else { "path2" };
7154 let st = conn.state.lock().await;
7155 assert_eq!(st.url_path.as_deref(), Some(expected));
7156 }
7157 });
7158 }
7159
7160 #[test]
7161 fn refused_returns_error() {
7162 TOKIO_SHARED_RT.block_on(async {
7163 let ws = create_websocket_streams(Some("ws://127.0.0.1:9"), None, None);
7164 let res = ws.connect(vec!["stream1".into()]).await;
7165 assert!(res.is_err());
7166 });
7167 }
7168
7169 #[test]
7170 fn invalid_url_returns_error() {
7171 TOKIO_SHARED_RT.block_on(async {
7172 let ws = create_websocket_streams(Some("not-a-url"), None, None);
7173 let res = ws.connect(vec!["s".into()]).await;
7174 assert!(res.is_err());
7175 });
7176 }
7177 }
7178
7179 mod disconnect {
7180 use super::*;
7181
7182 #[test]
7183 fn disconnect_clears_state_and_streams() {
7184 TOKIO_SHARED_RT.block_on(async {
7185 let ws = create_websocket_streams(None, None, None);
7186 let conn = &ws.common.connection_pool[0];
7187 {
7188 let mut state = conn.state.lock().await;
7189 state.stream_callbacks.insert("s1".to_string(), Vec::new());
7190 state.pending_subscriptions.push_back("s2".to_string());
7191 }
7192 {
7193 let mut map = ws.connection_streams.lock().await;
7194 map.insert("s3".to_string(), Arc::clone(conn));
7195 }
7196
7197 let res = ws.disconnect().await;
7198 assert!(res.is_ok());
7199
7200 let state = conn.state.lock().await;
7201 assert!(state.stream_callbacks.is_empty());
7202 assert!(state.pending_subscriptions.is_empty());
7203
7204 let map = ws.connection_streams.lock().await;
7205 assert!(map.is_empty());
7206 });
7207 }
7208 }
7209
7210 mod subscribe {
7211 use super::*;
7212
7213 #[test]
7214 fn empty_list_does_nothing() {
7215 TOKIO_SHARED_RT.block_on(async {
7216 let ws = create_websocket_streams(None, None, None);
7217 ws.clone().subscribe(Vec::new(), None, None).await;
7218 let map = ws.connection_streams.lock().await;
7219 assert!(map.is_empty());
7220 });
7221 }
7222
7223 #[test]
7224 fn queue_when_not_ready() {
7225 TOKIO_SHARED_RT.block_on(async {
7226 let ws = create_websocket_streams(None, None, None);
7227 let conn = ws.common.connection_pool[0].clone();
7228 ws.clone().subscribe(vec!["s1".into()], None, None).await;
7229 let state = conn.state.lock().await;
7230 let pending: Vec<String> =
7231 state.pending_subscriptions.iter().cloned().collect();
7232 assert_eq!(pending, vec!["s1".to_string()]);
7233 });
7234 }
7235
7236 #[test]
7237 fn only_one_subscription_per_stream() {
7238 TOKIO_SHARED_RT.block_on(async {
7239 let ws = create_websocket_streams(None, None, None);
7240 let conn = ws.common.connection_pool[0].clone();
7241 ws.clone().subscribe(vec!["s1".into()], None, None).await;
7242 ws.clone().subscribe(vec!["s1".into()], None, None).await;
7243 let state = conn.state.lock().await;
7244 let pending: Vec<String> =
7245 state.pending_subscriptions.iter().cloned().collect();
7246 assert_eq!(pending, vec!["s1".to_string()]);
7247 });
7248 }
7249
7250 #[test]
7251 fn multiple_streams_assigned() {
7252 TOKIO_SHARED_RT.block_on(async {
7253 let ws = create_websocket_streams(None, None, None);
7254 ws.clone()
7255 .subscribe(vec!["s1".into(), "s2".into()], None, None)
7256 .await;
7257 let map = ws.connection_streams.lock().await;
7258 assert!(map.contains_key("s1"));
7259 assert!(map.contains_key("s2"));
7260 });
7261 }
7262
7263 #[test]
7264 fn existing_stream_not_reassigned() {
7265 TOKIO_SHARED_RT.block_on(async {
7266 let ws = create_websocket_streams(None, None, None);
7267 ws.clone().subscribe(vec!["s1".into()], None, None).await;
7268 let first_id = {
7269 let map = ws.connection_streams.lock().await;
7270 map.get("s1").unwrap().id.clone()
7271 };
7272 ws.clone()
7273 .subscribe(vec!["s1".into(), "s2".into()], None, None)
7274 .await;
7275 let map = ws.connection_streams.lock().await;
7276 let second_id = map.get("s1").unwrap().id.clone();
7277 assert_eq!(first_id, second_id);
7278 assert!(map.contains_key("s2"));
7279 });
7280 }
7281
7282 #[test]
7283 fn queue_when_not_ready_with_url_path() {
7284 TOKIO_SHARED_RT.block_on(async {
7285 let ws = create_websocket_streams(None, None, None);
7286
7287 let conn = ws.common.connection_pool[0].clone();
7288 {
7289 let mut st = conn.state.lock().await;
7290 st.ws_write_tx = None;
7291 st.url_path = Some("path1".to_string());
7292 st.reconnection_pending = false;
7293 st.close_initiated = false;
7294 }
7295
7296 ws.clone()
7297 .subscribe(vec!["s1".into()], None, Some("path1"))
7298 .await;
7299
7300 let state = conn.state.lock().await;
7301 let pending: Vec<String> =
7302 state.pending_subscriptions.iter().cloned().collect();
7303 assert_eq!(pending, vec!["s1".to_string()]);
7304 });
7305 }
7306
7307 #[test]
7308 fn only_one_subscription_per_stream_per_url_path() {
7309 TOKIO_SHARED_RT.block_on(async {
7310 let ws = create_websocket_streams(None, None, None);
7311
7312 let conn = ws.common.connection_pool[0].clone();
7313 {
7314 let mut st = conn.state.lock().await;
7315 st.ws_write_tx = None;
7316 st.url_path = Some("path1".to_string());
7317 st.reconnection_pending = false;
7318 st.close_initiated = false;
7319 }
7320
7321 ws.clone()
7322 .subscribe(vec!["s1".into()], None, Some("path1"))
7323 .await;
7324 ws.clone()
7325 .subscribe(vec!["s1".into()], None, Some("path1"))
7326 .await;
7327
7328 let state = conn.state.lock().await;
7329 let pending: Vec<String> =
7330 state.pending_subscriptions.iter().cloned().collect();
7331 assert_eq!(pending, vec!["s1".to_string()]);
7332 });
7333 }
7334
7335 #[test]
7336 fn same_stream_can_be_subscribed_on_different_url_paths() {
7337 TOKIO_SHARED_RT.block_on(async {
7338 let ws = create_websocket_streams(None, None, None);
7339
7340 let conn1 = ws.common.connection_pool[0].clone();
7341 let conn2 = ws.common.connection_pool[1].clone();
7342
7343 {
7344 let mut st1 = conn1.state.lock().await;
7345 st1.ws_write_tx = None;
7346 st1.url_path = Some("path1".to_string());
7347 st1.reconnection_pending = false;
7348 st1.close_initiated = false;
7349 }
7350 {
7351 let mut st2 = conn2.state.lock().await;
7352 st2.ws_write_tx = None;
7353 st2.url_path = Some("path2".to_string());
7354 st2.reconnection_pending = false;
7355 st2.close_initiated = false;
7356 }
7357
7358 ws.clone()
7359 .subscribe(vec!["s1".into()], None, Some("path1"))
7360 .await;
7361 ws.clone()
7362 .subscribe(vec!["s1".into()], None, Some("path2"))
7363 .await;
7364
7365 let map = ws.connection_streams.lock().await;
7366 assert!(map.contains_key("path1::s1"));
7367 assert!(map.contains_key("path2::s1"));
7368 });
7369 }
7370 }
7371
7372 mod unsubscribe {
7373 use super::*;
7374
7375 #[test]
7376 fn removes_stream_with_no_callbacks() {
7377 TOKIO_SHARED_RT.block_on(async {
7378 let ws = create_websocket_streams(None, None, None);
7379 let conn = ws.common.connection_pool[0].clone();
7380
7381 {
7382 let (tx, _rx) = unbounded_channel::<Message>();
7383 let mut st = conn.state.lock().await;
7384 st.ws_write_tx = Some(tx);
7385 }
7386
7387 {
7388 let mut map = ws.connection_streams.lock().await;
7389 map.insert("s1".to_string(), conn.clone());
7390 }
7391 {
7392 let mut st = conn.state.lock().await;
7393 st.stream_callbacks.insert("s1".to_string(), Vec::new());
7394 }
7395
7396 ws.unsubscribe(vec!["s1".to_string()], None, None).await;
7397
7398 assert!(!ws.connection_streams.lock().await.contains_key("s1"));
7399 assert!(!conn.state.lock().await.stream_callbacks.contains_key("s1"));
7400 });
7401 }
7402
7403 #[test]
7404 fn preserves_stream_with_callbacks() {
7405 TOKIO_SHARED_RT.block_on(async {
7406 let ws = create_websocket_streams(None, None, None);
7407 let conn = ws.common.connection_pool[1].clone();
7408
7409 {
7410 let mut map = ws.connection_streams.lock().await;
7411 map.insert("s2".to_string(), conn.clone());
7412 }
7413 {
7414 let mut state = conn.state.lock().await;
7415 state
7416 .stream_callbacks
7417 .insert("s2".to_string(), vec![Arc::new(|_: &Value| {})]);
7418 }
7419
7420 ws.unsubscribe(vec!["s2".to_string()], None, None).await;
7421
7422 assert!(ws.connection_streams.lock().await.contains_key("s2"));
7423 assert!(conn.state.lock().await.stream_callbacks.contains_key("s2"));
7424 });
7425 }
7426
7427 #[test]
7428 fn does_not_send_if_callbacks_exist() {
7429 TOKIO_SHARED_RT.block_on(async {
7430 let ws = create_websocket_streams(None, None, None);
7431 let conn = ws.common.connection_pool[0].clone();
7432 {
7433 let mut map = ws.connection_streams.lock().await;
7434 map.insert("s1".to_string(), conn.clone());
7435 }
7436 {
7437 let mut state = conn.state.lock().await;
7438 state.stream_callbacks.insert(
7439 "s1".to_string(),
7440 vec![Arc::new(|_: &Value| {}), Arc::new(|_: &Value| {})],
7441 );
7442 }
7443 ws.unsubscribe(vec!["s1".into()], None, None).await;
7444 assert!(ws.connection_streams.lock().await.contains_key("s1"));
7445 assert!(conn.state.lock().await.stream_callbacks.contains_key("s1"));
7446 });
7447 }
7448
7449 #[test]
7450 fn warns_if_not_associated() {
7451 TOKIO_SHARED_RT.block_on(async {
7452 let ws = create_websocket_streams(None, None, None);
7453 ws.unsubscribe(vec!["nope".into()], None, None).await;
7454 });
7455 }
7456
7457 #[test]
7458 fn empty_list_does_nothing() {
7459 TOKIO_SHARED_RT.block_on(async {
7460 let ws = create_websocket_streams(None, None, None);
7461 let before = ws.connection_streams.lock().await.len();
7462 ws.unsubscribe(Vec::<String>::new(), None, None).await;
7463 let after = ws.connection_streams.lock().await.len();
7464 assert_eq!(before, after);
7465 });
7466 }
7467
7468 #[test]
7469 fn invalid_custom_id_falls_back() {
7470 TOKIO_SHARED_RT.block_on(async {
7471 let ws = create_websocket_streams(None, None, None);
7472 let conn = ws.common.connection_pool[0].clone();
7473 {
7474 let mut map = ws.connection_streams.lock().await;
7475 map.insert("foo".to_string(), conn.clone());
7476 }
7477 {
7478 let mut state = conn.state.lock().await;
7479 let (tx, _rx) = unbounded_channel();
7480 state.ws_write_tx = Some(tx);
7481 state.stream_callbacks.insert("foo".to_string(), Vec::new());
7482 }
7483 ws.unsubscribe(
7484 vec!["foo".into()],
7485 Some(StreamId::Str("bad-id".into())),
7486 None,
7487 )
7488 .await;
7489 assert!(!ws.connection_streams.lock().await.contains_key("foo"));
7490 });
7491 }
7492
7493 #[test]
7494 fn removes_even_without_write_channel() {
7495 TOKIO_SHARED_RT.block_on(async {
7496 let ws = create_websocket_streams(None, None, None);
7497 let conn = ws.common.connection_pool[0].clone();
7498 {
7499 let mut map = ws.connection_streams.lock().await;
7500 map.insert("x".to_string(), conn.clone());
7501 }
7502 {
7503 let mut state = conn.state.lock().await;
7504 let (tx, _rx) = unbounded_channel();
7505 state.ws_write_tx = Some(tx);
7506 state.stream_callbacks.insert("x".to_string(), Vec::new());
7507 }
7508 ws.unsubscribe(vec!["x".into()], None, None).await;
7509 assert!(!ws.connection_streams.lock().await.contains_key("x"));
7510 });
7511 }
7512
7513 #[test]
7514 fn removes_stream_with_no_callbacks_with_url_path() {
7515 TOKIO_SHARED_RT.block_on(async {
7516 let ws = create_websocket_streams(None, None, None);
7517 let conn = ws.common.connection_pool[0].clone();
7518
7519 {
7520 let (tx, _rx) = unbounded_channel::<Message>();
7521 let mut st = conn.state.lock().await;
7522 st.ws_write_tx = Some(tx);
7523 st.url_path = Some("path1".to_string());
7524 }
7525
7526 {
7527 let mut map = ws.connection_streams.lock().await;
7528 map.insert("path1::s1".to_string(), conn.clone());
7529 }
7530 {
7531 let mut st = conn.state.lock().await;
7532 st.stream_callbacks
7533 .insert("path1::s1".to_string(), Vec::new());
7534 }
7535
7536 ws.unsubscribe(vec!["s1".to_string()], None, Some("path1"))
7537 .await;
7538
7539 assert!(!ws.connection_streams.lock().await.contains_key("path1::s1"));
7540 assert!(
7541 !conn
7542 .state
7543 .lock()
7544 .await
7545 .stream_callbacks
7546 .contains_key("path1::s1")
7547 );
7548 });
7549 }
7550
7551 #[test]
7552 fn preserves_stream_with_callbacks_with_url_path() {
7553 TOKIO_SHARED_RT.block_on(async {
7554 let ws = create_websocket_streams(None, None, None);
7555 let conn = ws.common.connection_pool[0].clone();
7556
7557 {
7558 let (tx, _rx) = unbounded_channel::<Message>();
7559 let mut st = conn.state.lock().await;
7560 st.ws_write_tx = Some(tx);
7561 st.url_path = Some("path1".to_string());
7562 }
7563
7564 {
7565 let mut map = ws.connection_streams.lock().await;
7566 map.insert("path1::s2".to_string(), conn.clone());
7567 }
7568 {
7569 let mut state = conn.state.lock().await;
7570 state
7571 .stream_callbacks
7572 .insert("path1::s2".to_string(), vec![Arc::new(|_: &Value| {})]);
7573 }
7574
7575 ws.unsubscribe(vec!["s2".to_string()], None, Some("path1"))
7576 .await;
7577
7578 assert!(ws.connection_streams.lock().await.contains_key("path1::s2"));
7579 assert!(
7580 conn.state
7581 .lock()
7582 .await
7583 .stream_callbacks
7584 .contains_key("path1::s2")
7585 );
7586 });
7587 }
7588
7589 #[test]
7590 fn url_path_mismatch_does_not_remove_other_path_subscription() {
7591 TOKIO_SHARED_RT.block_on(async {
7592 let ws = create_websocket_streams(None, None, None);
7593 let conn = ws.common.connection_pool[0].clone();
7594
7595 {
7596 let (tx, _rx) = unbounded_channel::<Message>();
7597 let mut st = conn.state.lock().await;
7598 st.ws_write_tx = Some(tx);
7599 st.url_path = Some("path1".to_string());
7600 }
7601
7602 {
7603 let mut map = ws.connection_streams.lock().await;
7604 map.insert("path1::s1".to_string(), conn.clone());
7605 }
7606 {
7607 let mut st = conn.state.lock().await;
7608 st.stream_callbacks
7609 .insert("path1::s1".to_string(), Vec::new());
7610 }
7611
7612 ws.unsubscribe(vec!["s1".to_string()], None, Some("path2"))
7613 .await;
7614
7615 assert!(ws.connection_streams.lock().await.contains_key("path1::s1"));
7616 assert!(
7617 conn.state
7618 .lock()
7619 .await
7620 .stream_callbacks
7621 .contains_key("path1::s1")
7622 );
7623 });
7624 }
7625 }
7626
7627 mod is_subscribed {
7628 use super::*;
7629
7630 #[test]
7631 fn returns_false_when_not_subscribed() {
7632 TOKIO_SHARED_RT.block_on(async {
7633 let ws = create_websocket_streams(None, None, None);
7634 assert!(!ws.is_subscribed("unknown").await);
7635 });
7636 }
7637
7638 #[test]
7639 fn returns_true_when_subscribed() {
7640 TOKIO_SHARED_RT.block_on(async {
7641 let ws = create_websocket_streams(None, None, None);
7642 let conn = ws.common.connection_pool[0].clone();
7643 {
7644 let mut map = ws.connection_streams.lock().await;
7645 map.insert("stream1".to_string(), conn);
7646 }
7647 assert!(ws.is_subscribed("stream1").await);
7648 });
7649 }
7650
7651 #[test]
7652 fn returns_true_when_subscribed_with_url_path_key() {
7653 TOKIO_SHARED_RT.block_on(async {
7654 let ws = create_websocket_streams(None, None, None);
7655 let conn = ws.common.connection_pool[0].clone();
7656 {
7657 let mut map = ws.connection_streams.lock().await;
7658 map.insert("path1::stream1".to_string(), conn);
7659 }
7660 assert!(ws.is_subscribed("stream1").await);
7661 });
7662 }
7663
7664 #[test]
7665 fn returns_true_when_same_stream_subscribed_on_multiple_paths() {
7666 TOKIO_SHARED_RT.block_on(async {
7667 let ws = create_websocket_streams(None, None, None);
7668 let conn1 = ws.common.connection_pool[0].clone();
7669 let conn2 = ws.common.connection_pool[1].clone();
7670 {
7671 let mut map = ws.connection_streams.lock().await;
7672 map.insert("path1::stream1".to_string(), conn1);
7673 map.insert("path2::stream1".to_string(), conn2);
7674 }
7675 assert!(ws.is_subscribed("stream1").await);
7676 });
7677 }
7678
7679 #[test]
7680 fn returns_false_when_only_similar_suffix_exists() {
7681 TOKIO_SHARED_RT.block_on(async {
7682 let ws = create_websocket_streams(None, None, None);
7683 let conn = ws.common.connection_pool[0].clone();
7684 {
7685 let mut map = ws.connection_streams.lock().await;
7686 map.insert("path1::stream10".to_string(), conn);
7687 }
7688 assert!(!ws.is_subscribed("stream1").await);
7689 });
7690 }
7691 }
7692
7693 mod stream_key {
7694 use super::*;
7695
7696 #[test]
7697 fn stream_key_without_url_path_returns_stream() {
7698 TOKIO_SHARED_RT.block_on(async {
7699 let ws = create_websocket_streams(None, None, None);
7700 assert_eq!(ws.stream_key("s1", None), "s1");
7701 });
7702 }
7703
7704 #[test]
7705 fn stream_key_with_empty_url_path_returns_stream() {
7706 TOKIO_SHARED_RT.block_on(async {
7707 let ws = create_websocket_streams(None, None, None);
7708 assert_eq!(ws.stream_key("s1", Some("")), "s1");
7709 });
7710 }
7711
7712 #[test]
7713 fn stream_key_with_url_path_prefixes_stream() {
7714 TOKIO_SHARED_RT.block_on(async {
7715 let ws = create_websocket_streams(None, None, None);
7716 assert_eq!(ws.stream_key("s1", Some("path1")), "path1::s1");
7717 });
7718 }
7719
7720 #[test]
7721 fn stream_key_distinguishes_paths() {
7722 TOKIO_SHARED_RT.block_on(async {
7723 let ws = create_websocket_streams(None, None, None);
7724 assert_eq!(ws.stream_key("s1", Some("path1")), "path1::s1");
7725 assert_eq!(ws.stream_key("s1", Some("path2")), "path2::s1");
7726 });
7727 }
7728 }
7729
7730 mod prepare_url {
7731 use super::*;
7732
7733 #[test]
7734 fn without_time_unit_returns_base_url() {
7735 TOKIO_SHARED_RT.block_on(async {
7736 let conns = vec![
7737 WebsocketConnection::new("c1"),
7738 WebsocketConnection::new("c2"),
7739 ];
7740 let config = ConfigurationWebsocketStreams {
7741 ws_url: Some("wss://example".to_string()),
7742 mode: WebsocketMode::Single,
7743 reconnect_delay: 100,
7744 time_unit: None,
7745 agent: None,
7746 user_agent: build_user_agent("product"),
7747 };
7748 let ws = WebsocketStreams::new(config, conns, vec![]);
7749 let url = ws.prepare_url(&["s1".into(), "s2".into()], None);
7750 assert_eq!(url, "wss://example/stream?streams=s1/s2");
7751 });
7752 }
7753
7754 #[test]
7755 fn with_time_unit_appends_parameter() {
7756 TOKIO_SHARED_RT.block_on(async {
7757 let conns = vec![WebsocketConnection::new("c1")];
7758 let config = ConfigurationWebsocketStreams {
7759 ws_url: Some("wss://example".to_string()),
7760 mode: WebsocketMode::Single,
7761 reconnect_delay: 100,
7762 time_unit: Some(TimeUnit::Millisecond),
7763 agent: None,
7764 user_agent: build_user_agent("product"),
7765 };
7766 let ws = WebsocketStreams::new(config, conns, vec![]);
7767 let url = ws.prepare_url(&["a".into()], None);
7768 assert_eq!(url, "wss://example/stream?streams=a&timeUnit=millisecond");
7769 });
7770 }
7771
7772 #[test]
7773 fn multiple_streams_and_time_unit() {
7774 TOKIO_SHARED_RT.block_on(async {
7775 let conns = vec![WebsocketConnection::new("c1")];
7776 let config = ConfigurationWebsocketStreams {
7777 ws_url: Some("wss://example".to_string()),
7778 mode: WebsocketMode::Single,
7779 reconnect_delay: 100,
7780 time_unit: Some(TimeUnit::Microsecond),
7781 agent: None,
7782 user_agent: build_user_agent("product"),
7783 };
7784 let ws = WebsocketStreams::new(config, conns, vec![]);
7785 let url = ws.prepare_url(&["x".into(), "y".into(), "z".into()], None);
7786 assert_eq!(
7787 url,
7788 "wss://example/stream?streams=x/y/z&timeUnit=microsecond"
7789 );
7790 });
7791 }
7792
7793 #[test]
7794 fn with_url_path_prefixes_base_url() {
7795 TOKIO_SHARED_RT.block_on(async {
7796 let conns = vec![WebsocketConnection::new("c1")];
7797 let config = ConfigurationWebsocketStreams {
7798 ws_url: Some("wss://example".to_string()),
7799 mode: WebsocketMode::Single,
7800 reconnect_delay: 100,
7801 time_unit: None,
7802 agent: None,
7803 user_agent: build_user_agent("product"),
7804 };
7805 let ws = WebsocketStreams::new(config, conns, vec![]);
7806 let url = ws.prepare_url(["s1".into()].as_ref(), Some("path1"));
7807 assert_eq!(url, "wss://example/path1/stream?streams=s1");
7808 });
7809 }
7810
7811 #[test]
7812 fn with_url_path_and_time_unit_appends_parameter() {
7813 TOKIO_SHARED_RT.block_on(async {
7814 let conns = vec![WebsocketConnection::new("c1")];
7815 let config = ConfigurationWebsocketStreams {
7816 ws_url: Some("wss://example".to_string()),
7817 mode: WebsocketMode::Single,
7818 reconnect_delay: 100,
7819 time_unit: Some(TimeUnit::Millisecond),
7820 agent: None,
7821 user_agent: build_user_agent("product"),
7822 };
7823 let ws = WebsocketStreams::new(config, conns, vec![]);
7824 let url = ws.prepare_url(["a".into()].as_ref(), Some("path1"));
7825 assert_eq!(
7826 url,
7827 "wss://example/path1/stream?streams=a&timeUnit=millisecond"
7828 );
7829 });
7830 }
7831
7832 #[test]
7833 fn url_path_distinguishes_urls_for_same_streams() {
7834 TOKIO_SHARED_RT.block_on(async {
7835 let conns = vec![WebsocketConnection::new("c1")];
7836 let config = ConfigurationWebsocketStreams {
7837 ws_url: Some("wss://example".to_string()),
7838 mode: WebsocketMode::Single,
7839 reconnect_delay: 100,
7840 time_unit: None,
7841 agent: None,
7842 user_agent: build_user_agent("product"),
7843 };
7844 let ws = WebsocketStreams::new(config, conns, vec![]);
7845 let u1 = ws.prepare_url(["s1".into()].as_ref(), Some("path1"));
7846 let u2 = ws.prepare_url(["s1".into()].as_ref(), Some("path2"));
7847 assert_eq!(u1, "wss://example/path1/stream?streams=s1");
7848 assert_eq!(u2, "wss://example/path2/stream?streams=s1");
7849 });
7850 }
7851 }
7852
7853 mod handle_stream_assignment {
7854 use super::*;
7855
7856 #[test]
7857 fn assigns_new_streams_to_connections() {
7858 TOKIO_SHARED_RT.block_on(async {
7859 let ws = create_websocket_streams(None, None, None);
7860 let groups = ws
7861 .clone()
7862 .handle_stream_assignment(vec!["s1".into(), "s2".into()], None)
7863 .await;
7864 let mut seen_streams = HashSet::new();
7865 for (_conn, streams) in &groups {
7866 for s in streams {
7867 seen_streams.insert(s);
7868 }
7869 }
7870 assert_eq!(
7871 seen_streams,
7872 ["s1".to_string(), "s2".to_string()].iter().collect()
7873 );
7874 assert_eq!(groups.len(), 1);
7875 });
7876 }
7877
7878 #[test]
7879 fn reuses_existing_connection_for_duplicate_stream() {
7880 TOKIO_SHARED_RT.block_on(async {
7881 let ws = create_websocket_streams(None, None, None);
7882 let _ = ws
7883 .clone()
7884 .handle_stream_assignment(vec!["s1".into()], None)
7885 .await;
7886 let groups = ws
7887 .clone()
7888 .handle_stream_assignment(vec!["s1".into(), "s3".into()], None)
7889 .await;
7890 let mut all_streams = Vec::new();
7891 for (_conn, streams) in groups {
7892 all_streams.extend(streams);
7893 }
7894 all_streams.sort();
7895 assert_eq!(all_streams, vec!["s1".to_string(), "s3".to_string()]);
7896 });
7897 }
7898
7899 #[test]
7900 fn empty_stream_list_returns_empty() {
7901 TOKIO_SHARED_RT.block_on(async {
7902 let ws = create_websocket_streams(None, None, None);
7903 let groups = ws.clone().handle_stream_assignment(vec![], None).await;
7904 assert!(groups.is_empty());
7905 });
7906 }
7907
7908 #[test]
7909 fn closed_or_reconnecting_forces_reassignment_of_stream() {
7910 TOKIO_SHARED_RT.block_on(async {
7911 let ws = create_websocket_streams(None, None, None);
7912 let mut groups = ws
7913 .clone()
7914 .handle_stream_assignment(vec!["s1".into()], None)
7915 .await;
7916 let (conn, _) = groups.pop().unwrap();
7917 {
7918 let mut st = conn.state.lock().await;
7919 st.close_initiated = true;
7920 }
7921 let groups2 = ws
7922 .clone()
7923 .handle_stream_assignment(vec!["s2".into()], None)
7924 .await;
7925 assert_eq!(groups2.len(), 1);
7926 let (_new_conn, streams) = &groups2[0];
7927 assert_eq!(streams, &vec!["s2".to_string()]);
7928 });
7929 }
7930
7931 #[test]
7932 fn no_available_connections_falls_back_to_one() {
7933 TOKIO_SHARED_RT.block_on(async {
7934 let ws = create_websocket_streams(None, Some(vec![]), None);
7935 let assigned = ws.handle_stream_assignment(vec!["foo".into()], None).await;
7936 assert_eq!(assigned.len(), 1);
7937 let (_conn, streams) = &assigned[0];
7938 assert_eq!(streams.as_slice(), &["foo".to_string()]);
7939 });
7940 }
7941
7942 #[test]
7943 fn single_connection_groups_multiple_streams() {
7944 TOKIO_SHARED_RT.block_on(async {
7945 let conn = WebsocketConnection::new("c1");
7946 let ws = create_websocket_streams(None, Some(vec![conn.clone()]), None);
7947 let assigned = ws
7948 .handle_stream_assignment(vec!["s1".into(), "s2".into()], None)
7949 .await;
7950 assert_eq!(assigned.len(), 1);
7951 let (assigned_conn, streams) = &assigned[0];
7952 assert!(Arc::ptr_eq(assigned_conn, &conn));
7953 assert_eq!(streams.len(), 2);
7954 assert!(streams.contains(&"s1".to_string()));
7955 assert!(streams.contains(&"s2".to_string()));
7956 });
7957 }
7958
7959 #[test]
7960 fn reuse_existing_healthy_connection() {
7961 TOKIO_SHARED_RT.block_on(async {
7962 let conn = WebsocketConnection::new("c");
7963 let ws = create_websocket_streams(None, Some(vec![conn.clone()]), None);
7964 let _ = ws.handle_stream_assignment(vec!["s1".into()], None).await;
7965 let second = ws.handle_stream_assignment(vec!["s1".into()], None).await;
7966 assert_eq!(second.len(), 1);
7967 let (assigned_conn, streams) = &second[0];
7968 assert!(Arc::ptr_eq(assigned_conn, &conn));
7969 assert_eq!(streams.as_slice(), &["s1".to_string()]);
7970 });
7971 }
7972
7973 #[test]
7974 fn mix_new_and_assigned_streams() {
7975 TOKIO_SHARED_RT.block_on(async {
7976 let conn = WebsocketConnection::new("c");
7977 let ws = create_websocket_streams(None, Some(vec![conn.clone()]), None);
7978 let _ = ws
7979 .handle_stream_assignment(vec!["s1".into(), "s2".into()], None)
7980 .await;
7981 let mixed = ws
7982 .handle_stream_assignment(vec!["s2".into(), "s3".into()], None)
7983 .await;
7984 assert_eq!(mixed.len(), 1);
7985 let (assigned_conn, streams) = &mixed[0];
7986 assert!(Arc::ptr_eq(assigned_conn, &conn));
7987 let mut got = streams.clone();
7988 got.sort();
7989 assert_eq!(got, vec!["s2".to_string(), "s3".to_string()]);
7990 });
7991 }
7992
7993 #[test]
7994 fn assigns_streams_with_url_path_keys() {
7995 TOKIO_SHARED_RT.block_on(async {
7996 let ws = create_websocket_streams(None, None, None);
7997
7998 let conn = ws.common.connection_pool[0].clone();
7999 {
8000 let mut st = conn.state.lock().await;
8001 st.url_path = Some("path1".to_string());
8002 st.ws_write_tx = None;
8003 st.reconnection_pending = false;
8004 st.close_initiated = false;
8005 }
8006
8007 let groups = ws
8008 .handle_stream_assignment(vec!["s1".into(), "s2".into()], Some("path1"))
8009 .await;
8010
8011 let map = ws.connection_streams.lock().await;
8012 assert!(map.contains_key("path1::s1"));
8013 assert!(map.contains_key("path1::s2"));
8014 assert_eq!(groups.len(), 1);
8015
8016 let (_assigned_conn, streams) = &groups[0];
8017 let mut got = streams.clone();
8018 got.sort();
8019 assert_eq!(got, vec!["s1".to_string(), "s2".to_string()]);
8020 });
8021 }
8022
8023 #[test]
8024 fn same_stream_on_different_paths_creates_distinct_keys() {
8025 TOKIO_SHARED_RT.block_on(async {
8026 let ws = create_websocket_streams(None, None, None);
8027
8028 let conn1 = ws.common.connection_pool[0].clone();
8029 let conn2 = ws.common.connection_pool[1].clone();
8030
8031 {
8032 let mut st = conn1.state.lock().await;
8033 st.url_path = Some("path1".to_string());
8034 st.ws_write_tx = None;
8035 st.reconnection_pending = false;
8036 st.close_initiated = false;
8037 }
8038 {
8039 let mut st = conn2.state.lock().await;
8040 st.url_path = Some("path2".to_string());
8041 st.ws_write_tx = None;
8042 st.reconnection_pending = false;
8043 st.close_initiated = false;
8044 }
8045
8046 let g1 = ws
8047 .handle_stream_assignment(vec!["s1".into()], Some("path1"))
8048 .await;
8049 let g2 = ws
8050 .handle_stream_assignment(vec!["s1".into()], Some("path2"))
8051 .await;
8052
8053 assert_eq!(g1.len(), 1);
8054 assert_eq!(g2.len(), 1);
8055
8056 let map = ws.connection_streams.lock().await;
8057 assert!(map.contains_key("path1::s1"));
8058 assert!(map.contains_key("path2::s1"));
8059 });
8060 }
8061
8062 #[test]
8063 fn reuses_existing_connection_for_same_path_and_stream() {
8064 TOKIO_SHARED_RT.block_on(async {
8065 let ws = create_websocket_streams(None, None, None);
8066
8067 let conn = ws.common.connection_pool[0].clone();
8068 {
8069 let mut st = conn.state.lock().await;
8070 st.url_path = Some("path1".to_string());
8071 st.ws_write_tx = None;
8072 st.reconnection_pending = false;
8073 st.close_initiated = false;
8074 }
8075
8076 let first = ws
8077 .handle_stream_assignment(vec!["s1".into()], Some("path1"))
8078 .await;
8079 let second = ws
8080 .handle_stream_assignment(vec!["s1".into(), "s2".into()], Some("path1"))
8081 .await;
8082
8083 assert_eq!(first.len(), 1);
8084 assert_eq!(second.len(), 1);
8085
8086 let map = ws.connection_streams.lock().await;
8087 let c1 = map.get("path1::s1").unwrap().clone();
8088 let c2 = map.get("path1::s2").unwrap().clone();
8089 assert!(Arc::ptr_eq(&c1, &c2));
8090 });
8091 }
8092
8093 #[test]
8094 fn closed_or_reconnecting_forces_reassignment_with_url_path() {
8095 TOKIO_SHARED_RT.block_on(async {
8096 let ws = create_websocket_streams(None, None, None);
8097
8098 let conn1 = ws.common.connection_pool[0].clone();
8099 let conn2 = ws.common.connection_pool[1].clone();
8100
8101 {
8102 let mut st = conn1.state.lock().await;
8103 st.url_path = Some("path1".to_string());
8104 st.ws_write_tx = None;
8105 st.reconnection_pending = false;
8106 st.close_initiated = false;
8107 }
8108 {
8109 let mut st = conn2.state.lock().await;
8110 st.url_path = Some("path1".to_string());
8111 st.ws_write_tx = None;
8112 st.reconnection_pending = false;
8113 st.close_initiated = false;
8114 }
8115
8116 let _ = ws
8117 .handle_stream_assignment(vec!["s1".into()], Some("path1"))
8118 .await;
8119
8120 {
8121 let mut st = conn1.state.lock().await;
8122 st.close_initiated = true;
8123 }
8124
8125 let _ = ws
8126 .handle_stream_assignment(vec!["s1".into()], Some("path1"))
8127 .await;
8128
8129 let map = ws.connection_streams.lock().await;
8130 let assigned = map.get("path1::s1").unwrap().clone();
8131 assert!(!Arc::ptr_eq(&assigned, &conn1));
8132 });
8133 }
8134 }
8135
8136 mod send_subscription_payload {
8137 use super::*;
8138
8139 #[test]
8140 fn subscribe_payload_with_custom_id_fallbacks_if_invalid() {
8141 TOKIO_SHARED_RT.block_on(async {
8142 let ws: Arc<WebsocketStreams> =
8143 create_websocket_streams(Some("ws://example.com"), None, None);
8144 let conn = &ws.common.connection_pool[0];
8145 let (tx, mut rx) = unbounded_channel();
8146 {
8147 let mut st = conn.state.lock().await;
8148 st.ws_write_tx = Some(tx);
8149 }
8150 let id = Some("badid".to_string());
8151 ws.send_subscription_payload(
8152 conn,
8153 &vec!["s1".to_string()],
8154 id.map(StreamId::from),
8155 );
8156 let msg = rx.recv().await.expect("no message sent");
8157 if let Message::Text(txt) = msg {
8158 let v: serde_json::Value = serde_json::from_str(&txt).unwrap();
8159 assert_eq!(v["method"], "SUBSCRIBE");
8160 let id = v["id"].as_str().unwrap();
8161 assert_ne!(id, "badid");
8162 assert!(Regex::new(r"^[0-9a-fA-F]{32}$").unwrap().is_match(id));
8163 } else {
8164 panic!("unexpected message: {msg:?}");
8165 }
8166 });
8167 }
8168
8169 #[test]
8170 fn subscribe_payload_with_and_without_custom_string_id() {
8171 TOKIO_SHARED_RT.block_on(async {
8172 let ws: Arc<WebsocketStreams> =
8173 create_websocket_streams(Some("ws://unused"), None, None);
8174 let conn = &ws.common.connection_pool[0];
8175 let (tx, mut rx) = unbounded_channel();
8176 {
8177 let mut st = conn.state.lock().await;
8178 st.ws_write_tx = Some(tx);
8179 }
8180 let id = Some("deadbeefdeadbeefdeadbeefdeadbeef".to_string());
8181 ws.send_subscription_payload(
8182 conn,
8183 &vec!["a".to_string(), "b".to_string()],
8184 id.map(StreamId::from),
8185 );
8186 let msg1 = rx.recv().await.unwrap();
8187 ws.send_subscription_payload(conn, &vec!["x".to_string()], None);
8188 let msg2 = rx.recv().await.unwrap();
8189
8190 if let Message::Text(txt1) = msg1 {
8191 let v1: serde_json::Value = serde_json::from_str(&txt1).unwrap();
8192 assert_eq!(v1["id"], "deadbeefdeadbeefdeadbeefdeadbeef");
8193 assert_eq!(
8194 v1["params"].as_array().unwrap(),
8195 &vec![serde_json::json!("a"), serde_json::json!("b")]
8196 );
8197 } else {
8198 panic!()
8199 }
8200
8201 if let Message::Text(txt2) = msg2 {
8202 let v2: serde_json::Value = serde_json::from_str(&txt2).unwrap();
8203 assert_eq!(v2["method"], "SUBSCRIBE");
8204 let params = v2["params"].as_array().unwrap();
8205 assert_eq!(params.len(), 1);
8206 assert_eq!(params[0], "x");
8207 let id2 = v2["id"].as_str().unwrap();
8208 assert!(Regex::new(r"^[0-9a-fA-F]{32}$").unwrap().is_match(id2));
8209 } else {
8210 panic!()
8211 }
8212 });
8213 }
8214
8215 #[test]
8216 fn subscribe_payload_with_and_without_custom_integer_id() {
8217 TOKIO_SHARED_RT.block_on(async {
8218 let ws: Arc<WebsocketStreams> =
8219 create_websocket_streams(Some("ws://unused"), None, None);
8220 ws.stream_id_is_strictly_number
8221 .store(true, Ordering::Relaxed);
8222 let conn = &ws.common.connection_pool[0];
8223 let (tx, mut rx) = unbounded_channel();
8224 {
8225 let mut st = conn.state.lock().await;
8226 st.ws_write_tx = Some(tx);
8227 }
8228
8229 let id = Some(123u32);
8230
8231 ws.send_subscription_payload(
8232 conn,
8233 &vec!["a".to_string(), "b".to_string()],
8234 id.map(StreamId::from),
8235 );
8236 let msg1 = rx.recv().await.unwrap();
8237
8238 ws.send_subscription_payload(conn, &vec!["x".to_string()], None);
8239 let msg2 = rx.recv().await.unwrap();
8240
8241 if let Message::Text(txt1) = msg1 {
8242 let v1: serde_json::Value = serde_json::from_str(&txt1).unwrap();
8243 assert_eq!(v1["method"], "SUBSCRIBE");
8244 assert_eq!(v1["id"].as_u64(), Some(123));
8245 assert_eq!(
8246 v1["params"].as_array().unwrap(),
8247 &vec![serde_json::json!("a"), serde_json::json!("b")]
8248 );
8249 } else {
8250 panic!("Expected Message::Text for msg1");
8251 }
8252
8253 if let Message::Text(txt2) = msg2 {
8254 let v2: serde_json::Value = serde_json::from_str(&txt2).unwrap();
8255 assert_eq!(v2["method"], "SUBSCRIBE");
8256
8257 let params = v2["params"].as_array().unwrap();
8258 assert_eq!(params.len(), 1);
8259 assert_eq!(params[0], "x");
8260
8261 let id2 = v2.get("id").expect("payload should contain id");
8262 assert!(
8263 id2.is_number(),
8264 "expected numeric id in strict-number mode, got: {id2:?}"
8265 );
8266 let n = id2.as_u64().unwrap();
8267 assert!(u32::try_from(n).is_ok(), "id should fit u32, got {n}");
8268 } else {
8269 panic!("Expected Message::Text for msg2");
8270 }
8271 });
8272 }
8273 }
8274
8275 mod on_open {
8276 use super::*;
8277
8278 #[test]
8279 fn sends_pending_subscriptions() {
8280 TOKIO_SHARED_RT.block_on(async {
8281 let ws: Arc<WebsocketStreams> =
8282 create_websocket_streams(Some("ws://example.com"), None, None);
8283 let conn = &ws.common.connection_pool[0];
8284 let (tx, mut rx) = unbounded_channel();
8285 {
8286 let mut st = conn.state.lock().await;
8287 st.ws_write_tx = Some(tx);
8288 st.pending_subscriptions.push_back("foo".to_string());
8289 st.pending_subscriptions.push_back("bar".to_string());
8290 }
8291 ws.on_open("ws://example.com".to_string(), conn.clone())
8292 .await;
8293 let msg = rx.recv().await.expect("no subscription sent");
8294 if let Message::Text(txt) = msg {
8295 let v: Value = serde_json::from_str(&txt).unwrap();
8296 assert_eq!(v["method"], "SUBSCRIBE");
8297 let params = v["params"].as_array().unwrap();
8298 assert_eq!(
8299 params,
8300 &vec![Value::String("foo".into()), Value::String("bar".into())]
8301 );
8302 } else {
8303 panic!("unexpected message: {msg:?}");
8304 }
8305 let st_after = conn.state.lock().await;
8306 assert!(st_after.pending_subscriptions.is_empty());
8307 });
8308 }
8309
8310 #[test]
8311 fn with_no_pending_subscriptions_sends_nothing() {
8312 TOKIO_SHARED_RT.block_on(async {
8313 let ws: Arc<WebsocketStreams> =
8314 create_websocket_streams(Some("ws://example.com"), None, None);
8315 let conn = &ws.common.connection_pool[0];
8316 let (tx, mut rx) = unbounded_channel();
8317 {
8318 let mut st = conn.state.lock().await;
8319 st.ws_write_tx = Some(tx);
8320 }
8321 ws.on_open("ws://example.com".to_string(), conn.clone())
8322 .await;
8323 assert!(rx.try_recv().is_err(), "unexpected message sent");
8324 });
8325 }
8326
8327 #[test]
8328 fn clears_pending_without_write_channel() {
8329 TOKIO_SHARED_RT.block_on(async {
8330 let ws: Arc<WebsocketStreams> =
8331 create_websocket_streams(Some("ws://example.com"), None, None);
8332 let conn = &ws.common.connection_pool[0];
8333 {
8334 let mut st = conn.state.lock().await;
8335 st.pending_subscriptions.push_back("solo".to_string());
8336 }
8337 ws.on_open("ws://example.com".to_string(), conn.clone())
8338 .await;
8339 let st_after = conn.state.lock().await;
8340 assert!(st_after.pending_subscriptions.is_empty());
8341 });
8342 }
8343 }
8344
8345 mod on_message {
8346 use super::*;
8347
8348 #[test]
8349 fn invokes_registered_callback() {
8350 TOKIO_SHARED_RT.block_on(async {
8351 let ws: Arc<WebsocketStreams> =
8352 create_websocket_streams(Some("ws://example.com"), None, None);
8353 let conn = &ws.common.connection_pool[0];
8354 let called = Arc::new(AtomicBool::new(false));
8355 let called_clone = called.clone();
8356
8357 {
8358 let mut st = conn.state.lock().await;
8359 st.stream_callbacks
8360 .entry("stream1".to_string())
8361 .or_default()
8362 .push(
8363 (Box::new(move |_: &Value| {
8364 called_clone.store(true, Ordering::SeqCst);
8365 })
8366 as Box<dyn Fn(&Value) + Send + Sync>)
8367 .into(),
8368 );
8369 }
8370
8371 let msg = json!({
8372 "stream": "stream1",
8373 "data": { "key": "value" }
8374 })
8375 .to_string();
8376
8377 ws.on_message(msg, conn.clone()).await;
8378
8379 assert!(called.load(Ordering::SeqCst));
8380 });
8381 }
8382
8383 #[test]
8384 fn invokes_all_registered_callbacks() {
8385 TOKIO_SHARED_RT.block_on(async {
8386 let ws: Arc<WebsocketStreams> =
8387 create_websocket_streams(Some("ws://example.com"), None, None);
8388 let conn = &ws.common.connection_pool[0];
8389 let counter = Arc::new(AtomicUsize::new(0));
8390
8391 {
8392 let mut st = conn.state.lock().await;
8393 let entry = st.stream_callbacks.entry("s".into()).or_default();
8394 let c1 = counter.clone();
8395 entry.push(
8396 (Box::new(move |_: &Value| {
8397 c1.fetch_add(1, Ordering::SeqCst);
8398 }) as Box<dyn Fn(&Value) + Send + Sync>)
8399 .into(),
8400 );
8401 let c2 = counter.clone();
8402 entry.push(
8403 (Box::new(move |_: &Value| {
8404 c2.fetch_add(1, Ordering::SeqCst);
8405 }) as Box<dyn Fn(&Value) + Send + Sync>)
8406 .into(),
8407 );
8408 }
8409
8410 let msg = json!({"stream":"s","data":42}).to_string();
8411 ws.on_message(msg, conn.clone()).await;
8412
8413 assert_eq!(counter.load(Ordering::SeqCst), 2);
8414 });
8415 }
8416
8417 #[test]
8418 fn handles_null_data_field() {
8419 TOKIO_SHARED_RT.block_on(async {
8420 let ws: Arc<WebsocketStreams> =
8421 create_websocket_streams(Some("ws://example.com"), None, None);
8422 let conn = &ws.common.connection_pool[0];
8423 let called = Arc::new(AtomicUsize::new(0));
8424 {
8425 let mut st = conn.state.lock().await;
8426 st.stream_callbacks.entry("n".into()).or_default().push(
8427 (Box::new({
8428 let c = called.clone();
8429 move |data: &Value| {
8430 if data.is_null() {
8431 c.fetch_add(1, Ordering::SeqCst);
8432 }
8433 }
8434 }) as Box<dyn Fn(&Value) + Send + Sync>)
8435 .into(),
8436 );
8437 }
8438 let msg = json!({"stream":"n","data":null}).to_string();
8439 ws.on_message(msg, conn.clone()).await;
8440 assert_eq!(called.load(Ordering::SeqCst), 1);
8441 });
8442 }
8443
8444 #[test]
8445 fn with_invalid_json_does_not_panic() {
8446 TOKIO_SHARED_RT.block_on(async {
8447 let ws: Arc<WebsocketStreams> =
8448 create_websocket_streams(Some("ws://example.com"), None, None);
8449 let conn = &ws.common.connection_pool[0];
8450 let bad = "not a json";
8451 ws.on_message(bad.to_string(), conn.clone()).await;
8452 });
8453 }
8454
8455 #[test]
8456 fn without_stream_field_does_nothing() {
8457 TOKIO_SHARED_RT.block_on(async {
8458 let ws: Arc<WebsocketStreams> =
8459 create_websocket_streams(Some("ws://example.com"), None, None);
8460 let conn = &ws.common.connection_pool[0];
8461 let msg = json!({ "data": { "foo": 1 } }).to_string();
8462 ws.on_message(msg, conn.clone()).await;
8463 });
8464 }
8465
8466 #[test]
8467 fn with_unregistered_stream_does_not_panic() {
8468 TOKIO_SHARED_RT.block_on(async {
8469 let ws: Arc<WebsocketStreams> =
8470 create_websocket_streams(Some("ws://example.com"), None, None);
8471 let conn = &ws.common.connection_pool[0];
8472 let msg = json!({
8473 "stream": "nope",
8474 "data": { "foo": 1 }
8475 })
8476 .to_string();
8477 ws.on_message(msg, conn.clone()).await;
8478 });
8479 }
8480
8481 #[test]
8482 fn invokes_registered_callback_with_url_path_key() {
8483 TOKIO_SHARED_RT.block_on(async {
8484 let ws: Arc<WebsocketStreams> =
8485 create_websocket_streams(Some("ws://example.com"), None, None);
8486 let conn = &ws.common.connection_pool[0];
8487
8488 {
8489 let mut st = conn.state.lock().await;
8490 st.url_path = Some("path1".to_string());
8491 }
8492
8493 let called = Arc::new(AtomicBool::new(false));
8494 let called_clone = called.clone();
8495
8496 {
8497 let mut st = conn.state.lock().await;
8498 st.stream_callbacks
8499 .entry("path1::stream1".to_string())
8500 .or_default()
8501 .push(
8502 (Box::new(move |_: &Value| {
8503 called_clone.store(true, Ordering::SeqCst);
8504 })
8505 as Box<dyn Fn(&Value) + Send + Sync>)
8506 .into(),
8507 );
8508 }
8509
8510 let msg = json!({
8511 "stream": "stream1",
8512 "data": { "key": "value" }
8513 })
8514 .to_string();
8515
8516 ws.on_message(msg, conn.clone()).await;
8517
8518 assert!(called.load(Ordering::SeqCst));
8519 });
8520 }
8521
8522 #[test]
8523 fn does_not_invoke_callback_when_url_path_mismatch() {
8524 TOKIO_SHARED_RT.block_on(async {
8525 let ws: Arc<WebsocketStreams> =
8526 create_websocket_streams(Some("ws://example.com"), None, None);
8527 let conn = &ws.common.connection_pool[0];
8528
8529 {
8530 let mut st = conn.state.lock().await;
8531 st.url_path = Some("path2".to_string());
8532 }
8533
8534 let called = Arc::new(AtomicBool::new(false));
8535 let called_clone = called.clone();
8536
8537 {
8538 let mut st = conn.state.lock().await;
8539 st.stream_callbacks
8540 .entry("path1::stream1".to_string())
8541 .or_default()
8542 .push(
8543 (Box::new(move |_: &Value| {
8544 called_clone.store(true, Ordering::SeqCst);
8545 })
8546 as Box<dyn Fn(&Value) + Send + Sync>)
8547 .into(),
8548 );
8549 }
8550
8551 let msg = json!({
8552 "stream": "stream1",
8553 "data": { "key": "value" }
8554 })
8555 .to_string();
8556
8557 ws.on_message(msg, conn.clone()).await;
8558
8559 assert!(!called.load(Ordering::SeqCst));
8560 });
8561 }
8562
8563 #[test]
8564 fn invokes_only_callbacks_for_current_url_path_when_both_exist() {
8565 TOKIO_SHARED_RT.block_on(async {
8566 let ws: Arc<WebsocketStreams> =
8567 create_websocket_streams(Some("ws://example.com"), None, None);
8568 let conn = &ws.common.connection_pool[0];
8569
8570 {
8571 let mut st = conn.state.lock().await;
8572 st.url_path = Some("path1".to_string());
8573 }
8574
8575 let c1 = Arc::new(AtomicUsize::new(0));
8576 let c2 = Arc::new(AtomicUsize::new(0));
8577
8578 {
8579 let mut st = conn.state.lock().await;
8580
8581 let a = c1.clone();
8582 st.stream_callbacks
8583 .entry("path1::s".to_string())
8584 .or_default()
8585 .push(
8586 (Box::new(move |_: &Value| {
8587 a.fetch_add(1, Ordering::SeqCst);
8588 })
8589 as Box<dyn Fn(&Value) + Send + Sync>)
8590 .into(),
8591 );
8592
8593 let b = c2.clone();
8594 st.stream_callbacks
8595 .entry("path2::s".to_string())
8596 .or_default()
8597 .push(
8598 (Box::new(move |_: &Value| {
8599 b.fetch_add(1, Ordering::SeqCst);
8600 })
8601 as Box<dyn Fn(&Value) + Send + Sync>)
8602 .into(),
8603 );
8604 }
8605
8606 let msg = json!({"stream":"s","data":42}).to_string();
8607 ws.on_message(msg, conn.clone()).await;
8608
8609 assert_eq!(c1.load(Ordering::SeqCst), 1);
8610 assert_eq!(c2.load(Ordering::SeqCst), 0);
8611 });
8612 }
8613 }
8614
8615 mod get_reconnect_url {
8616 use super::*;
8617
8618 #[test]
8619 fn single_stream_reconnect_url() {
8620 TOKIO_SHARED_RT.block_on(async {
8621 let ws: Arc<WebsocketStreams> =
8622 create_websocket_streams(Some("ws://example.com"), None, None);
8623 let c0 = ws.common.connection_pool[0].clone();
8624 {
8625 let mut map = ws.connection_streams.lock().await;
8626 map.insert("s1".to_string(), c0.clone());
8627 }
8628 let url = ws.get_reconnect_url("default_url".into(), c0).await;
8629 assert_eq!(url, "ws://example.com/stream?streams=s1");
8630 });
8631 }
8632
8633 #[test]
8634 fn multiple_streams_same_connection() {
8635 TOKIO_SHARED_RT.block_on(async {
8636 let ws: Arc<WebsocketStreams> =
8637 create_websocket_streams(Some("ws://example.com"), None, None);
8638 let c0 = ws.common.connection_pool[0].clone();
8639 {
8640 let mut map = ws.connection_streams.lock().await;
8641 map.insert("a".to_string(), c0.clone());
8642 map.insert("b".to_string(), c0.clone());
8643 }
8644 let url = ws.get_reconnect_url("default_url".into(), c0).await;
8645 let suffix = url
8646 .strip_prefix("ws://example.com/stream?streams=")
8647 .unwrap();
8648 let parts: Vec<_> = suffix.split('&').next().unwrap().split('/').collect();
8649 let set = parts.into_iter().collect::<std::collections::HashSet<_>>();
8650 assert_eq!(set, ["a", "b"].iter().copied().collect());
8651 });
8652 }
8653
8654 #[test]
8655 fn reconnect_url_with_time_unit() {
8656 TOKIO_SHARED_RT.block_on(async {
8657 let mut ws: Arc<WebsocketStreams> =
8658 create_websocket_streams(Some("ws://example.com"), None, None);
8659 Arc::get_mut(&mut ws).unwrap().configuration.time_unit =
8660 Some(TimeUnit::Microsecond);
8661 let c0 = ws.common.connection_pool[0].clone();
8662 {
8663 let mut map = ws.connection_streams.lock().await;
8664 map.insert("x".to_string(), c0.clone());
8665 }
8666 let url = ws.get_reconnect_url("default_url".into(), c0).await;
8667 assert_eq!(
8668 url,
8669 "ws://example.com/stream?streams=x&timeUnit=microsecond"
8670 );
8671 });
8672 }
8673
8674 #[test]
8675 fn reconnect_url_uses_url_path_from_connection_state() {
8676 TOKIO_SHARED_RT.block_on(async {
8677 let ws: Arc<WebsocketStreams> =
8678 create_websocket_streams(Some("ws://example.com"), None, None);
8679 let c0 = ws.common.connection_pool[0].clone();
8680
8681 {
8682 let mut st = c0.state.lock().await;
8683 st.url_path = Some("path1".to_string());
8684 }
8685
8686 {
8687 let mut map = ws.connection_streams.lock().await;
8688 map.insert("path1::s1".to_string(), c0.clone());
8689 }
8690
8691 let url = ws.get_reconnect_url("default_url".into(), c0).await;
8692 assert_eq!(url, "ws://example.com/path1/stream?streams=s1");
8693 });
8694 }
8695
8696 #[test]
8697 fn reconnect_url_strips_prefix_from_multiple_keys_with_url_path() {
8698 TOKIO_SHARED_RT.block_on(async {
8699 let ws: Arc<WebsocketStreams> =
8700 create_websocket_streams(Some("ws://example.com"), None, None);
8701 let c0 = ws.common.connection_pool[0].clone();
8702
8703 {
8704 let mut st = c0.state.lock().await;
8705 st.url_path = Some("path1".to_string());
8706 }
8707
8708 {
8709 let mut map = ws.connection_streams.lock().await;
8710 map.insert("path1::a".to_string(), c0.clone());
8711 map.insert("path1::b".to_string(), c0.clone());
8712 }
8713
8714 let url = ws.get_reconnect_url("default_url".into(), c0).await;
8715
8716 let suffix = url
8717 .strip_prefix("ws://example.com/path1/stream?streams=")
8718 .unwrap();
8719 let parts: Vec<_> = suffix.split('&').next().unwrap().split('/').collect();
8720 let set = parts.into_iter().collect::<std::collections::HashSet<_>>();
8721 assert_eq!(set, ["a", "b"].iter().copied().collect());
8722 });
8723 }
8724
8725 #[test]
8726 fn reconnect_url_with_url_path_and_time_unit() {
8727 TOKIO_SHARED_RT.block_on(async {
8728 let mut ws: Arc<WebsocketStreams> =
8729 create_websocket_streams(Some("ws://example.com"), None, None);
8730 Arc::get_mut(&mut ws).unwrap().configuration.time_unit =
8731 Some(TimeUnit::Microsecond);
8732
8733 let c0 = ws.common.connection_pool[0].clone();
8734
8735 {
8736 let mut st = c0.state.lock().await;
8737 st.url_path = Some("path1".to_string());
8738 }
8739
8740 {
8741 let mut map = ws.connection_streams.lock().await;
8742 map.insert("path1::x".to_string(), c0.clone());
8743 }
8744
8745 let url = ws.get_reconnect_url("default_url".into(), c0).await;
8746 assert_eq!(
8747 url,
8748 "ws://example.com/path1/stream?streams=x&timeUnit=microsecond"
8749 );
8750 });
8751 }
8752
8753 #[test]
8754 fn reconnect_url_ignores_streams_from_other_connections_even_if_same_path_prefix() {
8755 TOKIO_SHARED_RT.block_on(async {
8756 let ws: Arc<WebsocketStreams> =
8757 create_websocket_streams(Some("ws://example.com"), None, None);
8758 let c0 = ws.common.connection_pool[0].clone();
8759 let c1 = ws.common.connection_pool[1].clone();
8760
8761 {
8762 let mut st = c0.state.lock().await;
8763 st.url_path = Some("path1".to_string());
8764 }
8765 {
8766 let mut st = c1.state.lock().await;
8767 st.url_path = Some("path1".to_string());
8768 }
8769
8770 {
8771 let mut map = ws.connection_streams.lock().await;
8772 map.insert("path1::a".to_string(), c0.clone());
8773 map.insert("path1::b".to_string(), c1.clone());
8774 }
8775
8776 let url = ws.get_reconnect_url("default_url".into(), c0).await;
8777
8778 let suffix = url
8779 .strip_prefix("ws://example.com/path1/stream?streams=")
8780 .unwrap();
8781 let parts: Vec<_> = suffix.split('&').next().unwrap().split('/').collect();
8782 let set = parts.into_iter().collect::<std::collections::HashSet<_>>();
8783 assert_eq!(set, ["a"].iter().copied().collect());
8784 });
8785 }
8786 }
8787 }
8788
8789 mod websocket_stream {
8790 use super::*;
8791
8792 mod on {
8793 use super::*;
8794
8795 #[test]
8796 fn registers_callback_and_stream_callback_for_websocket_streams() {
8797 TOKIO_SHARED_RT.block_on(async {
8798 let ws_base = create_websocket_streams(Some("example.com"), None, None);
8799 let stream_name = "s1".to_string();
8800 let conn = ws_base.common.connection_pool[0].clone();
8801
8802 let key = ws_base.stream_key(&stream_name, None);
8803
8804 {
8805 let mut map = ws_base.connection_streams.lock().await;
8806 map.insert(key.clone(), conn.clone());
8807 }
8808 {
8809 let mut state = conn.state.lock().await;
8810 state.stream_callbacks.insert(key.clone(), Vec::new());
8811 }
8812
8813 let stream = Arc::new(WebsocketStream::<Value> {
8814 websocket_base: WebsocketBase::WebsocketStreams(ws_base.clone()),
8815 stream_or_id: stream_name.clone(),
8816 callback: Mutex::new(None),
8817 url_path: None,
8818 id: None,
8819 _phantom: PhantomData,
8820 });
8821
8822 stream.on("message", |_| {}).await;
8823
8824 let cb_guard = stream.callback.lock().await;
8825 assert!(cb_guard.is_some());
8826
8827 let cbs = {
8828 let state = conn.state.lock().await;
8829 state.stream_callbacks.get(&key).unwrap().clone()
8830 };
8831 assert_eq!(cbs.len(), 1);
8832 });
8833 }
8834
8835 #[test]
8836 fn message_twice_registers_two_wrappers_for_websocket_streams() {
8837 TOKIO_SHARED_RT.block_on(async {
8838 let ws_base = create_websocket_streams(Some("example.com"), None, None);
8839 let stream_name = "s2".to_string();
8840 let conn = ws_base.common.connection_pool[0].clone();
8841
8842 let key = ws_base.stream_key(&stream_name, None);
8843
8844 {
8845 let mut map = ws_base.connection_streams.lock().await;
8846 map.insert(key.clone(), conn.clone());
8847 }
8848 {
8849 let mut state = conn.state.lock().await;
8850 state.stream_callbacks.insert(key.clone(), Vec::new());
8851 }
8852
8853 let stream = Arc::new(WebsocketStream::<Value> {
8854 websocket_base: WebsocketBase::WebsocketStreams(ws_base.clone()),
8855 stream_or_id: stream_name.clone(),
8856 url_path: None,
8857 callback: Mutex::new(None),
8858 id: None,
8859 _phantom: PhantomData,
8860 });
8861
8862 stream.on("message", |_| {}).await;
8863 stream.on("message", |_| {}).await;
8864
8865 let state = conn.state.lock().await;
8866 let callbacks = state.stream_callbacks.get(&key).unwrap();
8867 assert_eq!(callbacks.len(), 2);
8868 });
8869 }
8870
8871 #[test]
8872 fn ignores_non_message_event_for_websocket_streams() {
8873 TOKIO_SHARED_RT.block_on(async {
8874 let ws_base = create_websocket_streams(Some("example.com"), None, None);
8875 let stream = Arc::new(WebsocketStream::<Value> {
8876 websocket_base: WebsocketBase::WebsocketStreams(ws_base.clone()),
8877 stream_or_id: "s".into(),
8878 url_path: None,
8879 callback: Mutex::new(None),
8880 id: None,
8881 _phantom: PhantomData,
8882 });
8883 stream.on("open", |_| {}).await;
8884 let guard = stream.callback.lock().await;
8885 assert!(guard.is_none());
8886 });
8887 }
8888
8889 #[test]
8890 fn registers_callback_and_stream_callback_for_websocket_api() {
8891 TOKIO_SHARED_RT.block_on(async {
8892 let ws_base = create_websocket_api(None, None, None);
8893
8894 {
8895 let mut stream_callbacks = ws_base.stream_callbacks.lock().await;
8896 stream_callbacks.insert("id1".to_string(), Vec::new());
8897 }
8898
8899 let stream = Arc::new(WebsocketStream::<Value> {
8900 websocket_base: WebsocketBase::WebsocketApi(ws_base.clone()),
8901 stream_or_id: "id1".to_string(),
8902 url_path: None,
8903 callback: Mutex::new(None),
8904 id: None,
8905 _phantom: PhantomData,
8906 });
8907
8908 let called = Arc::new(Mutex::new(false));
8909 let called_clone = called.clone();
8910 stream
8911 .on("message", move |v: Value| {
8912 let mut lock = called_clone.blocking_lock();
8913 *lock = v == Value::String("x".into());
8914 })
8915 .await;
8916
8917 let cb_guard = stream.callback.lock().await;
8918 assert!(cb_guard.is_some());
8919
8920 let stream_callbacks = ws_base.stream_callbacks.lock().await;
8921 let callbacks = stream_callbacks.get("id1").unwrap();
8922 assert_eq!(callbacks.len(), 1);
8923 });
8924 }
8925
8926 #[test]
8927 fn message_twice_registers_two_wrappers_for_websocket_api() {
8928 TOKIO_SHARED_RT.block_on(async {
8929 let ws_base = create_websocket_api(None, None, None);
8930
8931 let stream = Arc::new(WebsocketStream::<Value> {
8932 websocket_base: WebsocketBase::WebsocketApi(ws_base.clone()),
8933 stream_or_id: "id2".to_string(),
8934 url_path: None,
8935 callback: Mutex::new(None),
8936 id: None,
8937 _phantom: PhantomData,
8938 });
8939
8940 stream.on("message", |_| {}).await;
8941 stream.on("message", |_| {}).await;
8942
8943 let stream_callbacks = ws_base.stream_callbacks.lock().await;
8944 let callbacks = stream_callbacks.get("id2").unwrap();
8945 assert_eq!(callbacks.len(), 2);
8946 });
8947 }
8948
8949 #[test]
8950 fn ignores_non_message_event_for_websocket_api() {
8951 TOKIO_SHARED_RT.block_on(async {
8952 let ws_base = create_websocket_api(None, None, None);
8953
8954 let stream = Arc::new(WebsocketStream::<Value> {
8955 websocket_base: WebsocketBase::WebsocketApi(ws_base.clone()),
8956 stream_or_id: "id3".into(),
8957 url_path: None,
8958 callback: Mutex::new(None),
8959 id: None,
8960 _phantom: PhantomData,
8961 });
8962
8963 stream.on("open", |_| {}).await;
8964
8965 let guard = stream.callback.lock().await;
8966 assert!(guard.is_none());
8967
8968 let stream_callbacks = ws_base.stream_callbacks.lock().await;
8969 assert!(stream_callbacks.get("id3").is_none());
8970 assert!(stream_callbacks.is_empty());
8971 });
8972 }
8973
8974 #[test]
8975 fn registers_callback_for_websocket_streams_with_url_path() {
8976 TOKIO_SHARED_RT.block_on(async {
8977 let ws_base = create_websocket_streams(Some("example.com"), None, None);
8978 let stream_name = "s1".to_string();
8979 let conn = ws_base.common.connection_pool[0].clone();
8980
8981 let key = ws_base.stream_key(&stream_name, Some("path1"));
8982
8983 {
8984 let mut map = ws_base.connection_streams.lock().await;
8985 map.insert(key.clone(), conn.clone());
8986 }
8987 {
8988 let mut state = conn.state.lock().await;
8989 state.stream_callbacks.insert(key.clone(), Vec::new());
8990 }
8991
8992 let stream = Arc::new(WebsocketStream::<Value> {
8993 websocket_base: WebsocketBase::WebsocketStreams(ws_base.clone()),
8994 stream_or_id: stream_name.clone(),
8995 url_path: Some("path1".to_string()),
8996 callback: Mutex::new(None),
8997 id: None,
8998 _phantom: PhantomData,
8999 });
9000
9001 stream.on("message", |_| {}).await;
9002
9003 let cb_guard = stream.callback.lock().await;
9004 assert!(cb_guard.is_some());
9005
9006 let callbacks = {
9007 let state = conn.state.lock().await;
9008 state.stream_callbacks.get(&key).unwrap().clone()
9009 };
9010 assert_eq!(callbacks.len(), 1);
9011 });
9012 }
9013
9014 #[test]
9015 fn url_path_routes_callback_to_correct_key_when_same_stream_name_used() {
9016 TOKIO_SHARED_RT.block_on(async {
9017 let ws_base = create_websocket_streams(Some("example.com"), None, None);
9018 let stream_name = "s1".to_string();
9019 let conn = ws_base.common.connection_pool[0].clone();
9020
9021 let key1 = ws_base.stream_key(&stream_name, Some("path1"));
9022 let key2 = ws_base.stream_key(&stream_name, Some("path2"));
9023
9024 {
9025 let mut map = ws_base.connection_streams.lock().await;
9026 map.insert(key1.clone(), conn.clone());
9027 map.insert(key2.clone(), conn.clone());
9028 }
9029 {
9030 let mut state = conn.state.lock().await;
9031 state.stream_callbacks.insert(key1.clone(), Vec::new());
9032 state.stream_callbacks.insert(key2.clone(), Vec::new());
9033 }
9034
9035 let stream_path1 = Arc::new(WebsocketStream::<Value> {
9036 websocket_base: WebsocketBase::WebsocketStreams(ws_base.clone()),
9037 stream_or_id: stream_name.clone(),
9038 url_path: Some("path1".to_string()),
9039 callback: Mutex::new(None),
9040 id: None,
9041 _phantom: PhantomData,
9042 });
9043
9044 let stream_path2 = Arc::new(WebsocketStream::<Value> {
9045 websocket_base: WebsocketBase::WebsocketStreams(ws_base.clone()),
9046 stream_or_id: stream_name.clone(),
9047 url_path: Some("path2".to_string()),
9048 callback: Mutex::new(None),
9049 id: None,
9050 _phantom: PhantomData,
9051 });
9052
9053 stream_path1.on("message", |_| {}).await;
9054 stream_path2.on("message", |_| {}).await;
9055
9056 let state = conn.state.lock().await;
9057 assert_eq!(state.stream_callbacks.get(&key1).unwrap().len(), 1);
9058 assert_eq!(state.stream_callbacks.get(&key2).unwrap().len(), 1);
9059 });
9060 }
9061 }
9062
9063 mod on_message {
9064 use super::*;
9065
9066 #[test]
9067 fn on_message_registers_callback_for_websocket_streams() {
9068 TOKIO_SHARED_RT.block_on(async {
9069 let ws_base = create_websocket_streams(Some("example.com"), None, None);
9070 let stream_name = "s".to_string();
9071 let conn = ws_base.common.connection_pool[0].clone();
9072 {
9073 let mut map = ws_base.connection_streams.lock().await;
9074 map.insert(stream_name.clone(), conn.clone());
9075 }
9076 {
9077 let mut state = conn.state.lock().await;
9078 state
9079 .stream_callbacks
9080 .insert(stream_name.clone(), Vec::new());
9081 }
9082 let stream = Arc::new(WebsocketStream::<Value> {
9083 websocket_base: WebsocketBase::WebsocketStreams(ws_base.clone()),
9084 stream_or_id: stream_name.clone(),
9085 url_path: None,
9086 callback: Mutex::new(None),
9087 id: None,
9088 _phantom: PhantomData,
9089 });
9090 stream.on_message(|_v| {});
9091 let callbacks = &conn.state.lock().await.stream_callbacks[&stream_name];
9092 assert_eq!(callbacks.len(), 1);
9093 });
9094 }
9095
9096 #[test]
9097 fn on_message_twice_registers_two_callbacks_for_websocket_streams() {
9098 TOKIO_SHARED_RT.block_on(async {
9099 let ws_base = create_websocket_streams(Some("example.com"), None, None);
9100 let stream_name = "s".to_string();
9101 let conn = ws_base.common.connection_pool[0].clone();
9102 {
9103 let mut map = ws_base.connection_streams.lock().await;
9104 map.insert(stream_name.clone(), conn.clone());
9105 }
9106 {
9107 let mut state = conn.state.lock().await;
9108 state
9109 .stream_callbacks
9110 .insert(stream_name.clone(), Vec::new());
9111 }
9112 let stream = Arc::new(WebsocketStream::<Value> {
9113 websocket_base: WebsocketBase::WebsocketStreams(ws_base.clone()),
9114 stream_or_id: stream_name.clone(),
9115 url_path: None,
9116 callback: Mutex::new(None),
9117 id: None,
9118 _phantom: PhantomData,
9119 });
9120 stream.on_message(|_v| {});
9121 stream.on_message(|_v| {});
9122 let callbacks = &conn.state.lock().await.stream_callbacks[&stream_name];
9123 assert_eq!(callbacks.len(), 2);
9124 });
9125 }
9126
9127 #[test]
9128 fn on_message_registers_callback_for_websocket_api() {
9129 TOKIO_SHARED_RT.block_on(async {
9130 let ws_base = create_websocket_api(None, None, None);
9131 let identifier = "id1".to_string();
9132
9133 let stream = Arc::new(WebsocketStream::<Value> {
9134 websocket_base: WebsocketBase::WebsocketApi(ws_base.clone()),
9135 stream_or_id: identifier.clone(),
9136 url_path: None,
9137 callback: Mutex::new(None),
9138 id: None,
9139 _phantom: PhantomData,
9140 });
9141
9142 stream.on_message(|_v: Value| {});
9143
9144 let stream_callbacks = ws_base.stream_callbacks.lock().await;
9145 let callbacks = stream_callbacks.get(&identifier).unwrap();
9146 assert_eq!(callbacks.len(), 1);
9147 });
9148 }
9149
9150 #[test]
9151 fn on_message_twice_registers_two_callbacks_for_websocket_api() {
9152 TOKIO_SHARED_RT.block_on(async {
9153 let ws_base = create_websocket_api(None, None, None);
9154 let identifier = "id2".to_string();
9155
9156 let stream = Arc::new(WebsocketStream::<Value> {
9157 websocket_base: WebsocketBase::WebsocketApi(ws_base.clone()),
9158 stream_or_id: identifier.clone(),
9159 url_path: None,
9160 callback: Mutex::new(None),
9161 id: None,
9162 _phantom: PhantomData,
9163 });
9164
9165 stream.on_message(|_v: Value| {});
9166 stream.on_message(|_v: Value| {});
9167
9168 let stream_callbacks = ws_base.stream_callbacks.lock().await;
9169 let callbacks = stream_callbacks.get(&identifier).unwrap();
9170 assert_eq!(callbacks.len(), 2);
9171 });
9172 }
9173 }
9174
9175 mod unsubscribe {
9176 use super::*;
9177
9178 #[test]
9179 fn without_callback_does_nothing() {
9180 TOKIO_SHARED_RT.block_on(async {
9181 let ws_base = create_websocket_streams(Some("example.com"), None, None);
9182 let stream_name = "s1".to_string();
9183 let conn = ws_base.common.connection_pool[0].clone();
9184 {
9185 let mut map = ws_base.connection_streams.lock().await;
9186 map.insert(stream_name.clone(), conn.clone());
9187 }
9188 let mut state = conn.state.lock().await;
9189 state.stream_callbacks.insert(stream_name.clone(), vec![]);
9190 drop(state);
9191 let stream = Arc::new(WebsocketStream::<Value> {
9192 websocket_base: WebsocketBase::WebsocketStreams(ws_base.clone()),
9193 stream_or_id: stream_name.clone(),
9194 url_path: None,
9195 callback: Mutex::new(None),
9196 id: None,
9197 _phantom: PhantomData,
9198 });
9199 stream.unsubscribe().await;
9200 let state = conn.state.lock().await;
9201 assert!(state.stream_callbacks.contains_key(&stream_name));
9202 });
9203 }
9204
9205 #[test]
9206 fn removes_registered_callback_and_clears_state() {
9207 TOKIO_SHARED_RT.block_on(async {
9208 let ws_base = create_websocket_streams(Some("example.com"), None, None);
9209 let stream_name = "s2".to_string();
9210 let conn = ws_base.common.connection_pool[0].clone();
9211 {
9212 let mut map = ws_base.connection_streams.lock().await;
9213 map.insert(stream_name.clone(), conn.clone());
9214 }
9215 {
9216 let mut state = conn.state.lock().await;
9217 state
9218 .stream_callbacks
9219 .insert(stream_name.clone(), Vec::new());
9220 }
9221 let stream = Arc::new(WebsocketStream::<Value> {
9222 websocket_base: WebsocketBase::WebsocketStreams(ws_base.clone()),
9223 stream_or_id: stream_name.clone(),
9224 url_path: None,
9225 callback: Mutex::new(None),
9226 id: None,
9227 _phantom: PhantomData,
9228 });
9229 stream.on("message", |_| {}).await;
9230 {
9231 let guard = stream.callback.lock().await;
9232 assert!(guard.is_some());
9233 }
9234 stream.unsubscribe().await;
9235 sleep(Duration::from_millis(10)).await;
9236 let guard = stream.callback.lock().await;
9237 assert!(guard.is_none());
9238 let state = conn.state.lock().await;
9239 assert!(
9240 state
9241 .stream_callbacks
9242 .get(&stream_name)
9243 .is_none_or(std::vec::Vec::is_empty)
9244 );
9245 });
9246 }
9247
9248 #[test]
9249 fn without_callback_does_nothing_for_websocket_api() {
9250 TOKIO_SHARED_RT.block_on(async {
9251 let ws_base = create_websocket_api(None, None, None);
9252 let identifier = "id1".to_string();
9253
9254 {
9255 let mut stream_callbacks = ws_base.stream_callbacks.lock().await;
9256 stream_callbacks.insert(identifier.clone(), Vec::new());
9257 }
9258
9259 let stream = Arc::new(WebsocketStream::<Value> {
9260 websocket_base: WebsocketBase::WebsocketApi(ws_base.clone()),
9261 stream_or_id: identifier.clone(),
9262 url_path: None,
9263 callback: Mutex::new(None),
9264 id: None,
9265 _phantom: PhantomData,
9266 });
9267
9268 stream.unsubscribe().await;
9269
9270 let stream_callbacks = ws_base.stream_callbacks.lock().await;
9271 assert!(stream_callbacks.contains_key(&identifier));
9272 let callbacks = stream_callbacks.get(&identifier).unwrap();
9273 assert!(callbacks.is_empty());
9274 });
9275 }
9276
9277 #[test]
9278 fn removes_registered_callback_and_clears_state_for_websocket_api() {
9279 TOKIO_SHARED_RT.block_on(async {
9280 let ws_base = create_websocket_api(None, None, None);
9281 let identifier = "id2".to_string();
9282
9283 {
9284 let mut stream_callbacks = ws_base.stream_callbacks.lock().await;
9285 stream_callbacks.insert(identifier.clone(), Vec::new());
9286 }
9287
9288 let stream = Arc::new(WebsocketStream::<Value> {
9289 websocket_base: WebsocketBase::WebsocketApi(ws_base.clone()),
9290 stream_or_id: identifier.clone(),
9291 url_path: None,
9292 callback: Mutex::new(None),
9293 id: None,
9294 _phantom: PhantomData,
9295 });
9296
9297 stream.on("message", |_| {}).await;
9298
9299 {
9300 let stream_callbacks = ws_base.stream_callbacks.lock().await;
9301 let callbacks = stream_callbacks
9302 .get(&identifier)
9303 .expect("Entry for 'id2' should exist");
9304 assert_eq!(callbacks.len(), 1);
9305 }
9306
9307 stream.unsubscribe().await;
9308
9309 {
9310 let guard = stream.callback.lock().await;
9311 assert!(guard.is_none());
9312 }
9313
9314 {
9315 let stream_callbacks = ws_base.stream_callbacks.lock().await;
9316 let callbacks = stream_callbacks
9317 .get(&identifier)
9318 .expect("Entry for 'id2' should still exist");
9319 assert!(callbacks.is_empty());
9320 }
9321 });
9322 }
9323 }
9324 }
9325
9326 mod create_stream_handler {
9327 use super::*;
9328
9329 #[test]
9330 fn create_stream_handler_without_id_registers_stream() {
9331 TOKIO_SHARED_RT.block_on(async {
9332 let ws = create_websocket_streams(Some("ws://example.com"), None, None);
9333 let stream_name = "foo".to_string();
9334 let handler = create_stream_handler::<serde_json::Value>(
9335 WebsocketBase::WebsocketStreams(ws.clone()),
9336 stream_name.clone(),
9337 None,
9338 None,
9339 )
9340 .await;
9341 assert_eq!(handler.stream_or_id, stream_name);
9342 assert!(handler.id.is_none());
9343 let map = ws.connection_streams.lock().await;
9344 assert!(map.contains_key(&stream_name));
9345 });
9346 }
9347
9348 #[test]
9349 fn create_stream_handler_with_custom_string_id_registers_stream_and_id() {
9350 TOKIO_SHARED_RT.block_on(async {
9351 let ws = create_websocket_streams(Some("ws://example.com"), None, None);
9352 let stream_name = "bar".to_string();
9353 let custom_id = StreamId::from("my-custom-id".to_string());
9354 let handler = create_stream_handler::<serde_json::Value>(
9355 WebsocketBase::WebsocketStreams(ws.clone()),
9356 stream_name.clone(),
9357 Some(custom_id.clone()),
9358 None,
9359 )
9360 .await;
9361 assert_eq!(handler.stream_or_id, stream_name);
9362 assert_eq!(handler.id, Some(custom_id));
9363 let map = ws.connection_streams.lock().await;
9364 assert!(map.contains_key(&stream_name));
9365 });
9366 }
9367
9368 #[test]
9369 fn create_stream_handler_with_custom_integer_id_registers_stream_and_id() {
9370 TOKIO_SHARED_RT.block_on(async {
9371 let ws = create_websocket_streams(Some("ws://example.com"), None, None);
9372 let stream_name = "bar".to_string();
9373 let custom_id = StreamId::from(123u32);
9374 let handler = create_stream_handler::<serde_json::Value>(
9375 WebsocketBase::WebsocketStreams(ws.clone()),
9376 stream_name.clone(),
9377 Some(custom_id.clone()),
9378 None,
9379 )
9380 .await;
9381 assert_eq!(handler.stream_or_id, stream_name);
9382 assert_eq!(handler.id, Some(custom_id));
9383 let map = ws.connection_streams.lock().await;
9384 assert!(map.contains_key(&stream_name));
9385 });
9386 }
9387
9388 #[test]
9389 fn create_stream_handler_without_id_registers_api_stream() {
9390 TOKIO_SHARED_RT.block_on(async {
9391 let ws_base = create_websocket_api(None, None, None);
9392 let identifier = "foo-api".to_string();
9393
9394 let handler = create_stream_handler::<Value>(
9395 WebsocketBase::WebsocketApi(ws_base.clone()),
9396 identifier.clone(),
9397 None,
9398 None,
9399 )
9400 .await;
9401
9402 assert_eq!(handler.stream_or_id, identifier);
9403 assert!(handler.id.is_none());
9404 });
9405 }
9406
9407 #[test]
9408 fn create_stream_handler_with_custom_string_id_registers_api_stream_and_id() {
9409 TOKIO_SHARED_RT.block_on(async {
9410 let ws_base = create_websocket_api(None, None, None);
9411 let identifier = "bar-api".to_string();
9412 let custom_id = StreamId::from("custom-123".to_string());
9413
9414 let handler = create_stream_handler::<Value>(
9415 WebsocketBase::WebsocketApi(ws_base.clone()),
9416 identifier.clone(),
9417 Some(custom_id.clone()),
9418 None,
9419 )
9420 .await;
9421
9422 assert_eq!(handler.stream_or_id, identifier);
9423 assert_eq!(handler.id, Some(custom_id));
9424 });
9425 }
9426
9427 #[test]
9428 fn create_stream_handler_with_custom_integer_id_registers_api_stream_and_id() {
9429 TOKIO_SHARED_RT.block_on(async {
9430 let ws_base = create_websocket_api(None, None, None);
9431 let identifier = "bar-api".to_string();
9432 let custom_id = StreamId::from(123u32);
9433
9434 let handler = create_stream_handler::<Value>(
9435 WebsocketBase::WebsocketApi(ws_base.clone()),
9436 identifier.clone(),
9437 Some(custom_id.clone()),
9438 None,
9439 )
9440 .await;
9441
9442 assert_eq!(handler.stream_or_id, identifier);
9443 assert_eq!(handler.id, Some(custom_id));
9444 });
9445 }
9446
9447 #[test]
9448 fn websocket_streams_without_url_path_registers_stream_key() {
9449 TOKIO_SHARED_RT.block_on(async {
9450 let ws = create_websocket_streams(Some("ws://example.com"), None, None);
9451 let stream_name = "foo".to_string();
9452
9453 let handler = create_stream_handler::<Value>(
9454 WebsocketBase::WebsocketStreams(ws.clone()),
9455 stream_name.clone(),
9456 None,
9457 None,
9458 )
9459 .await;
9460
9461 assert_eq!(handler.stream_or_id, stream_name);
9462 assert!(handler.id.is_none());
9463
9464 let map = ws.connection_streams.lock().await;
9465 assert!(map.contains_key("foo"));
9466 });
9467 }
9468
9469 #[test]
9470 fn websocket_streams_with_url_path_registers_prefixed_stream_key() {
9471 TOKIO_SHARED_RT.block_on(async {
9472 let ws = create_websocket_streams(Some("ws://example.com"), None, None);
9473
9474 {
9475 let conn = ws.common.connection_pool[0].clone();
9476 let mut st = conn.state.lock().await;
9477 st.url_path = Some("path1".to_string());
9478 }
9479
9480 let stream_name = "foo".to_string();
9481
9482 let handler = create_stream_handler::<Value>(
9483 WebsocketBase::WebsocketStreams(ws.clone()),
9484 stream_name.clone(),
9485 None,
9486 Some("path1".to_string()),
9487 )
9488 .await;
9489
9490 assert_eq!(handler.stream_or_id, stream_name);
9491 assert!(handler.id.is_none());
9492
9493 let map = ws.connection_streams.lock().await;
9494 assert!(map.contains_key("path1::foo"));
9495 });
9496 }
9497
9498 #[test]
9499 fn websocket_streams_with_custom_id_preserves_id_and_registers_prefixed_key() {
9500 TOKIO_SHARED_RT.block_on(async {
9501 let ws = create_websocket_streams(Some("ws://example.com"), None, None);
9502
9503 {
9504 let conn = ws.common.connection_pool[0].clone();
9505 let mut st = conn.state.lock().await;
9506 st.url_path = Some("path1".to_string());
9507 }
9508
9509 let stream_name = "bar".to_string();
9510 let custom_id = StreamId::from("my-custom-id".to_string());
9511
9512 let handler = create_stream_handler::<Value>(
9513 WebsocketBase::WebsocketStreams(ws.clone()),
9514 stream_name.clone(),
9515 Some(custom_id.clone()),
9516 Some("path1".to_string()),
9517 )
9518 .await;
9519
9520 assert_eq!(handler.stream_or_id, stream_name);
9521 assert_eq!(handler.id, Some(custom_id));
9522
9523 let map = ws.connection_streams.lock().await;
9524 assert!(map.contains_key("path1::bar"));
9525 });
9526 }
9527
9528 #[test]
9529 fn websocket_api_does_not_register_stream_in_connection_map() {
9530 TOKIO_SHARED_RT.block_on(async {
9531 let ws_base = create_websocket_api(None, None, None);
9532 let identifier = "foo-api".to_string();
9533
9534 let handler = create_stream_handler::<Value>(
9535 WebsocketBase::WebsocketApi(ws_base.clone()),
9536 identifier.clone(),
9537 None,
9538 Some("path1".to_string()),
9539 )
9540 .await;
9541
9542 assert_eq!(handler.stream_or_id, identifier);
9543 assert!(handler.id.is_none());
9544 });
9545 }
9546 }
9547
9548 mod websocket_connection_failure_reason {
9549 use super::*;
9550 use std::io::{Error as IoError, ErrorKind};
9551 use tokio_tungstenite::tungstenite::Error as TungsteniteError;
9552
9553 #[test]
9554 fn from_tungstenite_error_classifies_connection_closed() {
9555 let error = TungsteniteError::ConnectionClosed;
9556 let reason = WebsocketConnectionFailureReason::from_tungstenite_error(&error);
9557 assert!(matches!(
9558 reason,
9559 WebsocketConnectionFailureReason::ConnectionReset
9560 ));
9561 assert!(reason.should_reconnect());
9562 }
9563
9564 #[test]
9565 fn from_tungstenite_error_classifies_already_closed() {
9566 let error = TungsteniteError::AlreadyClosed;
9567 let reason = WebsocketConnectionFailureReason::from_tungstenite_error(&error);
9568 assert!(matches!(
9569 reason,
9570 WebsocketConnectionFailureReason::ConnectionReset
9571 ));
9572 assert!(reason.should_reconnect());
9573 }
9574
9575 #[test]
9576 fn from_tungstenite_error_classifies_io_errors() {
9577 let io_error = IoError::new(ErrorKind::ConnectionReset, "connection reset");
9579 let error = TungsteniteError::Io(io_error);
9580 let reason = WebsocketConnectionFailureReason::from_tungstenite_error(&error);
9581 assert!(matches!(
9582 reason,
9583 WebsocketConnectionFailureReason::ConnectionReset
9584 ));
9585 assert!(reason.should_reconnect());
9586
9587 let io_error = IoError::new(ErrorKind::ConnectionAborted, "connection aborted");
9589 let error = TungsteniteError::Io(io_error);
9590 let reason = WebsocketConnectionFailureReason::from_tungstenite_error(&error);
9591 assert!(matches!(
9592 reason,
9593 WebsocketConnectionFailureReason::ConnectionReset
9594 ));
9595 assert!(reason.should_reconnect());
9596
9597 let io_error = IoError::new(ErrorKind::TimedOut, "timed out");
9599 let error = TungsteniteError::Io(io_error);
9600 let reason = WebsocketConnectionFailureReason::from_tungstenite_error(&error);
9601 assert!(matches!(
9602 reason,
9603 WebsocketConnectionFailureReason::NetworkInterruption
9604 ));
9605 assert!(reason.should_reconnect());
9606
9607 let io_error = IoError::new(ErrorKind::UnexpectedEof, "unexpected eof");
9609 let error = TungsteniteError::Io(io_error);
9610 let reason = WebsocketConnectionFailureReason::from_tungstenite_error(&error);
9611 assert!(matches!(
9612 reason,
9613 WebsocketConnectionFailureReason::StreamEnded
9614 ));
9615 assert!(reason.should_reconnect());
9616
9617 let io_error = IoError::new(ErrorKind::PermissionDenied, "permission denied");
9619 let error = TungsteniteError::Io(io_error);
9620 let reason = WebsocketConnectionFailureReason::from_tungstenite_error(&error);
9621 assert!(matches!(
9622 reason,
9623 WebsocketConnectionFailureReason::AuthenticationFailure
9624 ));
9625 assert!(!reason.should_reconnect());
9626
9627 let io_error = IoError::other("other error");
9629 let error = TungsteniteError::Io(io_error);
9630 let reason = WebsocketConnectionFailureReason::from_tungstenite_error(&error);
9631 assert!(matches!(
9632 reason,
9633 WebsocketConnectionFailureReason::NetworkInterruption
9634 ));
9635 assert!(reason.should_reconnect());
9636 }
9637
9638 #[test]
9639 fn from_tungstenite_error_classifies_protocol_errors() {
9640 use tokio_tungstenite::tungstenite::error::ProtocolError;
9642 let protocol_error = ProtocolError::ResetWithoutClosingHandshake;
9643 let error = TungsteniteError::Protocol(protocol_error);
9644 let reason = WebsocketConnectionFailureReason::from_tungstenite_error(&error);
9645 assert!(matches!(
9646 reason,
9647 WebsocketConnectionFailureReason::UnexpectedClose
9648 ));
9649 assert!(reason.should_reconnect());
9650
9651 let error = TungsteniteError::Utf8;
9653 let reason = WebsocketConnectionFailureReason::from_tungstenite_error(&error);
9654 assert!(matches!(
9655 reason,
9656 WebsocketConnectionFailureReason::ProtocolViolation
9657 ));
9658 assert!(!reason.should_reconnect());
9659 }
9660
9661 #[test]
9662 fn from_close_code_classifies_standard_codes() {
9663 let reason = WebsocketConnectionFailureReason::from_close_code(1000, false);
9665 assert!(matches!(
9666 reason,
9667 WebsocketConnectionFailureReason::NormalClose
9668 ));
9669 assert!(!reason.should_reconnect());
9670
9671 let reason = WebsocketConnectionFailureReason::from_close_code(1001, false);
9673 assert!(matches!(
9674 reason,
9675 WebsocketConnectionFailureReason::ServerTemporaryError
9676 ));
9677 assert!(reason.should_reconnect());
9678
9679 let reason = WebsocketConnectionFailureReason::from_close_code(1002, false);
9681 assert!(matches!(
9682 reason,
9683 WebsocketConnectionFailureReason::ProtocolViolation
9684 ));
9685 assert!(!reason.should_reconnect());
9686
9687 let reason = WebsocketConnectionFailureReason::from_close_code(1006, false);
9689 assert!(matches!(
9690 reason,
9691 WebsocketConnectionFailureReason::UnexpectedClose
9692 ));
9693 assert!(reason.should_reconnect());
9694
9695 let reason = WebsocketConnectionFailureReason::from_close_code(1008, false);
9697 assert!(matches!(
9698 reason,
9699 WebsocketConnectionFailureReason::PermanentServerError
9700 ));
9701 assert!(!reason.should_reconnect());
9702
9703 let reason = WebsocketConnectionFailureReason::from_close_code(1011, false);
9705 assert!(matches!(
9706 reason,
9707 WebsocketConnectionFailureReason::ServerTemporaryError
9708 ));
9709 assert!(reason.should_reconnect());
9710
9711 let reason = WebsocketConnectionFailureReason::from_close_code(1015, false);
9713 assert!(matches!(
9714 reason,
9715 WebsocketConnectionFailureReason::ConfigurationError
9716 ));
9717 assert!(!reason.should_reconnect());
9718
9719 let reason = WebsocketConnectionFailureReason::from_close_code(4000, false);
9721 assert!(matches!(
9722 reason,
9723 WebsocketConnectionFailureReason::PermanentServerError
9724 ));
9725 assert!(!reason.should_reconnect());
9726
9727 let reason = WebsocketConnectionFailureReason::from_close_code(4999, false);
9728 assert!(matches!(
9729 reason,
9730 WebsocketConnectionFailureReason::PermanentServerError
9731 ));
9732 assert!(!reason.should_reconnect());
9733
9734 let reason = WebsocketConnectionFailureReason::from_close_code(9999, false);
9736 assert!(matches!(
9737 reason,
9738 WebsocketConnectionFailureReason::UnexpectedClose
9739 ));
9740 assert!(reason.should_reconnect());
9741 }
9742
9743 #[test]
9744 fn from_close_code_handles_user_initiated() {
9745 let reason = WebsocketConnectionFailureReason::from_close_code(1000, true);
9747 assert!(matches!(
9748 reason,
9749 WebsocketConnectionFailureReason::UserInitiatedClose
9750 ));
9751 assert!(!reason.should_reconnect());
9752
9753 let reason = WebsocketConnectionFailureReason::from_close_code(1006, true);
9754 assert!(matches!(
9755 reason,
9756 WebsocketConnectionFailureReason::UserInitiatedClose
9757 ));
9758 assert!(!reason.should_reconnect());
9759
9760 let reason = WebsocketConnectionFailureReason::from_close_code(4000, true);
9761 assert!(matches!(
9762 reason,
9763 WebsocketConnectionFailureReason::UserInitiatedClose
9764 ));
9765 assert!(!reason.should_reconnect());
9766 }
9767
9768 #[test]
9769 fn should_reconnect_logic() {
9770 assert!(WebsocketConnectionFailureReason::NetworkInterruption.should_reconnect());
9772 assert!(WebsocketConnectionFailureReason::ConnectionReset.should_reconnect());
9773 assert!(WebsocketConnectionFailureReason::ServerTemporaryError.should_reconnect());
9774 assert!(WebsocketConnectionFailureReason::UnexpectedClose.should_reconnect());
9775 assert!(WebsocketConnectionFailureReason::StreamEnded.should_reconnect());
9776
9777 assert!(!WebsocketConnectionFailureReason::AuthenticationFailure.should_reconnect());
9779 assert!(!WebsocketConnectionFailureReason::ProtocolViolation.should_reconnect());
9780 assert!(!WebsocketConnectionFailureReason::ConfigurationError.should_reconnect());
9781 assert!(!WebsocketConnectionFailureReason::UserInitiatedClose.should_reconnect());
9782 assert!(!WebsocketConnectionFailureReason::PermanentServerError.should_reconnect());
9783 assert!(!WebsocketConnectionFailureReason::NormalClose.should_reconnect());
9784 }
9785
9786 #[test]
9787 fn debug_and_clone_work() {
9788 let reason = WebsocketConnectionFailureReason::NetworkInterruption;
9789 let cloned = reason;
9790 let debug_str = format!("{:?}", reason);
9791
9792 assert!(matches!(
9793 cloned,
9794 WebsocketConnectionFailureReason::NetworkInterruption
9795 ));
9796 assert!(debug_str.contains("NetworkInterruption"));
9797 }
9798 }
9799}