1use async_trait::async_trait;
2use flate2::read::ZlibDecoder;
3use futures::{SinkExt, StreamExt, stream::FuturesUnordered};
4use http::header::USER_AGENT;
5use regex::Regex;
6use serde::de::DeserializeOwned;
7use serde_json::{Value, json};
8use std::{
9 collections::{BTreeMap, HashMap, VecDeque},
10 io::Read,
11 marker::PhantomData,
12 mem::take,
13 sync::{
14 Arc, LazyLock,
15 atomic::{AtomicUsize, Ordering},
16 },
17 time::Duration,
18};
19use tokio::{
20 net::TcpStream,
21 select, spawn,
22 sync::{
23 Mutex, Notify, broadcast,
24 mpsc::{Receiver, Sender, UnboundedSender, channel, unbounded_channel},
25 oneshot,
26 },
27 task::JoinHandle,
28 time::{sleep, timeout},
29};
30use tokio_tungstenite::{
31 Connector, MaybeTlsStream, WebSocketStream, connect_async_tls_with_config,
32 tungstenite::{
33 Message,
34 client::IntoClientRequest,
35 protocol::{CloseFrame, WebSocketConfig, frame::coding::CloseCode},
36 },
37};
38use tokio_util::time::DelayQueue;
39use tracing::{debug, error, info, warn};
40
41use crate::common::utils::{remove_empty_value, sort_object_params};
42
43use super::{
44 config::{AgentConnector, ConfigurationWebsocketApi, ConfigurationWebsocketStreams},
45 errors::WebsocketError,
46 models::{WebsocketApiResponse, WebsocketEvent, WebsocketMode},
47 utils::{get_timestamp, random_string, validate_time_unit},
48};
49
50static ID_REGEX: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"^[0-9a-f]{32}$").unwrap());
51
52pub type WebSocketClient = WebSocketStream<MaybeTlsStream<TcpStream>>;
53
54const MAX_CONN_DURATION: Duration = Duration::from_secs(23 * 60 * 60);
55
56pub struct Subscription {
57 handle: JoinHandle<()>,
58}
59
60impl Subscription {
61 pub fn unsubscribe(self) {
76 self.handle.abort();
77 }
78}
79
80#[derive(Clone)]
81pub enum WebsocketBase {
82 WebsocketApi(Arc<WebsocketApi>),
83 WebsocketStreams(Arc<WebsocketStreams>),
84}
85
86pub struct WebsocketEventEmitter {
87 tx: broadcast::Sender<WebsocketEvent>,
88}
89
90impl Default for WebsocketEventEmitter {
91 fn default() -> Self {
92 Self::new()
93 }
94}
95
96impl WebsocketEventEmitter {
97 #[must_use]
98 pub fn new() -> Self {
99 let (tx, _rx) = broadcast::channel(100);
100 Self { tx }
101 }
102
103 pub fn subscribe<F>(&self, mut callback: F) -> Subscription
124 where
125 F: FnMut(WebsocketEvent) + Send + 'static,
126 {
127 let mut rx = self.tx.subscribe();
128 let handle = spawn(async move {
129 while let Ok(event) = rx.recv().await {
130 callback(event);
131 }
132 });
133 Subscription { handle }
134 }
135
136 fn emit(&self, event: WebsocketEvent) {
147 let _ = self.tx.send(event);
148 }
149}
150
151#[async_trait]
166pub trait WebsocketHandler: Send + Sync + 'static {
167 async fn on_open(&self, url: String, connection: Arc<WebsocketConnection>);
168 async fn on_message(&self, data: String, connection: Arc<WebsocketConnection>);
169 async fn get_reconnect_url(
170 &self,
171 default_url: String,
172 connection: Arc<WebsocketConnection>,
173 ) -> String;
174}
175
176pub struct PendingRequest {
177 pub completion: oneshot::Sender<Result<Value, WebsocketError>>,
178}
179
180pub struct WebsocketConnectionState {
181 pub reconnection_pending: bool,
182 pub renewal_pending: bool,
183 pub close_initiated: bool,
184 pub pending_requests: HashMap<String, PendingRequest>,
185 pub pending_subscriptions: VecDeque<String>,
186 pub stream_callbacks: HashMap<String, Vec<Arc<dyn Fn(&Value) + Send + Sync + 'static>>>,
187 pub handler: Option<Arc<dyn WebsocketHandler>>,
188 pub ws_write_tx: Option<UnboundedSender<Message>>,
189}
190
191impl Default for WebsocketConnectionState {
192 fn default() -> Self {
193 Self::new()
194 }
195}
196
197impl WebsocketConnectionState {
198 #[must_use]
199 pub fn new() -> Self {
200 Self {
201 reconnection_pending: false,
202 renewal_pending: false,
203 close_initiated: false,
204 pending_requests: HashMap::new(),
205 pending_subscriptions: VecDeque::new(),
206 stream_callbacks: HashMap::new(),
207 handler: None,
208 ws_write_tx: None,
209 }
210 }
211}
212
213pub struct WebsocketConnection {
214 pub id: String,
215 pub drain_notify: Notify,
216 pub state: Mutex<WebsocketConnectionState>,
217}
218
219impl WebsocketConnection {
220 pub fn new(id: impl Into<String>) -> Arc<Self> {
221 Arc::new(Self {
222 id: id.into(),
223 drain_notify: Notify::new(),
224 state: Mutex::new(WebsocketConnectionState::new()),
225 })
226 }
227
228 pub async fn set_handler(&self, handler: Arc<dyn WebsocketHandler>) {
229 let mut conn_state = self.state.lock().await;
230 conn_state.handler = Some(handler);
231 }
232}
233
234struct ReconnectEntry {
235 connection_id: String,
236 url: String,
237 is_renewal: bool,
238}
239
240pub struct WebsocketCommon {
241 pub events: WebsocketEventEmitter,
242 mode: WebsocketMode,
243 round_robin_index: AtomicUsize,
244 connection_pool: Vec<Arc<WebsocketConnection>>,
245 reconnect_tx: Sender<ReconnectEntry>,
246 renewal_tx: Sender<(String, String)>,
247 reconnect_delay: usize,
248 agent: Option<AgentConnector>,
249 user_agent: Option<String>,
250}
251
252impl WebsocketCommon {
253 #[must_use]
254 pub fn new(
255 mut initial_pool: Vec<Arc<WebsocketConnection>>,
256 mode: WebsocketMode,
257 reconnect_delay: usize,
258 agent: Option<AgentConnector>,
259 user_agent: Option<String>,
260 ) -> Arc<Self> {
261 if initial_pool.is_empty() {
262 for _ in 0..mode.pool_size() {
263 let id = random_string();
264 initial_pool.push(WebsocketConnection::new(id));
265 }
266 }
267
268 let (reconnect_tx, reconnect_rx) = channel::<ReconnectEntry>(mode.pool_size());
269 let (renewal_tx, renewal_rx) = channel::<(String, String)>(mode.pool_size());
270
271 let common = Arc::new(Self {
272 events: WebsocketEventEmitter::new(),
273 mode,
274 round_robin_index: AtomicUsize::new(0),
275 connection_pool: initial_pool,
276 reconnect_tx,
277 renewal_tx,
278 reconnect_delay,
279 agent,
280 user_agent,
281 });
282
283 Self::spawn_reconnect_loop(Arc::clone(&common), reconnect_rx);
284 Self::spawn_renewal_loop(&Arc::clone(&common), renewal_rx);
285
286 common
287 }
288
289 fn spawn_reconnect_loop(common: Arc<Self>, mut reconnect_rx: Receiver<ReconnectEntry>) {
307 spawn(async move {
308 while let Some(entry) = reconnect_rx.recv().await {
309 info!("Scheduling reconnect for id {}", entry.connection_id);
310
311 if !entry.is_renewal {
312 sleep(Duration::from_millis(common.reconnect_delay as u64)).await;
313 }
314
315 if let Some(conn_arc) = common
316 .connection_pool
317 .iter()
318 .find(|c| c.id == entry.connection_id)
319 .cloned()
320 {
321 let common_clone = Arc::clone(&common);
322 if let Err(err) = common_clone
323 .init_connect(&entry.url, entry.is_renewal, Some(conn_arc.clone()))
324 .await
325 {
326 error!(
327 "Reconnect failed for {} → {}: {:?}",
328 entry.connection_id, entry.url, err
329 );
330 }
331
332 sleep(Duration::from_secs(1)).await;
333 } else {
334 warn!("No connection {} found for reconnect", entry.connection_id);
335 }
336 }
337 });
338 }
339
340 fn spawn_renewal_loop(common: &Arc<Self>, renewal_rx: Receiver<(String, String)>) {
354 let common = Arc::clone(common);
355 spawn(async move {
356 let mut dq = DelayQueue::new();
357 let mut renewal_rx = renewal_rx;
358
359 loop {
360 select! {
361 Some((conn_id, url)) = renewal_rx.recv() => {
362 debug!("Scheduling renewal for {}", conn_id);
363 dq.insert((conn_id, url), MAX_CONN_DURATION);
364 }
365
366 Some(expired) = dq.next() => {
367 let (conn_id, default_url) = expired.into_inner();
368
369 if let Some(conn_arc) = common
370 .connection_pool
371 .iter()
372 .find(|c| c.id == conn_id)
373 .cloned()
374 {
375 debug!("Renewing connection {}", conn_id);
376 let url = common
377 .get_reconnect_url(&default_url, Arc::clone(&conn_arc))
378 .await;
379 if let Err(e) = common.reconnect_tx.send(ReconnectEntry {
380 connection_id: conn_id.clone(),
381 url,
382 is_renewal: true,
383 }).await {
384 error!(
385 "Failed to enqueue renewal for {}: {:?}",
386 conn_id, e
387 );
388 }
389 } else {
390 warn!("No connection {} found for renewal", conn_id);
391 }
392 }
393 }
394 }
395 });
396 }
397
398 pub async fn is_connection_ready(
417 &self,
418 connection: &WebsocketConnection,
419 allow_non_established: bool,
420 ) -> bool {
421 let conn_state = connection.state.lock().await;
422 (allow_non_established || conn_state.ws_write_tx.is_some())
423 && !conn_state.renewal_pending
424 && !conn_state.reconnection_pending
425 && !conn_state.close_initiated
426 }
427
428 async fn is_connected(&self, connection: Option<&Arc<WebsocketConnection>>) -> bool {
444 if let Some(conn_arc) = connection {
445 return self.is_connection_ready(conn_arc, false).await;
446 }
447
448 for conn_arc in &self.connection_pool {
449 if self.is_connection_ready(conn_arc, false).await {
450 return true;
451 }
452 }
453
454 false
455 }
456
457 async fn get_connection(
477 &self,
478 allow_non_established: bool,
479 ) -> Result<Arc<WebsocketConnection>, WebsocketError> {
480 if let WebsocketMode::Single = self.mode {
481 return Ok(Arc::clone(&self.connection_pool[0]));
482 }
483
484 let mut ready = Vec::new();
485 for conn in &self.connection_pool {
486 if self.is_connection_ready(conn, allow_non_established).await {
487 ready.push(Arc::clone(conn));
488 }
489 }
490
491 if ready.is_empty() {
492 return Err(WebsocketError::NotConnected);
493 }
494
495 let idx = self.round_robin_index.fetch_add(1, Ordering::Relaxed) % ready.len();
496
497 Ok(Arc::clone(&ready[idx]))
498 }
499
500 async fn close_connection_gracefully(
517 &self,
518 ws_write_tx_to_close: UnboundedSender<Message>,
519 connection: Arc<WebsocketConnection>,
520 ) -> Result<(), WebsocketError> {
521 debug!("Waiting for pending requests to complete before disconnecting.");
522
523 let drain = async {
524 loop {
525 {
526 let conn_state = connection.state.lock().await;
527 if conn_state.pending_requests.is_empty() {
528 debug!("All pending requests completed, proceeding to close.");
529 break;
530 }
531 }
532 connection.drain_notify.notified().await;
533 }
534 };
535
536 if timeout(Duration::from_secs(30), drain).await.is_err() {
537 warn!("Timeout waiting for pending requests; forcing close.");
538 }
539
540 info!("Closing WebSocket connection for {}", connection.id);
541 let _ = ws_write_tx_to_close.send(Message::Close(Some(CloseFrame {
542 code: CloseCode::Normal,
543 reason: "".into(),
544 })));
545
546 Ok(())
547 }
548
549 async fn get_reconnect_url(
566 &self,
567 default_url: &str,
568 connection: Arc<WebsocketConnection>,
569 ) -> String {
570 if let Some(handler) = {
571 let conn_state = connection.state.lock().await;
572 conn_state.handler.clone()
573 } {
574 return handler
575 .get_reconnect_url(default_url.to_string(), Arc::clone(&connection))
576 .await;
577 }
578
579 default_url.to_string()
580 }
581
582 async fn on_open(
604 &self,
605 url: String,
606 connection: Arc<WebsocketConnection>,
607 old_ws_writer: Option<UnboundedSender<Message>>,
608 ) {
609 if let Some(handler) = {
610 let conn_state = connection.state.lock().await;
611 conn_state.handler.clone()
612 } {
613 handler.on_open(url.clone(), Arc::clone(&connection)).await;
614 }
615
616 let conn_id = &connection.id;
617 info!("Connected to WebSocket Server with id {}: {}", conn_id, url);
618
619 {
620 let mut conn_state = connection.state.lock().await;
621
622 if conn_state.renewal_pending {
623 conn_state.renewal_pending = false;
624 drop(conn_state);
625 if let Some(tx) = old_ws_writer {
626 info!("Connection renewal in progress; closing previous connection.");
627 let _ = self
628 .close_connection_gracefully(tx, Arc::clone(&connection))
629 .await;
630 }
631 return;
632 }
633
634 if conn_state.close_initiated {
635 drop(conn_state);
636 if let Some(tx) = connection.state.lock().await.ws_write_tx.clone() {
637 info!("Close initiated; closing connection.");
638 let _ = self
639 .close_connection_gracefully(tx, Arc::clone(&connection))
640 .await;
641 }
642 return;
643 }
644
645 self.events.emit(WebsocketEvent::Open);
646 }
647 }
648
649 async fn on_message(&self, msg: String, connection: Arc<WebsocketConnection>) {
661 if let Some(handler) = connection.state.lock().await.handler.clone() {
662 let handler_clone = handler.clone();
663 let data = msg.clone();
664 let conn_clone = connection.clone();
665 spawn(async move {
666 handler_clone.on_message(data, conn_clone).await;
667 });
668 }
669 self.events.emit(WebsocketEvent::Message(msg));
670 }
671
672 async fn create_websocket(
695 url: &str,
696 agent: Option<AgentConnector>,
697 user_agent: Option<String>,
698 ) -> Result<WebSocketStream<MaybeTlsStream<tokio::net::TcpStream>>, WebsocketError> {
699 let mut req = url
700 .into_client_request()
701 .map_err(|e| WebsocketError::Handshake(e.to_string()))?;
702
703 if let Some(ua) = user_agent {
704 req.headers_mut().insert(USER_AGENT, ua.parse().unwrap());
705 }
706
707 let ws_config: Option<WebSocketConfig> = None;
708 let disable_nagle = false;
709 let connector: Option<Connector> = agent.map(|dbg| dbg.0);
710
711 let timeout_duration = Duration::from_secs(10);
712 let handshake = connect_async_tls_with_config(req, ws_config, disable_nagle, connector);
713 match timeout(timeout_duration, handshake).await {
714 Ok(Ok((ws_stream, response))) => {
715 debug!("WebSocket connected: {:?}", response);
716 Ok(ws_stream)
717 }
718 Ok(Err(e)) => {
719 let msg = e.to_string();
720 error!("WebSocket handshake failed: {}", msg);
721 Err(WebsocketError::Handshake(msg))
722 }
723 Err(_) => {
724 error!(
725 "WebSocket connection timed out after {}s",
726 timeout_duration.as_secs()
727 );
728 Err(WebsocketError::Timeout)
729 }
730 }
731 }
732
733 async fn connect_pool(self: Arc<Self>, url: &str) -> Result<(), WebsocketError> {
752 let mut tasks = FuturesUnordered::new();
753
754 for conn in &self.connection_pool {
755 let common = Arc::clone(&self);
756 let url = url.to_owned();
757 let conn_clone = Arc::clone(conn);
758
759 tasks.push(async move {
760 match common.init_connect(&url, false, Some(conn_clone)).await {
761 Ok(()) => {
762 info!("Successfully connected to {}", url);
763 Ok(())
764 }
765 Err(err) => {
766 error!("Failed to connect to {}: {:?}", url, err);
767 Err(err)
768 }
769 }
770 });
771 }
772
773 while let Some(result) = tasks.next().await {
774 result?;
775 }
776
777 Ok(())
778 }
779
780 async fn init_connect(
801 self: Arc<Self>,
802 url: &str,
803 is_renewal: bool,
804 connection: Option<Arc<WebsocketConnection>>,
805 ) -> Result<(), WebsocketError> {
806 let conn = connection.unwrap_or(self.get_connection(true).await?);
807
808 {
809 let mut conn_state = conn.state.lock().await;
810 if conn_state.renewal_pending && is_renewal {
811 info!("Renewal in progress {}→{}", conn.id, url);
812 return Ok(());
813 }
814 if conn_state.ws_write_tx.is_some() && !is_renewal {
815 info!("Exists {}; skipping {}", conn.id, url);
816 return Ok(());
817 }
818 if is_renewal {
819 conn_state.renewal_pending = true;
820 }
821 }
822
823 let ws = Self::create_websocket(url, self.agent.clone(), self.user_agent.clone())
824 .await
825 .map_err(|e| {
826 error!("Handshake failed {}: {:?}", url, e);
827 e
828 })?;
829
830 info!("Established {} → {}", conn.id, url);
831
832 if let Err(e) = self.renewal_tx.try_send((conn.id.clone(), url.to_string())) {
833 error!("Failed to schedule renewal for {}: {:?}", conn.id, e);
834 }
835
836 let (write_half, mut read_half) = ws.split();
837 let (tx, mut rx) = unbounded_channel::<Message>();
838
839 let old_writer = {
840 let mut conn_state = conn.state.lock().await;
841 conn_state.ws_write_tx.replace(tx.clone())
842 };
843
844 let wconn = conn.clone();
845
846 spawn(async move {
847 let mut sink = write_half;
848 while let Some(msg) = rx.recv().await {
849 if sink.send(msg).await.is_err() {
850 error!("Write error {}", wconn.id);
851 break;
852 }
853 }
854 debug!("Writer {} exit", wconn.id);
855 });
856
857 self.on_open(url.to_string(), conn.clone(), old_writer)
858 .await;
859
860 let common = self.clone();
861 let reader_conn = conn.clone();
862 let read_url = url.to_string();
863
864 spawn(async move {
865 while let Some(item) = read_half.next().await {
866 match item {
867 Ok(Message::Text(msg)) => {
868 common
869 .on_message(msg.to_string(), Arc::clone(&reader_conn))
870 .await;
871 }
872 Ok(Message::Binary(bin)) => {
873 let mut decoder = ZlibDecoder::new(&bin[..]);
874 let mut decompressed = String::new();
875 if let Err(err) = decoder.read_to_string(&mut decompressed) {
876 error!("Binary message decompress failed: {:?}", err);
877 continue;
878 }
879 common
880 .on_message(decompressed, Arc::clone(&reader_conn))
881 .await;
882 }
883 Ok(Message::Ping(payload)) => {
884 info!("PING received from server on {}", reader_conn.id);
885 common.events.emit(WebsocketEvent::Ping);
886 if let Some(tx) = reader_conn.state.lock().await.ws_write_tx.clone() {
887 let _ = tx.send(Message::Pong(payload));
888 info!(
889 "Responded PONG to server's PING message on {}",
890 reader_conn.id
891 );
892 }
893 }
894 Ok(Message::Pong(_)) => {
895 info!("Received PONG from server on {}", reader_conn.id);
896 common.events.emit(WebsocketEvent::Pong);
897 }
898 Ok(Message::Close(frame)) => {
899 let (code, reason) = frame
900 .map_or((1000, String::new()), |CloseFrame { code, reason }| {
901 (code.into(), reason.to_string())
902 });
903 common
904 .events
905 .emit(WebsocketEvent::Close(code, reason.clone()));
906
907 let mut conn_state = reader_conn.state.lock().await;
908 if !conn_state.close_initiated
909 && !is_renewal
910 && CloseCode::from(code) != CloseCode::Normal
911 {
912 warn!(
913 "Connection {} closed due to {}: {}",
914 reader_conn.id, code, reason
915 );
916 conn_state.reconnection_pending = true;
917 drop(conn_state);
918 let reconnect_url = common
919 .get_reconnect_url(&read_url, Arc::clone(&reader_conn))
920 .await;
921
922 let _ = common
923 .reconnect_tx
924 .send(ReconnectEntry {
925 connection_id: reader_conn.id.clone(),
926 url: reconnect_url,
927 is_renewal: false,
928 })
929 .await;
930 }
931 break;
932 }
933 Err(e) => {
934 error!("WebSocket error on {}: {:?}", reader_conn.id, e);
935 common.events.emit(WebsocketEvent::Error(e.to_string()));
936 }
937 _ => {}
938 }
939 }
940 debug!("Reader actor for {} exiting", reader_conn.id);
941 });
942
943 Ok(())
944 }
945 async fn disconnect(&self) -> Result<(), WebsocketError> {
960 if !self.is_connected(None).await {
961 warn!("No active connection to close.");
962 return Ok(());
963 }
964
965 let mut shutdowns = FuturesUnordered::new();
966 for conn in &self.connection_pool {
967 {
968 let mut conn_state = conn.state.lock().await;
969 conn_state.close_initiated = true;
970 if let Some(tx) = &conn_state.ws_write_tx {
971 shutdowns.push(self.close_connection_gracefully(tx.clone(), Arc::clone(conn)));
972 }
973 }
974 }
975
976 let close_all = async {
977 while let Some(result) = shutdowns.next().await {
978 result?;
979 }
980 Ok::<(), WebsocketError>(())
981 };
982
983 match timeout(Duration::from_secs(30), close_all).await {
984 Ok(Ok(())) => {
985 info!("Disconnected all WebSocket connections successfully.");
986 Ok(())
987 }
988 Ok(Err(err)) => {
989 error!("Error while disconnecting: {:?}", err);
990 Err(err)
991 }
992 Err(_) => {
993 error!("Timed out while disconnecting WebSocket connections.");
994 Err(WebsocketError::Timeout)
995 }
996 }
997 }
998
999 async fn ping_server(&self) {
1012 let mut ready = Vec::new();
1013 for conn in &self.connection_pool {
1014 if self.is_connection_ready(conn, false).await {
1015 let id = conn.id.clone();
1016 let ws_write_tx = {
1017 let conn_state = conn.state.lock().await;
1018 conn_state.ws_write_tx.clone()
1019 };
1020 ready.push((id, ws_write_tx));
1021 }
1022 }
1023
1024 if ready.is_empty() {
1025 warn!("No ready connections for PING.");
1026 return;
1027 }
1028 info!("Sending PING to {} WebSocket connections.", ready.len());
1029
1030 let mut tasks = FuturesUnordered::new();
1031 for (id, ws_write_tx_opt) in ready {
1032 if let Some(tx) = ws_write_tx_opt {
1033 tasks.push(async move {
1034 if let Err(e) = tx.send(Message::Ping(Vec::new().into())) {
1035 error!("Failed to send PING to {}: {:?}", id, e);
1036 } else {
1037 debug!("Sent PING to connection {}", id);
1038 }
1039 });
1040 } else {
1041 error!("Connection {} was ready but has no write channel", id);
1042 }
1043 }
1044
1045 while tasks.next().await.is_some() {}
1046 }
1047
1048 async fn send(
1066 &self,
1067 payload: String,
1068 id: Option<String>,
1069 wait_for_reply: bool,
1070 timeout: Duration,
1071 connection: Option<Arc<WebsocketConnection>>,
1072 ) -> Result<Option<oneshot::Receiver<Result<Value, WebsocketError>>>, WebsocketError> {
1073 let conn = if let Some(c) = connection {
1074 c
1075 } else {
1076 self.get_connection(false).await?
1077 };
1078
1079 if !self.is_connected(Some(&conn)).await {
1080 warn!("Send attempted on a non-connected socket");
1081 return Err(WebsocketError::NotConnected);
1082 }
1083
1084 let ws_write_tx = {
1085 let conn_state = conn.state.lock().await;
1086 conn_state
1087 .ws_write_tx
1088 .clone()
1089 .ok_or(WebsocketError::NotConnected)?
1090 };
1091
1092 debug!("Sending message to WebSocket on connection {}", conn.id);
1093
1094 ws_write_tx
1095 .send(Message::Text(payload.clone().into()))
1096 .map_err(|_| WebsocketError::NotConnected)?;
1097
1098 if !wait_for_reply {
1099 return Ok(None);
1100 }
1101
1102 let request_id = id.ok_or_else(|| {
1103 error!("id is required when waiting for a reply");
1104 WebsocketError::NotConnected
1105 })?;
1106
1107 let (tx, rx) = oneshot::channel();
1108 {
1109 let mut conn_state = conn.state.lock().await;
1110 conn_state
1111 .pending_requests
1112 .insert(request_id.clone(), PendingRequest { completion: tx });
1113 }
1114
1115 let conn_clone = Arc::clone(&conn);
1116 spawn(async move {
1117 sleep(timeout).await;
1118 let mut conn_state = conn_clone.state.lock().await;
1119 if let Some(pending_req) = conn_state.pending_requests.remove(&request_id) {
1120 let _ = pending_req.completion.send(Err(WebsocketError::Timeout));
1121 }
1122 });
1123
1124 Ok(Some(rx))
1125 }
1126}
1127
1128pub struct WebsocketMessageSendOptions {
1129 pub with_api_key: bool,
1130 pub is_signed: bool,
1131}
1132
1133pub struct WebsocketApi {
1134 pub common: Arc<WebsocketCommon>,
1135 configuration: ConfigurationWebsocketApi,
1136 is_connecting: Arc<Mutex<bool>>,
1137 stream_callbacks: Mutex<HashMap<String, Vec<Arc<dyn Fn(&Value) + Send + Sync + 'static>>>>,
1138}
1139
1140impl WebsocketApi {
1141 #[must_use]
1142 pub fn new(
1163 configuration: ConfigurationWebsocketApi,
1164 connection_pool: Vec<Arc<WebsocketConnection>>,
1165 ) -> Arc<Self> {
1166 let agent_clone = configuration.agent.clone();
1167 let user_agent_clone = configuration.user_agent.clone();
1168 let common = WebsocketCommon::new(
1169 connection_pool,
1170 configuration.mode.clone(),
1171 usize::try_from(configuration.reconnect_delay)
1172 .expect("reconnect_delay should fit in usize"),
1173 agent_clone,
1174 Some(user_agent_clone),
1175 );
1176
1177 Arc::new(Self {
1178 common: Arc::clone(&common),
1179 configuration,
1180 is_connecting: Arc::new(Mutex::new(false)),
1181 stream_callbacks: Mutex::new(HashMap::new()),
1182 })
1183 }
1184
1185 pub async fn connect(self: Arc<Self>) -> Result<(), WebsocketError> {
1207 if self.common.is_connected(None).await {
1208 info!("WebSocket connection already established");
1209 return Ok(());
1210 }
1211
1212 {
1213 let mut flag = self.is_connecting.lock().await;
1214 if *flag {
1215 info!("Already connecting...");
1216 return Ok(());
1217 }
1218 *flag = true;
1219 }
1220
1221 let url = self.prepare_url(self.configuration.ws_url.as_deref().unwrap_or_default());
1222
1223 let handler: Arc<dyn WebsocketHandler> = self.clone();
1224 for slot in &self.common.connection_pool {
1225 slot.set_handler(handler.clone()).await;
1226 }
1227
1228 let result = select! {
1229 () = sleep(Duration::from_secs(10)) => Err(WebsocketError::Timeout),
1230 r = self.common.clone().connect_pool(&url) => r,
1231 };
1232
1233 {
1234 let mut flag = self.is_connecting.lock().await;
1235 *flag = false;
1236 }
1237
1238 result
1239 }
1240
1241 pub async fn disconnect(&self) -> Result<(), WebsocketError> {
1254 self.common.disconnect().await
1255 }
1256
1257 pub async fn is_connected(&self) -> bool {
1263 self.common.is_connected(None).await
1264 }
1265
1266 pub async fn ping_server(&self) {
1271 self.common.ping_server().await;
1272 }
1273
1274 pub async fn send_message<R>(
1304 &self,
1305 method: &str,
1306 mut payload: BTreeMap<String, Value>,
1307 options: WebsocketMessageSendOptions,
1308 ) -> Result<WebsocketApiResponse<R>, WebsocketError>
1309 where
1310 R: DeserializeOwned + Send + Sync + 'static,
1311 {
1312 if !self.common.is_connected(None).await {
1313 return Err(WebsocketError::NotConnected);
1314 }
1315
1316 let id = payload
1317 .get("id")
1318 .and_then(Value::as_str)
1319 .filter(|s| ID_REGEX.is_match(s))
1320 .map_or_else(random_string, String::from);
1321
1322 payload.remove("id");
1323
1324 let mut params = remove_empty_value(payload.into_iter());
1325 if options.with_api_key || options.is_signed {
1326 params.insert(
1327 "apiKey".into(),
1328 Value::String(
1329 self.configuration
1330 .api_key
1331 .clone()
1332 .expect("API key must be set"),
1333 ),
1334 );
1335 }
1336 if options.is_signed {
1337 let ts = get_timestamp();
1338 let ts_i64 = i64::try_from(ts).map_err(|e| WebsocketError::Protocol(e.to_string()))?;
1339 params.insert(
1340 "timestamp".into(),
1341 Value::Number(serde_json::Number::from(ts_i64)),
1342 );
1343 let mut sorted_params = sort_object_params(¶ms);
1344 let sig = self
1345 .configuration
1346 .signature_gen
1347 .get_signature(&sorted_params)
1348 .map_err(|e| WebsocketError::Protocol(e.to_string()))?;
1349 sorted_params.insert("signature".into(), Value::String(sig));
1350 params = sorted_params.into_iter().collect();
1351 }
1352
1353 let request = json!({
1354 "id": id,
1355 "method": method,
1356 "params": params,
1357 });
1358 debug!("Sending message to WebSocket API: {:?}", request);
1359
1360 let timeout = Duration::from_millis(self.configuration.timeout);
1361 let maybe_rx = self
1362 .common
1363 .send(
1364 serde_json::to_string(&request).unwrap(),
1365 Some(id.clone()),
1366 true,
1367 timeout,
1368 None,
1369 )
1370 .await?;
1371
1372 let msg: Value = if let Some(rx) = maybe_rx {
1373 rx.await.unwrap_or(Err(WebsocketError::Timeout))?
1374 } else {
1375 return Err(WebsocketError::NoResponse);
1376 };
1377
1378 let raw = msg
1379 .get("result")
1380 .or_else(|| msg.get("response"))
1381 .cloned()
1382 .unwrap_or(Value::Null);
1383
1384 let rate_limits = msg
1385 .get("rateLimits")
1386 .and_then(Value::as_array)
1387 .map(|arr| {
1388 arr.iter()
1389 .filter_map(|v| serde_json::from_value(v.clone()).ok())
1390 .collect()
1391 })
1392 .unwrap_or_default();
1393
1394 Ok(WebsocketApiResponse {
1395 raw,
1396 rate_limits,
1397 _marker: PhantomData,
1398 })
1399 }
1400
1401 fn prepare_url(&self, ws_url: &str) -> String {
1416 let mut url = ws_url.to_string();
1417
1418 let time_unit = match &self.configuration.time_unit {
1419 Some(u) => u.to_string(),
1420 None => return url,
1421 };
1422
1423 match validate_time_unit(&time_unit) {
1424 Ok(Some(validated)) => {
1425 let sep = if url.contains('?') { '&' } else { '?' };
1426 url.push(sep);
1427 url.push_str("timeUnit=");
1428 url.push_str(validated);
1429 }
1430 Ok(None) => {}
1431 Err(e) => {
1432 error!("Invalid time unit provided: {:?}", e);
1433 }
1434 }
1435
1436 url
1437 }
1438}
1439
1440#[async_trait]
1441impl WebsocketHandler for WebsocketApi {
1442 async fn on_open(&self, _url: String, _connection: Arc<WebsocketConnection>) {}
1458
1459 async fn on_message(&self, data: String, connection: Arc<WebsocketConnection>) {
1479 let msg: Value = match serde_json::from_str(&data) {
1480 Ok(v) => v,
1481 Err(err) => {
1482 error!("Failed to parse WebSocket message {} – {}", data, err);
1483 return;
1484 }
1485 };
1486
1487 if let Some(id) = msg.get("id").and_then(Value::as_str) {
1488 let maybe_sender = {
1489 let mut conn_state = connection.state.lock().await;
1490 conn_state.pending_requests.remove(id)
1491 };
1492
1493 if let Some(PendingRequest { completion }) = maybe_sender {
1494 connection.drain_notify.notify_one();
1495 let status = msg.get("status").and_then(Value::as_u64).unwrap_or(200);
1496 if status >= 400 {
1497 let error_map = msg
1498 .get("error")
1499 .and_then(Value::as_object)
1500 .unwrap_or(&serde_json::Map::new())
1501 .clone();
1502
1503 let code = error_map
1504 .get("code")
1505 .and_then(Value::as_i64)
1506 .unwrap_or(status.try_into().unwrap());
1507
1508 let message = error_map
1509 .get("msg")
1510 .and_then(Value::as_str)
1511 .unwrap_or("Unknown error")
1512 .to_string();
1513
1514 let _ = completion.send(Err(WebsocketError::ResponseError { code, message }));
1515 } else {
1516 let _ = completion.send(Ok(msg.clone()));
1517 }
1518 }
1519
1520 return;
1521 }
1522
1523 if let Some(event) = msg.get("event") {
1524 if event.get("e").is_some() {
1525 for callbacks in self.stream_callbacks.lock().await.values() {
1526 for callback in callbacks {
1527 callback(event);
1528 }
1529 }
1530
1531 return;
1532 }
1533 }
1534
1535 warn!(
1536 "Received response for unknown or timed-out request: {}",
1537 data
1538 );
1539 }
1540
1541 async fn get_reconnect_url(
1552 &self,
1553 default_url: String,
1554 _connection: Arc<WebsocketConnection>,
1555 ) -> String {
1556 default_url
1557 }
1558}
1559
1560pub struct WebsocketStreams {
1561 pub common: Arc<WebsocketCommon>,
1562 is_connecting: Mutex<bool>,
1563 connection_streams: Mutex<HashMap<String, Arc<WebsocketConnection>>>,
1564 configuration: ConfigurationWebsocketStreams,
1565}
1566
1567impl WebsocketStreams {
1568 #[must_use]
1583 pub fn new(
1584 configuration: ConfigurationWebsocketStreams,
1585 connection_pool: Vec<Arc<WebsocketConnection>>,
1586 ) -> Arc<Self> {
1587 let agent_clone = configuration.agent.clone();
1588 let user_agent_clone = configuration.user_agent.clone();
1589 let common = WebsocketCommon::new(
1590 connection_pool,
1591 configuration.mode.clone(),
1592 usize::try_from(configuration.reconnect_delay)
1593 .expect("reconnect_delay should fit in usize"),
1594 agent_clone,
1595 Some(user_agent_clone),
1596 );
1597 Arc::new(Self {
1598 common,
1599 is_connecting: Mutex::new(false),
1600 connection_streams: Mutex::new(HashMap::new()),
1601 configuration,
1602 })
1603 }
1604
1605 pub async fn connect(self: Arc<Self>, streams: Vec<String>) -> Result<(), WebsocketError> {
1622 if self.common.is_connected(None).await {
1623 info!("WebSocket connection already established");
1624 return Ok(());
1625 }
1626
1627 {
1628 let mut flag = self.is_connecting.lock().await;
1629 if *flag {
1630 info!("Already connecting...");
1631 return Ok(());
1632 }
1633 *flag = true;
1634 }
1635
1636 let url = self.prepare_url(&streams);
1637
1638 let handler: Arc<dyn WebsocketHandler> = self.clone();
1639 for conn in &self.common.connection_pool {
1640 conn.set_handler(handler.clone()).await;
1641 }
1642
1643 let connect_res = select! {
1644 () = sleep(Duration::from_secs(10)) => Err(WebsocketError::Timeout),
1645 r = self.common.clone().connect_pool(&url) => r,
1646 };
1647
1648 {
1649 let mut flag = self.is_connecting.lock().await;
1650 *flag = false;
1651 }
1652
1653 connect_res
1654 }
1655
1656 pub async fn disconnect(&self) -> Result<(), WebsocketError> {
1672 for connection in &self.common.connection_pool {
1673 let mut conn_state = connection.state.lock().await;
1674 conn_state.stream_callbacks.clear();
1675 conn_state.pending_subscriptions.clear();
1676 }
1677 self.connection_streams.lock().await.clear();
1678 self.common.disconnect().await
1679 }
1680
1681 pub async fn is_connected(&self) -> bool {
1687 self.common.is_connected(None).await
1688 }
1689
1690 pub async fn ping_server(&self) {
1699 self.common.ping_server().await;
1700 }
1701
1702 pub async fn subscribe(self: Arc<Self>, streams: Vec<String>, id: Option<String>) {
1721 let streams: Vec<String> = {
1722 let map = self.connection_streams.lock().await;
1723 streams
1724 .into_iter()
1725 .filter(|s| !map.contains_key(s))
1726 .collect()
1727 };
1728 let connection_streams = self.handle_stream_assignment(streams).await;
1729
1730 for (conn, streams) in connection_streams {
1731 if !self.common.is_connected(Some(&conn)).await {
1732 info!(
1733 "Connection is not ready. Queuing subscription for streams: {:?}",
1734 streams
1735 );
1736 let mut conn_state = conn.state.lock().await;
1737 conn_state.pending_subscriptions.extend(streams.clone());
1738 continue;
1739 }
1740 self.send_subscription_payload(&conn, &streams, id.clone());
1741 }
1742 }
1743
1744 pub async fn unsubscribe(&self, streams: Vec<String>, id: Option<String>) {
1772 let request_id = id
1773 .filter(|s| ID_REGEX.is_match(s))
1774 .unwrap_or_else(random_string);
1775
1776 for stream in streams {
1777 let maybe_conn = { self.connection_streams.lock().await.get(&stream).cloned() };
1778
1779 let conn = if let Some(c) = maybe_conn {
1780 if !self.common.is_connected(Some(&c)).await {
1781 warn!(
1782 "Stream {} not associated with an active connection.",
1783 stream
1784 );
1785 continue;
1786 }
1787 c
1788 } else {
1789 warn!("Stream {} was not subscribed.", stream);
1790 continue;
1791 };
1792
1793 let callbacks = {
1794 let conn_state = conn.state.lock().await;
1795 conn_state
1796 .stream_callbacks
1797 .get(&stream)
1798 .is_none_or(std::vec::Vec::is_empty)
1799 };
1800
1801 if !callbacks {
1802 continue;
1803 }
1804
1805 let payload = json!({
1806 "method": "UNSUBSCRIBE",
1807 "params": [stream.clone()],
1808 "id": request_id,
1809 });
1810
1811 info!("UNSUBSCRIBE → {:?}", payload);
1812
1813 let common = Arc::clone(&self.common);
1814 let conn_clone = Arc::clone(&conn);
1815 let msg = serde_json::to_string(&payload).unwrap();
1816 spawn(async move {
1817 let _ = common
1818 .send(msg, None, false, Duration::ZERO, Some(conn_clone))
1819 .await;
1820 });
1821
1822 {
1823 let mut connection_streams = self.connection_streams.lock().await;
1824 connection_streams.remove(&stream);
1825 }
1826 {
1827 let mut conn_state = conn.state.lock().await;
1828 conn_state.stream_callbacks.remove(&stream);
1829 }
1830 }
1831 }
1832
1833 pub async fn is_subscribed(&self, stream: &str) -> bool {
1847 self.connection_streams.lock().await.contains_key(stream)
1848 }
1849
1850 fn prepare_url(&self, streams: &[String]) -> String {
1866 let mut url = format!(
1867 "{}/stream?streams={}",
1868 self.configuration.ws_url.as_deref().unwrap_or(""),
1869 streams.join("/")
1870 );
1871
1872 let time_unit = match &self.configuration.time_unit {
1873 Some(u) => u.to_string(),
1874 None => return url,
1875 };
1876
1877 match validate_time_unit(&time_unit) {
1878 Ok(Some(validated)) => {
1879 let sep = if url.contains('?') { '&' } else { '?' };
1880 url.push(sep);
1881 url.push_str("timeUnit=");
1882 url.push_str(validated);
1883 }
1884 Ok(None) => {}
1885 Err(e) => {
1886 error!("Invalid time unit provided: {:?}", e);
1887 }
1888 }
1889
1890 url
1891 }
1892
1893 async fn handle_stream_assignment(
1911 &self,
1912 streams: Vec<String>,
1913 ) -> Vec<(Arc<WebsocketConnection>, Vec<String>)> {
1914 let mut connection_streams: Vec<(String, Arc<WebsocketConnection>)> = Vec::new();
1915
1916 for stream in streams {
1917 let mut conn_opt = {
1918 let map = self.connection_streams.lock().await;
1919 map.get(&stream).cloned()
1920 };
1921
1922 let need_new = if let Some(conn) = &conn_opt {
1923 let state = conn.state.lock().await;
1924 state.close_initiated || state.reconnection_pending
1925 } else {
1926 true
1927 };
1928
1929 if need_new {
1930 match self.common.get_connection(true).await {
1931 Ok(new_conn) => {
1932 let mut map = self.connection_streams.lock().await;
1933 map.insert(stream.clone(), new_conn.clone());
1934 conn_opt = Some(new_conn);
1935 }
1936 Err(err) => {
1937 warn!(
1938 "No available WebSocket connection to subscribe stream `{}`: {:?}",
1939 stream, err
1940 );
1941 continue;
1942 }
1943 }
1944 }
1945
1946 if let Some(conn) = conn_opt {
1947 {
1948 let mut conn_state = conn.state.lock().await;
1949 conn_state
1950 .stream_callbacks
1951 .entry(stream.clone())
1952 .or_default();
1953 }
1954 connection_streams.push((stream.clone(), conn));
1955 }
1956 }
1957
1958 let mut groups: Vec<(Arc<WebsocketConnection>, Vec<String>)> = Vec::new();
1959 for (stream, conn) in connection_streams {
1960 if let Some((_, vec)) = groups.iter_mut().find(|(c, _)| Arc::ptr_eq(c, &conn)) {
1961 vec.push(stream);
1962 } else {
1963 groups.push((conn, vec![stream]));
1964 }
1965 }
1966
1967 groups
1968 }
1969
1970 fn send_subscription_payload(
1983 &self,
1984 connection: &Arc<WebsocketConnection>,
1985 streams: &Vec<String>,
1986 id: Option<String>,
1987 ) {
1988 let request_id = id
1989 .filter(|s| ID_REGEX.is_match(s))
1990 .unwrap_or_else(random_string);
1991
1992 let payload = json!({
1993 "method": "SUBSCRIBE",
1994 "params": streams,
1995 "id": request_id,
1996 });
1997
1998 info!("SUBSCRIBE → {:?}", payload);
1999
2000 let common = Arc::clone(&self.common);
2001 let msg = match serde_json::to_string(&payload) {
2002 Ok(s) => s,
2003 Err(e) => {
2004 error!("Failed to serialize SUBSCRIBE payload: {}", e);
2005 return;
2006 }
2007 };
2008 let conn_clone = Arc::clone(connection);
2009
2010 spawn(async move {
2011 let _ = common
2012 .send(msg, None, false, Duration::ZERO, Some(conn_clone))
2013 .await;
2014 });
2015 }
2016}
2017
2018#[async_trait]
2019impl WebsocketHandler for WebsocketStreams {
2020 async fn on_open(&self, _url: String, connection: Arc<WebsocketConnection>) {
2037 let pending_subs: Vec<String> = {
2038 let mut conn_state = connection.state.lock().await;
2039 take(&mut conn_state.pending_subscriptions)
2040 .into_iter()
2041 .collect()
2042 };
2043
2044 if !pending_subs.is_empty() {
2045 info!("Processing queued subscriptions for connection");
2046 self.send_subscription_payload(&connection, &pending_subs, None);
2047 }
2048 }
2049
2050 async fn on_message(&self, data: String, connection: Arc<WebsocketConnection>) {
2067 let msg: Value = match serde_json::from_str(&data) {
2068 Ok(v) => v,
2069 Err(err) => {
2070 error!(
2071 "Failed to parse WebSocket stream message {} – {}",
2072 data, err
2073 );
2074 return;
2075 }
2076 };
2077
2078 let (stream_name, payload) = match (
2079 msg.get("stream").and_then(Value::as_str),
2080 msg.get("data").cloned(),
2081 ) {
2082 (Some(name), Some(data)) => (name.to_string(), data),
2083 _ => return,
2084 };
2085
2086 let callbacks = {
2087 let conn_state = connection.state.lock().await;
2088 conn_state
2089 .stream_callbacks
2090 .get(&stream_name)
2091 .cloned()
2092 .unwrap_or_else(Vec::new)
2093 };
2094
2095 for callback in callbacks {
2096 callback(&payload);
2097 }
2098 }
2099
2100 async fn get_reconnect_url(
2111 &self,
2112 _default_url: String,
2113 connection: Arc<WebsocketConnection>,
2114 ) -> String {
2115 let connection_streams = self.connection_streams.lock().await;
2116 let reconnect_streams = connection_streams
2117 .iter()
2118 .filter_map(|(stream, conn_arc)| {
2119 if Arc::ptr_eq(conn_arc, &connection) {
2120 Some(stream.clone())
2121 } else {
2122 None
2123 }
2124 })
2125 .collect::<Vec<_>>();
2126 self.prepare_url(&reconnect_streams)
2127 }
2128}
2129
2130pub struct WebsocketStream<T> {
2131 websocket_base: WebsocketBase,
2132 stream_or_id: String,
2133 callback: Mutex<Option<Arc<dyn Fn(&Value) + Send + Sync>>>,
2134 pub id: Option<String>,
2135 _phantom: PhantomData<T>,
2136}
2137
2138impl<T> WebsocketStream<T>
2139where
2140 T: DeserializeOwned + Send + 'static,
2141{
2142 async fn on<F>(&self, event: &str, callback_fn: F)
2163 where
2164 F: Fn(T) + Send + Sync + 'static,
2165 {
2166 if event != "message" {
2167 return;
2168 }
2169
2170 let cb_wrapper: Arc<dyn Fn(&Value) + Send + Sync> =
2171 Arc::new(
2172 move |v: &Value| match serde_json::from_value::<T>(v.clone()) {
2173 Ok(data) => callback_fn(data),
2174 Err(e) => error!("Failed to deserialize stream payload: {:?}", e),
2175 },
2176 );
2177
2178 {
2179 let mut guard = self.callback.lock().await;
2180 *guard = Some(cb_wrapper.clone());
2181 }
2182
2183 match &self.websocket_base {
2184 WebsocketBase::WebsocketStreams(ws_streams) => {
2185 let conn = {
2186 let map = ws_streams.connection_streams.lock().await;
2187 map.get(&self.stream_or_id)
2188 .cloned()
2189 .expect("stream must be subscribed")
2190 };
2191
2192 {
2193 let mut conn_state = conn.state.lock().await;
2194 let entry = conn_state
2195 .stream_callbacks
2196 .entry(self.stream_or_id.clone())
2197 .or_default();
2198
2199 if !entry
2200 .iter()
2201 .any(|existing| Arc::ptr_eq(existing, &cb_wrapper))
2202 {
2203 entry.push(cb_wrapper);
2204 }
2205 }
2206 }
2207 WebsocketBase::WebsocketApi(ws_api) => {
2208 let mut stream_callbacks = ws_api.stream_callbacks.lock().await;
2209 let entry = stream_callbacks
2210 .entry(self.stream_or_id.clone())
2211 .or_default();
2212
2213 if !entry
2214 .iter()
2215 .any(|existing| Arc::ptr_eq(existing, &cb_wrapper))
2216 {
2217 entry.push(cb_wrapper);
2218 }
2219 }
2220 }
2221 }
2222
2223 pub fn on_message<F>(self: &Arc<Self>, callback_fn: F)
2242 where
2243 T: Send + Sync,
2244 F: Fn(T) + Send + Sync + 'static,
2245 {
2246 let handler: Arc<Self> = Arc::clone(self);
2247
2248 std::thread::spawn(move || {
2249 let rt = tokio::runtime::Builder::new_current_thread()
2250 .enable_all()
2251 .build()
2252 .expect("failed to build Tokio runtime");
2253
2254 rt.block_on(handler.on("message", callback_fn));
2255 })
2256 .join()
2257 .expect("on_message thread panicked");
2258 }
2259
2260 pub async fn unsubscribe(&self) {
2275 let maybe_cb = {
2276 let mut guard = self.callback.lock().await;
2277 guard.take()
2278 };
2279
2280 if let Some(cb) = maybe_cb {
2281 match &self.websocket_base {
2282 WebsocketBase::WebsocketStreams(ws_streams) => {
2283 let conn = {
2284 let map = ws_streams.connection_streams.lock().await;
2285 map.get(&self.stream_or_id)
2286 .cloned()
2287 .expect("stream must have been subscribed")
2288 };
2289
2290 {
2291 let mut conn_state = conn.state.lock().await;
2292 if let Some(list) = conn_state.stream_callbacks.get_mut(&self.stream_or_id)
2293 {
2294 list.retain(|existing| !Arc::ptr_eq(existing, &cb));
2295 }
2296 }
2297
2298 let stream = self.stream_or_id.clone();
2299 let id = self.id.clone();
2300 let websocket_streams_base = Arc::clone(ws_streams);
2301 spawn(async move {
2302 websocket_streams_base.unsubscribe(vec![stream], id).await;
2303 });
2304 }
2305 WebsocketBase::WebsocketApi(ws_api) => {
2306 let mut stream_callbacks = ws_api.stream_callbacks.lock().await;
2307 if let Some(list) = stream_callbacks.get_mut(&self.stream_or_id) {
2308 list.retain(|existing| !Arc::ptr_eq(existing, &cb));
2309 }
2310 }
2311 }
2312 }
2313 }
2314}
2315
2316pub async fn create_stream_handler<T>(
2317 websocket_base: WebsocketBase,
2318 stream_or_id: String,
2319 id: Option<String>,
2320) -> Arc<WebsocketStream<T>>
2321where
2322 T: DeserializeOwned + Send + 'static,
2323{
2324 match &websocket_base {
2325 WebsocketBase::WebsocketStreams(ws_streams) => {
2326 ws_streams
2327 .clone()
2328 .subscribe(vec![stream_or_id.clone()], id.clone())
2329 .await;
2330 }
2331 WebsocketBase::WebsocketApi(_) => {}
2332 }
2333
2334 Arc::new(WebsocketStream {
2335 websocket_base,
2336 stream_or_id,
2337 id,
2338 callback: Mutex::new(None),
2339 _phantom: PhantomData,
2340 })
2341}
2342
2343#[cfg(test)]
2344mod tests {
2345 use crate::TOKIO_SHARED_RT;
2346 use crate::common::utils::{SignatureGenerator, build_user_agent};
2347 use crate::common::websocket::{
2348 PendingRequest, ReconnectEntry, WebsocketApi, WebsocketBase, WebsocketCommon,
2349 WebsocketConnection, WebsocketEvent, WebsocketEventEmitter, WebsocketHandler,
2350 WebsocketMessageSendOptions, WebsocketMode, WebsocketStream, WebsocketStreams,
2351 create_stream_handler,
2352 };
2353 use crate::config::{ConfigurationWebsocketApi, ConfigurationWebsocketStreams, PrivateKey};
2354 use crate::errors::WebsocketError;
2355 use crate::models::TimeUnit;
2356 use async_trait::async_trait;
2357 use futures::{SinkExt, StreamExt};
2358 use http::header::USER_AGENT;
2359 use regex::Regex;
2360 use serde_json::{Value, json};
2361 use std::collections::{BTreeMap, HashSet};
2362 use std::marker::PhantomData;
2363 use std::net::SocketAddr;
2364 use std::sync::{
2365 Arc,
2366 atomic::{AtomicBool, AtomicUsize, Ordering},
2367 };
2368 use tokio::net::TcpListener;
2369 use tokio::sync::{Mutex, mpsc::unbounded_channel, oneshot};
2370 use tokio::time::{Duration, advance, pause, resume, sleep, timeout};
2371 use tokio_tungstenite::{accept_async, accept_hdr_async, tungstenite, tungstenite::Message};
2372 use tungstenite::handshake::server::Request;
2373
2374 fn subscribe_events(common: &WebsocketCommon) -> Arc<Mutex<Vec<WebsocketEvent>>> {
2375 let events = Arc::new(Mutex::new(Vec::new()));
2376 let events_clone = events.clone();
2377 common.events.subscribe(move |event| {
2378 let events_clone = events_clone.clone();
2379 tokio::spawn(async move {
2380 events_clone.lock().await.push(event);
2381 });
2382 });
2383 events
2384 }
2385
2386 async fn create_connection(
2387 id: &str,
2388 has_writer: bool,
2389 reconnection_pending: bool,
2390 renewal_pending: bool,
2391 close_initiated: bool,
2392 ) -> Arc<WebsocketConnection> {
2393 let conn = WebsocketConnection::new(id);
2394 let mut st = conn.state.lock().await;
2395 st.reconnection_pending = reconnection_pending;
2396 st.renewal_pending = renewal_pending;
2397 st.close_initiated = close_initiated;
2398 if has_writer {
2399 let (tx, _) = unbounded_channel::<Message>();
2400 st.ws_write_tx = Some(tx);
2401 } else {
2402 st.ws_write_tx = None;
2403 }
2404 drop(st);
2405 conn
2406 }
2407
2408 fn create_websocket_api(time_unit: Option<TimeUnit>) -> Arc<WebsocketApi> {
2409 let sig_gen = SignatureGenerator::new(
2410 Some("api_secret".into()),
2411 None::<PrivateKey>,
2412 None::<String>,
2413 );
2414 let config = ConfigurationWebsocketApi {
2415 api_key: Some("api_key".into()),
2416 api_secret: Some("api_secret".into()),
2417 private_key: None,
2418 private_key_passphrase: None,
2419 ws_url: Some("wss://example.com".into()),
2420 mode: WebsocketMode::Single,
2421 reconnect_delay: 1000,
2422 signature_gen: sig_gen,
2423 timeout: 500,
2424 time_unit,
2425 agent: None,
2426 user_agent: build_user_agent("product"),
2427 };
2428 let conn = WebsocketConnection::new("c1");
2429 WebsocketApi::new(config, vec![conn])
2430 }
2431
2432 fn create_websocket_streams(
2433 ws_url: Option<&str>,
2434 conns: Option<Vec<Arc<WebsocketConnection>>>,
2435 ) -> Arc<WebsocketStreams> {
2436 let mut connections: Vec<Arc<WebsocketConnection>> = vec![];
2437 if conns.is_none() {
2438 connections.push(WebsocketConnection::new("c1"));
2439 connections.push(WebsocketConnection::new("c2"));
2440 } else {
2441 connections = conns.expect("Expected connections to be set");
2442 }
2443 let config = ConfigurationWebsocketStreams {
2444 ws_url: Some(ws_url.unwrap_or("example.com").to_string()),
2445 mode: WebsocketMode::Single,
2446 reconnect_delay: 500,
2447 time_unit: None,
2448 agent: None,
2449 user_agent: build_user_agent("product"),
2450 };
2451 WebsocketStreams::new(config, connections)
2452 }
2453
2454 mod event_emitter {
2455 use super::*;
2456
2457 #[test]
2458 fn event_emitter_subscribe_and_emit() {
2459 TOKIO_SHARED_RT.block_on(async {
2460 let emitter = WebsocketEventEmitter::new();
2461 let (tx, rx) = oneshot::channel();
2462 let tx = Arc::new(std::sync::Mutex::new(Some(tx)));
2463 let tx_clone = tx.clone();
2464 let _sub = emitter.subscribe(move |event| {
2465 if let Some(sender) = tx_clone.lock().unwrap().take() {
2466 let _ = sender.send(event);
2467 }
2468 });
2469 emitter.emit(WebsocketEvent::Open);
2470 let received = timeout(Duration::from_millis(100), rx)
2471 .await
2472 .expect("timed out");
2473 assert_eq!(received, Ok(WebsocketEvent::Open));
2474 });
2475 }
2476 }
2477
2478 mod websocket_common {
2479 use super::*;
2480
2481 mod initialisation {
2482 use super::*;
2483
2484 #[test]
2485 fn single_mode() {
2486 TOKIO_SHARED_RT.block_on(async {
2487 let common = WebsocketCommon::new(vec![], WebsocketMode::Single, 0, None, None);
2488 assert_eq!(common.connection_pool.len(), 1);
2489 });
2490 }
2491
2492 #[test]
2493 fn pool_mode() {
2494 TOKIO_SHARED_RT.block_on(async {
2495 let common =
2496 WebsocketCommon::new(vec![], WebsocketMode::Pool(3), 0, None, None);
2497 assert_eq!(common.connection_pool.len(), 3);
2498 });
2499 }
2500 }
2501
2502 mod spawn_reconnect_loop {
2503 use super::*;
2504
2505 #[test]
2506 fn successful_reconnect_entry_triggers_init_connect() {
2507 TOKIO_SHARED_RT.block_on(async {
2508 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
2509 let addr = listener.local_addr().unwrap();
2510 tokio::spawn(async move {
2511 if let Ok((stream, _)) = listener.accept().await {
2512 let mut ws = accept_async(stream).await.unwrap();
2513 let _ = ws.close(None).await;
2514 }
2515 });
2516
2517 let conn = WebsocketConnection::new("c1");
2518 let common = WebsocketCommon::new(
2519 vec![conn.clone()],
2520 WebsocketMode::Single,
2521 10,
2522 None,
2523 None,
2524 );
2525 let url = format!("ws://{addr}");
2526 common
2527 .reconnect_tx
2528 .send(ReconnectEntry {
2529 connection_id: "c1".into(),
2530 url: url.clone(),
2531 is_renewal: false,
2532 })
2533 .await
2534 .unwrap();
2535
2536 sleep(Duration::from_secs(2)).await;
2537
2538 let st = conn.state.lock().await;
2539 assert!(st.ws_write_tx.is_some());
2540 });
2541 }
2542
2543 #[test]
2544 fn reconnect_entry_with_unknown_id_is_ignored() {
2545 TOKIO_SHARED_RT.block_on(async {
2546 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
2547 let addr = listener.local_addr().unwrap();
2548 tokio::spawn(async move {
2549 if let Ok((stream, _)) = listener.accept().await {
2550 let mut ws = accept_async(stream).await.unwrap();
2551 let _ = ws.close(None).await;
2552 }
2553 });
2554
2555 let conn = WebsocketConnection::new("c1");
2556 let common = WebsocketCommon::new(
2557 vec![conn.clone()],
2558 WebsocketMode::Single,
2559 5,
2560 None,
2561 None,
2562 );
2563 let url = format!("ws://{addr}");
2564 common
2565 .reconnect_tx
2566 .send(ReconnectEntry {
2567 connection_id: "other".into(),
2568 url,
2569 is_renewal: false,
2570 })
2571 .await
2572 .unwrap();
2573
2574 sleep(Duration::from_secs(1)).await;
2575
2576 let st = conn.state.lock().await;
2577 assert!(st.ws_write_tx.is_none());
2578 });
2579 }
2580
2581 #[test]
2582 fn renewal_entries_bypass_initial_delay() {
2583 TOKIO_SHARED_RT.block_on(async {
2584 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
2585 let addr = listener.local_addr().unwrap();
2586 tokio::spawn(async move {
2587 if let Ok((stream, _)) = listener.accept().await {
2588 let mut ws = accept_async(stream).await.unwrap();
2589 let _ = ws.close(None).await;
2590 }
2591 });
2592
2593 let conn = WebsocketConnection::new("renew");
2594 let common = WebsocketCommon::new(
2595 vec![conn.clone()],
2596 WebsocketMode::Single,
2597 200,
2598 None,
2599 None,
2600 );
2601 let url = format!("ws://{addr}");
2602 common
2603 .reconnect_tx
2604 .send(ReconnectEntry {
2605 connection_id: "renew".into(),
2606 url: url.clone(),
2607 is_renewal: true,
2608 })
2609 .await
2610 .unwrap();
2611
2612 sleep(Duration::from_secs(2)).await;
2613
2614 let st = conn.state.lock().await;
2615
2616 assert!(st.ws_write_tx.is_some());
2617 });
2618 }
2619
2620 #[test]
2621 fn non_renewal_entries_respect_initial_delay() {
2622 TOKIO_SHARED_RT.block_on(async {
2623 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
2624 let addr = listener.local_addr().unwrap();
2625 tokio::spawn(async move {
2626 if let Ok((stream, _)) = listener.accept().await {
2627 let mut ws = accept_async(stream).await.unwrap();
2628 let _ = ws.close(None).await;
2629 }
2630 });
2631
2632 let conn = WebsocketConnection::new("nonrenew");
2633 let common = WebsocketCommon::new(
2634 vec![conn.clone()],
2635 WebsocketMode::Single,
2636 200,
2637 None,
2638 None,
2639 );
2640 let url = format!("ws://{addr}");
2641 common
2642 .reconnect_tx
2643 .send(ReconnectEntry {
2644 connection_id: "nonrenew".into(),
2645 url: url.clone(),
2646 is_renewal: false,
2647 })
2648 .await
2649 .unwrap();
2650
2651 sleep(Duration::from_millis(100)).await;
2652 assert!(conn.state.lock().await.ws_write_tx.is_none());
2653
2654 sleep(Duration::from_secs(2)).await;
2655
2656 assert!(conn.state.lock().await.ws_write_tx.is_some());
2657 });
2658 }
2659 }
2660
2661 mod spawn_renewal_loop {
2662 use super::*;
2663
2664 #[tokio::test]
2665 async fn scheduling_renewal_does_not_panic_for_known_connection() {
2666 pause();
2667
2668 let conn = WebsocketConnection::new("known");
2669 let common =
2670 WebsocketCommon::new(vec![conn.clone()], WebsocketMode::Single, 0, None, None);
2671 let url = "wss://example".to_string();
2672 common
2673 .renewal_tx
2674 .send((conn.id.clone(), url))
2675 .await
2676 .unwrap();
2677 advance(Duration::from_secs(23 * 60 * 60 + 1)).await;
2678
2679 resume();
2680 }
2681
2682 #[tokio::test]
2683 async fn scheduling_renewal_ignored_for_unknown_connection() {
2684 pause();
2685
2686 let conn = WebsocketConnection::new("c1");
2687 let common =
2688 WebsocketCommon::new(vec![conn.clone()], WebsocketMode::Single, 0, None, None);
2689 common
2690 .renewal_tx
2691 .send(("other".into(), "u".into()))
2692 .await
2693 .unwrap();
2694 advance(Duration::from_secs(23 * 60 * 60 + 1)).await;
2695
2696 resume();
2697 }
2698 }
2699
2700 mod is_connection_ready {
2701 use super::*;
2702
2703 #[test]
2704 fn is_connection_ready() {
2705 TOKIO_SHARED_RT.block_on(async {
2706 let conn = WebsocketConnection::new("c1");
2707 let common = WebsocketCommon::new(
2708 vec![conn.clone()],
2709 WebsocketMode::Single,
2710 0,
2711 None,
2712 None,
2713 );
2714 assert!(!common.is_connection_ready(&conn, false).await);
2715 assert!(common.is_connection_ready(&conn, true).await);
2716 });
2717 }
2718
2719 #[test]
2720 fn connection_ready_basic() {
2721 TOKIO_SHARED_RT.block_on(async {
2722 let conn = create_connection("c1", true, false, false, false).await;
2723 let common = WebsocketCommon::new(
2724 vec![conn.clone()],
2725 WebsocketMode::Single,
2726 0,
2727 None,
2728 None,
2729 );
2730 assert!(common.is_connection_ready(&conn, false).await);
2731 });
2732 }
2733
2734 #[test]
2735 fn connection_not_ready_without_writer() {
2736 TOKIO_SHARED_RT.block_on(async {
2737 let conn = create_connection("c1", false, false, false, false).await;
2738 let common = WebsocketCommon::new(
2739 vec![conn.clone()],
2740 WebsocketMode::Single,
2741 0,
2742 None,
2743 None,
2744 );
2745 assert!(!common.is_connection_ready(&conn, false).await);
2746 assert!(common.is_connection_ready(&conn, true).await);
2747 });
2748 }
2749
2750 #[test]
2751 fn connection_not_ready_when_flagged() {
2752 TOKIO_SHARED_RT.block_on(async {
2753 let conn1 = create_connection("c1", true, true, false, false).await;
2754 let conn2 = create_connection("c2", true, false, true, false).await;
2755 let conn3 = create_connection("c3", true, false, false, true).await;
2756
2757 let common = WebsocketCommon::new(
2758 vec![conn1.clone(), conn2.clone(), conn3.clone()],
2759 WebsocketMode::Pool(3),
2760 0,
2761 None,
2762 None,
2763 );
2764
2765 assert!(!common.is_connection_ready(&conn1, false).await);
2766 assert!(!common.is_connection_ready(&conn2, false).await);
2767 assert!(!common.is_connection_ready(&conn3, false).await);
2768 });
2769 }
2770 }
2771
2772 mod is_connected {
2773 use super::*;
2774
2775 #[test]
2776 fn with_pool_various_connections() {
2777 TOKIO_SHARED_RT.block_on(async {
2778 let conn_a = create_connection("a", true, false, false, false).await;
2779 let conn_b = create_connection("b", false, false, false, false).await;
2780 let conn_c = create_connection("c", true, true, false, false).await;
2781 let pool = vec![conn_a.clone(), conn_b.clone(), conn_c.clone()];
2782 let common = WebsocketCommon::new(pool, WebsocketMode::Pool(3), 0, None, None);
2783
2784 assert!(common.is_connected(None).await);
2785 assert!(common.is_connected(Some(&conn_a)).await);
2786 assert!(!common.is_connected(Some(&conn_b)).await);
2787 assert!(!common.is_connected(Some(&conn_c)).await);
2788 });
2789 }
2790
2791 #[test]
2792 fn with_pool_all_bad_connections() {
2793 TOKIO_SHARED_RT.block_on(async {
2794 let bad1 = create_connection("c1", false, false, false, false).await;
2795 let bad2 = create_connection("c2", true, true, false, false).await;
2796 let bad3 = create_connection("c3", true, false, false, true).await;
2797 let common = WebsocketCommon::new(
2798 vec![bad1, bad2, bad3],
2799 WebsocketMode::Pool(3),
2800 0,
2801 None,
2802 None,
2803 );
2804
2805 assert!(!common.is_connected(None).await);
2806 });
2807 }
2808
2809 #[test]
2810 fn with_pool_ignore_close_initiated() {
2811 TOKIO_SHARED_RT.block_on(async {
2812 let good = create_connection("c1", true, false, false, false).await;
2813 let closed = create_connection("c2", true, false, false, true).await;
2814 let bad = create_connection("c3", false, false, false, false).await;
2815 let common = WebsocketCommon::new(
2816 vec![closed.clone(), good.clone(), bad.clone()],
2817 WebsocketMode::Pool(3),
2818 0,
2819 None,
2820 None,
2821 );
2822
2823 assert!(common.is_connected(None).await);
2824 assert!(!common.is_connected(Some(&closed)).await);
2825 });
2826 }
2827 }
2828
2829 mod get_connection {
2830 use super::*;
2831
2832 #[test]
2833 fn single_mode() {
2834 TOKIO_SHARED_RT.block_on(async {
2835 let common = WebsocketCommon::new(vec![], WebsocketMode::Single, 0, None, None);
2836 let conn = common
2837 .get_connection(false)
2838 .await
2839 .expect("should get connection");
2840 assert_eq!(conn.id, common.connection_pool[0].id);
2841 });
2842 }
2843
2844 #[test]
2845 fn pool_mode_not_ready() {
2846 TOKIO_SHARED_RT.block_on(async {
2847 let common =
2848 WebsocketCommon::new(vec![], WebsocketMode::Pool(2), 0, None, None);
2849 let result = common.get_connection(false).await;
2850 assert!(matches!(
2851 result,
2852 Err(crate::errors::WebsocketError::NotConnected)
2853 ));
2854 });
2855 }
2856
2857 #[test]
2858 fn pool_mode_with_ready() {
2859 TOKIO_SHARED_RT.block_on(async {
2860 let conn1 = WebsocketConnection::new("c1");
2861 let conn2 = WebsocketConnection::new("c2");
2862 let (tx1, _rx1) = unbounded_channel();
2863 {
2864 let mut s1 = conn1.state.lock().await;
2865 s1.ws_write_tx = Some(tx1);
2866 }
2867 let pool = vec![conn1.clone(), conn2.clone()];
2868 let common = WebsocketCommon::new(pool, WebsocketMode::Pool(2), 0, None, None);
2869 let result = common.get_connection(false).await;
2870 assert!(result.is_ok());
2871 let chosen = result.unwrap();
2872 assert_eq!(chosen.id, conn1.id);
2873 });
2874 }
2875 }
2876
2877 mod close_connection_gracefully {
2878 use super::*;
2879
2880 #[tokio::test]
2881 async fn waits_for_pending_requests_then_closes() {
2882 pause();
2883
2884 let conn = WebsocketConnection::new("c1");
2885 let (tx, mut rx) = unbounded_channel::<Message>();
2886 let (req_tx, _req_rx) = oneshot::channel();
2887 {
2888 let mut st = conn.state.lock().await;
2889 st.pending_requests
2890 .insert("r".to_string(), PendingRequest { completion: req_tx });
2891 }
2892 let common =
2893 WebsocketCommon::new(vec![conn.clone()], WebsocketMode::Single, 0, None, None);
2894 let close_fut = common.close_connection_gracefully(tx.clone(), conn.clone());
2895 advance(Duration::from_secs(1)).await;
2896 {
2897 let mut st = conn.state.lock().await;
2898 st.pending_requests.clear();
2899 }
2900 conn.drain_notify.notify_waiters();
2901 advance(Duration::from_secs(1)).await;
2902 close_fut.await.unwrap();
2903 match rx.try_recv() {
2904 Ok(Message::Close(_)) => {}
2905 other => panic!("expected Close, got {other:?}"),
2906 }
2907
2908 resume();
2909 }
2910
2911 #[tokio::test]
2912 async fn force_closes_after_timeout() {
2913 pause();
2914
2915 let conn = WebsocketConnection::new("c2");
2916 let (tx, mut rx) = unbounded_channel::<Message>();
2917 let (req_tx, _req_rx) = oneshot::channel();
2918 {
2919 let mut st = conn.state.lock().await;
2920 st.pending_requests.insert(
2921 "request_id".to_string(),
2922 PendingRequest { completion: req_tx },
2923 );
2924 }
2925 let common =
2926 WebsocketCommon::new(vec![conn.clone()], WebsocketMode::Single, 0, None, None);
2927 let close_fut = common.close_connection_gracefully(tx.clone(), conn.clone());
2928 advance(Duration::from_secs(30)).await;
2929 close_fut.await.unwrap();
2930 match rx.try_recv() {
2931 Ok(Message::Close(_)) => {}
2932 other => panic!("expected Close on timeout, got {other:?}"),
2933 }
2934
2935 resume();
2936 }
2937 }
2938
2939 mod get_reconnect_url {
2940 use super::*;
2941
2942 struct DummyHandler {
2943 url: String,
2944 }
2945
2946 #[async_trait::async_trait]
2947 impl WebsocketHandler for DummyHandler {
2948 async fn on_open(&self, _url: String, _connection: Arc<WebsocketConnection>) {}
2949 async fn on_message(&self, _data: String, _connection: Arc<WebsocketConnection>) {}
2950 async fn get_reconnect_url(
2951 &self,
2952 _default_url: String,
2953 _connection: Arc<WebsocketConnection>,
2954 ) -> String {
2955 self.url.clone()
2956 }
2957 }
2958
2959 #[test]
2960 fn returns_default_when_no_handler() {
2961 TOKIO_SHARED_RT.block_on(async {
2962 let conn = WebsocketConnection::new("c1");
2963 let common = WebsocketCommon::new(
2964 vec![conn.clone()],
2965 WebsocketMode::Single,
2966 0,
2967 None,
2968 None,
2969 );
2970 let default = "wss://default".to_string();
2971 let result = common.get_reconnect_url(&default, conn.clone()).await;
2972 assert_eq!(result, default);
2973 });
2974 }
2975
2976 #[test]
2977 fn returns_handler_url_when_set() {
2978 TOKIO_SHARED_RT.block_on(async {
2979 let conn = WebsocketConnection::new("c2");
2980 let handler = Arc::new(DummyHandler {
2981 url: "wss://custom".into(),
2982 });
2983 conn.set_handler(handler).await;
2984 let common = WebsocketCommon::new(
2985 vec![conn.clone()],
2986 WebsocketMode::Single,
2987 0,
2988 None,
2989 None,
2990 );
2991 let default = "wss://default".to_string();
2992 let result = common.get_reconnect_url(&default, conn.clone()).await;
2993 assert_eq!(result, "wss://custom");
2994 });
2995 }
2996 }
2997
2998 mod on_open {
2999 use super::*;
3000
3001 struct DummyHandler {
3002 called: Arc<Mutex<bool>>,
3003 opened_url: Arc<Mutex<Option<String>>>,
3004 }
3005
3006 #[async_trait]
3007 impl WebsocketHandler for DummyHandler {
3008 async fn on_open(&self, url: String, _connection: Arc<WebsocketConnection>) {
3009 let mut flag = self.called.lock().await;
3010 *flag = true;
3011 let mut store = self.opened_url.lock().await;
3012 *store = Some(url);
3013 }
3014 async fn on_message(&self, _data: String, _connection: Arc<WebsocketConnection>) {}
3015 async fn get_reconnect_url(
3016 &self,
3017 default_url: String,
3018 _connection: Arc<WebsocketConnection>,
3019 ) -> String {
3020 default_url
3021 }
3022 }
3023
3024 #[test]
3025 fn emits_open_and_calls_handler() {
3026 TOKIO_SHARED_RT.block_on(async {
3027 let conn = WebsocketConnection::new("c1");
3028 let called = Arc::new(Mutex::new(false));
3029 let opened_url = Arc::new(Mutex::new(None));
3030 let handler = Arc::new(DummyHandler {
3031 called: called.clone(),
3032 opened_url: opened_url.clone(),
3033 });
3034
3035 conn.set_handler(handler.clone()).await;
3036 let common = WebsocketCommon::new(
3037 vec![conn.clone()],
3038 WebsocketMode::Single,
3039 0,
3040 None,
3041 None,
3042 );
3043 let events = subscribe_events(&common);
3044 common
3045 .on_open("wss://example.com".into(), conn.clone(), None)
3046 .await;
3047
3048 sleep(std::time::Duration::from_millis(10)).await;
3049
3050 let evs = events.lock().await;
3051 assert!(evs.iter().any(|e| matches!(e, WebsocketEvent::Open)));
3052 assert!(*called.lock().await);
3053 assert_eq!(
3054 opened_url.lock().await.as_deref(),
3055 Some("wss://example.com")
3056 );
3057 });
3058 }
3059
3060 #[test]
3061 fn handles_renewal_pending_and_closes_old_writer() {
3062 TOKIO_SHARED_RT.block_on(async {
3063 let conn = WebsocketConnection::new("c2");
3064 let (old_tx, mut old_rx) = unbounded_channel::<Message>();
3065 {
3066 let mut st = conn.state.lock().await;
3067 st.renewal_pending = true;
3068 }
3069 let common = WebsocketCommon::new(
3070 vec![conn.clone()],
3071 WebsocketMode::Single,
3072 0,
3073 None,
3074 None,
3075 );
3076 common
3077 .on_open("url".into(), conn.clone(), Some(old_tx.clone()))
3078 .await;
3079 assert!(!conn.state.lock().await.renewal_pending);
3080 match old_rx.try_recv() {
3081 Ok(Message::Close(_)) => {}
3082 other => panic!("expected Close, got {other:?}"),
3083 }
3084 });
3085 }
3086 }
3087
3088 mod on_message {
3089 use super::*;
3090
3091 struct DummyHandler {
3092 called_with: Arc<Mutex<Vec<String>>>,
3093 }
3094
3095 #[async_trait]
3096 impl WebsocketHandler for DummyHandler {
3097 async fn on_open(&self, _url: String, _connection: Arc<WebsocketConnection>) {}
3098 async fn on_message(&self, data: String, _connection: Arc<WebsocketConnection>) {
3099 self.called_with.lock().await.push(data);
3100 }
3101 async fn get_reconnect_url(
3102 &self,
3103 default_url: String,
3104 _connection: Arc<WebsocketConnection>,
3105 ) -> String {
3106 default_url
3107 }
3108 }
3109
3110 #[test]
3111 fn emits_message_event_without_handler() {
3112 TOKIO_SHARED_RT.block_on(async {
3113 let conn = WebsocketConnection::new("c1");
3114 let common = WebsocketCommon::new(
3115 vec![conn.clone()],
3116 WebsocketMode::Single,
3117 0,
3118 None,
3119 None,
3120 );
3121 let events = subscribe_events(&common);
3122 common.on_message("msg".into(), conn.clone()).await;
3123
3124 sleep(Duration::from_millis(10)).await;
3125
3126 let locked = events.lock().await;
3127 assert!(
3128 locked
3129 .iter()
3130 .any(|e| matches!(e, WebsocketEvent::Message(m) if m == "msg"))
3131 );
3132 });
3133 }
3134
3135 #[test]
3136 fn calls_handler_and_emits_message() {
3137 TOKIO_SHARED_RT.block_on(async {
3138 let conn = WebsocketConnection::new("c2");
3139 let called = Arc::new(Mutex::new(Vec::new()));
3140 let handler = Arc::new(DummyHandler {
3141 called_with: called.clone(),
3142 });
3143 conn.set_handler(handler.clone()).await;
3144
3145 let common = WebsocketCommon::new(
3146 vec![conn.clone()],
3147 WebsocketMode::Single,
3148 0,
3149 None,
3150 None,
3151 );
3152 let events = subscribe_events(&common);
3153 common.on_message("msg".into(), conn.clone()).await;
3154
3155 sleep(Duration::from_millis(10)).await;
3156
3157 let evs = events.lock().await;
3158 assert!(
3159 evs.iter()
3160 .any(|e| matches!(e, WebsocketEvent::Message(m) if m == "msg"))
3161 );
3162 let msgs = called.lock().await;
3163 assert_eq!(msgs.as_slice(), &["msg".to_string()]);
3164 });
3165 }
3166 }
3167
3168 mod create_websocket {
3169 use super::*;
3170
3171 #[test]
3172 fn successful_connection() {
3173 TOKIO_SHARED_RT.block_on(async {
3174 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
3175 let addr: SocketAddr = listener.local_addr().unwrap();
3176
3177 let expected_ua = build_user_agent("product");
3178 let expected_ua_clone = expected_ua.clone();
3179
3180 tokio::spawn(async move {
3181 if let Ok((stream, _)) = listener.accept().await {
3182 let callback = |req: &Request, resp| {
3183 let got = req
3184 .headers()
3185 .get(USER_AGENT)
3186 .expect("no USER_AGENT header in WS handshake")
3187 .to_str()
3188 .expect("invalid USER_AGENT header");
3189 assert_eq!(got, expected_ua_clone, "User-Agent mismatch");
3190 Ok(resp)
3191 };
3192 let _ = accept_hdr_async(stream, callback).await.unwrap();
3193 }
3194 });
3195
3196 let url = format!("ws://{addr}");
3197 let res =
3198 WebsocketCommon::create_websocket(&url, None, Some(expected_ua)).await;
3199 assert!(res.is_ok(), "handshake failed: {res:?}");
3200 });
3201 }
3202
3203 #[test]
3204 fn invalid_url_returns_handshake_error() {
3205 TOKIO_SHARED_RT.block_on(async {
3206 let res =
3207 WebsocketCommon::create_websocket("not-a-valid-url", None, None).await;
3208 assert!(matches!(res, Err(WebsocketError::Handshake(_))));
3209 });
3210 }
3211
3212 #[test]
3213 fn unreachable_host_returns_handshake_error() {
3214 TOKIO_SHARED_RT.block_on(async {
3215 let res =
3216 WebsocketCommon::create_websocket("ws://127.0.0.1:1", None, None).await;
3217 assert!(matches!(res, Err(WebsocketError::Handshake(_))));
3218 });
3219 }
3220 }
3221
3222 mod connect_pool {
3223 use super::*;
3224
3225 #[test]
3226 fn connects_all_in_pool() {
3227 TOKIO_SHARED_RT.block_on(async {
3228 let pool_size = 3;
3229 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
3230 let addr = listener.local_addr().unwrap();
3231 tokio::spawn(async move {
3232 for _ in 0..pool_size {
3233 if let Ok((stream, _)) = listener.accept().await {
3234 let mut ws = accept_async(stream).await.unwrap();
3235 let _ = ws.close(None).await;
3236 }
3237 }
3238 });
3239 let conns: Vec<Arc<WebsocketConnection>> = (0..pool_size)
3240 .map(|i| WebsocketConnection::new(format!("c{i}")))
3241 .collect();
3242 let common = WebsocketCommon::new(
3243 conns.clone(),
3244 WebsocketMode::Pool(pool_size),
3245 0,
3246 None,
3247 None,
3248 );
3249 let url = format!("ws://{addr}");
3250 common.clone().connect_pool(&url).await.unwrap();
3251 for conn in conns {
3252 let st = conn.state.lock().await;
3253 assert!(st.ws_write_tx.is_some());
3254 }
3255 });
3256 }
3257
3258 #[test]
3259 fn fails_if_any_refused() {
3260 TOKIO_SHARED_RT.block_on(async {
3261 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
3262 let addr = listener.local_addr().unwrap();
3263 let pool_size = 3;
3264 tokio::spawn(async move {
3265 for _ in 0..2 {
3266 if let Ok((stream, _)) = listener.accept().await {
3267 let mut ws = accept_async(stream).await.unwrap();
3268 let _ = ws.close(None).await;
3269 }
3270 }
3271 });
3272 let mut conns = Vec::new();
3273 let valid_url = format!("ws://{addr}");
3274 for i in 0..2 {
3275 conns.push(WebsocketConnection::new(format!("c{i}")));
3276 }
3277 conns.push(WebsocketConnection::new("bad"));
3278 let common = WebsocketCommon::new(
3279 conns.clone(),
3280 WebsocketMode::Pool(pool_size),
3281 0,
3282 None,
3283 None,
3284 );
3285 let res = common.clone().connect_pool(&valid_url).await;
3286 assert!(matches!(res, Err(WebsocketError::Handshake(_))));
3287 });
3288 }
3289
3290 #[test]
3291 fn fails_on_invalid_url() {
3292 TOKIO_SHARED_RT.block_on(async {
3293 let conns = vec![WebsocketConnection::new("c1")];
3294 let common = WebsocketCommon::new(conns, WebsocketMode::Pool(1), 0, None, None);
3295 let res = common.connect_pool("not-a-url").await;
3296 assert!(matches!(res, Err(WebsocketError::Handshake(_))));
3297 });
3298 }
3299
3300 #[test]
3301 fn fails_if_mixed_success_and_invalid_url() {
3302 TOKIO_SHARED_RT.block_on(async {
3303 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
3304 let addr = listener.local_addr().unwrap();
3305 tokio::spawn(async move {
3306 if let Ok((stream, _)) = listener.accept().await {
3307 let mut ws = accept_async(stream).await.unwrap();
3308 let _ = ws.close(None).await;
3309 }
3310 });
3311 let good = WebsocketConnection::new("good");
3312 let bad = WebsocketConnection::new("bad");
3313 let common = WebsocketCommon::new(
3314 vec![good, bad],
3315 WebsocketMode::Pool(2),
3316 0,
3317 None,
3318 None,
3319 );
3320 let url = format!("ws://{addr}");
3321 let res = common.connect_pool(&url).await;
3322 assert!(matches!(res, Err(WebsocketError::Handshake(_))));
3323 });
3324 }
3325
3326 #[test]
3327 fn init_connect_invoked_for_each() {
3328 TOKIO_SHARED_RT.block_on(async {
3329 let pool_size = 2;
3330 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
3331 let addr = listener.local_addr().unwrap();
3332 tokio::spawn(async move {
3333 for _ in 0..pool_size {
3334 if let Ok((stream, _)) = listener.accept().await {
3335 let mut ws = accept_async(stream).await.unwrap();
3336 let _ = ws.close(None).await;
3337 }
3338 }
3339 });
3340 let conns: Vec<Arc<WebsocketConnection>> = (0..pool_size)
3341 .map(|i| WebsocketConnection::new(format!("c{i}")))
3342 .collect();
3343 let common = WebsocketCommon::new(
3344 conns.clone(),
3345 WebsocketMode::Pool(pool_size),
3346 0,
3347 None,
3348 None,
3349 );
3350 let url = format!("ws://{addr}");
3351 common.clone().connect_pool(&url).await.unwrap();
3352 for conn in conns {
3353 let st = conn.state.lock().await;
3354 assert!(st.ws_write_tx.is_some());
3355 }
3356 });
3357 }
3358
3359 #[test]
3360 fn single_mode_uses_first_connection() {
3361 TOKIO_SHARED_RT.block_on(async {
3362 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
3363 let addr = listener.local_addr().unwrap();
3364 tokio::spawn(async move {
3365 if let Ok((stream, _)) = listener.accept().await {
3366 let mut ws = accept_async(stream).await.unwrap();
3367 let _ = ws.close(None).await;
3368 }
3369 });
3370 let conn = WebsocketConnection::new("c1");
3371 let common = WebsocketCommon::new(
3372 vec![conn.clone()],
3373 WebsocketMode::Single,
3374 0,
3375 None,
3376 None,
3377 );
3378 let url = format!("ws://{addr}");
3379 common.connect_pool(&url).await.unwrap();
3380 let st = conn.state.lock().await;
3381 assert!(st.ws_write_tx.is_some());
3382 });
3383 }
3384 }
3385
3386 mod init_connect {
3387 use super::*;
3388
3389 #[test]
3390 fn pool_mode_none_connection_uses_first() {
3391 TOKIO_SHARED_RT.block_on(async {
3392 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
3393 let addr = listener.local_addr().unwrap();
3394 tokio::spawn(async move {
3395 for _ in 0..2 {
3396 if let Ok((stream, _)) = listener.accept().await {
3397 let mut ws = accept_async(stream).await.unwrap();
3398 ws.close(None).await.ok();
3399 }
3400 }
3401 });
3402
3403 let c1 = WebsocketConnection::new("c1");
3404 let c2 = WebsocketConnection::new("c2");
3405 let common = WebsocketCommon::new(
3406 vec![c1.clone(), c2.clone()],
3407 WebsocketMode::Pool(2),
3408 0,
3409 None,
3410 None,
3411 );
3412 let url = format!("ws://{addr}");
3413
3414 common
3415 .clone()
3416 .init_connect(&url, false, None)
3417 .await
3418 .unwrap();
3419 let st1 = c1.state.lock().await;
3420 let st2 = c2.state.lock().await;
3421
3422 assert!(st1.ws_write_tx.is_some());
3423 assert!(st2.ws_write_tx.is_none());
3424 });
3425 }
3426
3427 #[test]
3428 fn writer_channel_can_send_text() {
3429 TOKIO_SHARED_RT.block_on(async {
3430 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
3431 let addr = listener.local_addr().unwrap();
3432 let received = Arc::new(Mutex::new(None::<String>));
3433 let received_clone = received.clone();
3434
3435 tokio::spawn(async move {
3436 if let Ok((stream, _)) = listener.accept().await {
3437 let mut ws = accept_async(stream).await.unwrap();
3438 if let Some(Ok(Message::Text(txt))) = ws.next().await {
3439 *received_clone.lock().await = Some(txt.to_string());
3440 }
3441 ws.close(None).await.ok();
3442 }
3443 });
3444
3445 let conn = WebsocketConnection::new("cw");
3446 let common = WebsocketCommon::new(
3447 vec![conn.clone()],
3448 WebsocketMode::Single,
3449 0,
3450 None,
3451 None,
3452 );
3453 let url = format!("ws://{addr}");
3454 common
3455 .clone()
3456 .init_connect(&url, false, Some(conn.clone()))
3457 .await
3458 .unwrap();
3459
3460 let tx = conn.state.lock().await.ws_write_tx.clone().unwrap();
3461 tx.send(Message::Text("ping".into())).unwrap();
3462
3463 sleep(Duration::from_millis(50)).await;
3464
3465 let lock = received.lock().await;
3466 assert_eq!(lock.as_deref(), Some("ping"));
3467 });
3468 }
3469
3470 #[test]
3471 fn responds_to_ping_with_pong() {
3472 TOKIO_SHARED_RT.block_on(async {
3473 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
3474 let addr = listener.local_addr().unwrap();
3475
3476 let saw_pong = Arc::new(Mutex::new(false));
3477 let saw_pong2 = saw_pong.clone();
3478
3479 tokio::spawn(async move {
3480 if let Ok((stream, _)) = listener.accept().await {
3481 let mut ws = accept_async(stream).await.unwrap();
3482 ws.send(Message::Ping(vec![1, 2, 3].into())).await.unwrap();
3483 if let Some(Ok(Message::Pong(payload))) = ws.next().await {
3484 if payload[..] == [1, 2, 3] {
3485 *saw_pong2.lock().await = true;
3486 }
3487 }
3488 let _ = ws.close(None).await;
3489 }
3490 });
3491
3492 let conn = WebsocketConnection::new("c-ping");
3493 let common = WebsocketCommon::new(
3494 vec![conn.clone()],
3495 WebsocketMode::Single,
3496 0,
3497 None,
3498 None,
3499 );
3500 let url = format!("ws://{addr}");
3501 common
3502 .clone()
3503 .init_connect(&url, false, Some(conn))
3504 .await
3505 .unwrap();
3506
3507 sleep(Duration::from_millis(50)).await;
3508
3509 assert!(*saw_pong.lock().await, "server should have seen a Pong");
3510 });
3511 }
3512
3513 #[test]
3514 fn handshake_error_on_invalid_url() {
3515 TOKIO_SHARED_RT.block_on(async {
3516 let conn = WebsocketConnection::new("c-invalid");
3517 let common = WebsocketCommon::new(
3518 vec![conn.clone()],
3519 WebsocketMode::Single,
3520 0,
3521 None,
3522 None,
3523 );
3524 let res = common
3525 .clone()
3526 .init_connect("not-a-url", false, Some(conn.clone()))
3527 .await;
3528 assert!(matches!(res, Err(WebsocketError::Handshake(_))));
3529 });
3530 }
3531
3532 #[test]
3533 fn skip_if_writer_exists_and_not_renewal() {
3534 TOKIO_SHARED_RT.block_on(async {
3535 let conn = WebsocketConnection::new("c-writer");
3536 let (tx, mut rx) = unbounded_channel::<Message>();
3537 {
3538 let mut st = conn.state.lock().await;
3539 st.ws_write_tx = Some(tx.clone());
3540 }
3541 let common = WebsocketCommon::new(
3542 vec![conn.clone()],
3543 WebsocketMode::Single,
3544 0,
3545 None,
3546 None,
3547 );
3548 let res = common
3549 .clone()
3550 .init_connect("ws://127.0.0.1:1", false, Some(conn.clone()))
3551 .await;
3552
3553 assert!(res.is_ok());
3554 assert!(rx.try_recv().is_err());
3555 });
3556 }
3557
3558 #[test]
3559 fn short_circuit_on_already_renewing() {
3560 TOKIO_SHARED_RT.block_on(async {
3561 let conn = WebsocketConnection::new("c-renew");
3562 {
3563 let mut st = conn.state.lock().await;
3564 st.renewal_pending = true;
3565 }
3566 let common = WebsocketCommon::new(
3567 vec![conn.clone()],
3568 WebsocketMode::Single,
3569 0,
3570 None,
3571 None,
3572 );
3573 let res = common
3574 .clone()
3575 .init_connect("ws://127.0.0.1:1", true, Some(conn.clone()))
3576 .await;
3577
3578 assert!(res.is_ok());
3579 assert!(conn.state.lock().await.ws_write_tx.is_none());
3580 });
3581 }
3582
3583 #[test]
3584 fn is_renewal_true_sets_and_clears_flag() {
3585 TOKIO_SHARED_RT.block_on(async {
3586 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
3587 let addr = listener.local_addr().unwrap();
3588 tokio::spawn(async move {
3589 if let Ok((stream, _)) = listener.accept().await {
3590 let mut ws = accept_async(stream).await.unwrap();
3591 let _ = ws.close(None).await;
3592 }
3593 });
3594
3595 let conn = WebsocketConnection::new("c-new-renew");
3596 let common = WebsocketCommon::new(
3597 vec![conn.clone()],
3598 WebsocketMode::Single,
3599 0,
3600 None,
3601 None,
3602 );
3603 let url = format!("ws://{addr}");
3604 let res = common
3605 .clone()
3606 .init_connect(&url, true, Some(conn.clone()))
3607 .await;
3608
3609 assert!(res.is_ok());
3610 let st = conn.state.lock().await;
3611 assert!(st.ws_write_tx.is_some());
3612 assert!(!st.renewal_pending);
3613 });
3614 }
3615
3616 #[test]
3617 fn default_connection_selected_when_none_passed() {
3618 TOKIO_SHARED_RT.block_on(async {
3619 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
3620 let addr = listener.local_addr().unwrap();
3621 tokio::spawn(async move {
3622 if let Ok((stream, _)) = listener.accept().await {
3623 let mut ws = accept_async(stream).await.unwrap();
3624 let _ = ws.close(None).await;
3625 }
3626 });
3627 let conn = WebsocketConnection::new("c-default");
3628 let common = WebsocketCommon::new(
3629 vec![conn.clone()],
3630 WebsocketMode::Single,
3631 0,
3632 None,
3633 None,
3634 );
3635 let url = format!("ws://{addr}");
3636 let res = common.clone().init_connect(&url, false, None).await;
3637
3638 assert!(res.is_ok());
3639 assert!(conn.state.lock().await.ws_write_tx.is_some());
3640 });
3641 }
3642
3643 #[test]
3644 fn schedules_reconnect_on_abnormal_close() {
3645 TOKIO_SHARED_RT.block_on(async {
3646 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
3647 let addr = listener.local_addr().unwrap();
3648 tokio::spawn(async move {
3649 if let Ok((stream, _)) = listener.accept().await {
3650 let mut ws = accept_async(stream).await.unwrap();
3651 ws.close(Some(tungstenite::protocol::CloseFrame {
3652 code: tungstenite::protocol::frame::coding::CloseCode::Abnormal,
3653 reason: "oops".into(),
3654 }))
3655 .await
3656 .ok();
3657 }
3658 });
3659 let conn = WebsocketConnection::new("c-close");
3660 let common = WebsocketCommon::new(
3661 vec![conn.clone()],
3662 WebsocketMode::Single,
3663 10,
3664 None,
3665 None,
3666 );
3667 let url = format!("ws://{addr}");
3668 common
3669 .clone()
3670 .init_connect(&url, false, Some(conn.clone()))
3671 .await
3672 .unwrap();
3673
3674 sleep(Duration::from_millis(50)).await;
3675
3676 let st = conn.state.lock().await;
3677 assert!(
3678 st.reconnection_pending,
3679 "expected reconnection_pending to be true after abnormal close"
3680 );
3681 });
3682 }
3683 }
3684
3685 mod disconnect {
3686 use super::*;
3687
3688 #[test]
3689 fn returns_ok_when_no_connections_are_ready() {
3690 TOKIO_SHARED_RT.block_on(async {
3691 let conn = WebsocketConnection::new("c1");
3692 let common = WebsocketCommon::new(
3693 vec![conn.clone()],
3694 WebsocketMode::Single,
3695 0,
3696 None,
3697 None,
3698 );
3699 let res = common.disconnect().await;
3700
3701 assert!(res.is_ok());
3702 assert!(!conn.state.lock().await.close_initiated);
3703 });
3704 }
3705
3706 #[test]
3707 fn closes_all_ready_connections() {
3708 TOKIO_SHARED_RT.block_on(async {
3709 let conn1 = WebsocketConnection::new("c1");
3710 let conn2 = WebsocketConnection::new("c2");
3711 let (tx1, mut rx1) = unbounded_channel::<Message>();
3712 let (tx2, mut rx2) = unbounded_channel::<Message>();
3713 {
3714 let mut s1 = conn1.state.lock().await;
3715 s1.ws_write_tx = Some(tx1);
3716 }
3717 {
3718 let mut s2 = conn2.state.lock().await;
3719 s2.ws_write_tx = Some(tx2);
3720 }
3721 let common = WebsocketCommon::new(
3722 vec![conn1.clone(), conn2.clone()],
3723 WebsocketMode::Pool(2),
3724 0,
3725 None,
3726 None,
3727 );
3728 let fut = common.disconnect();
3729
3730 sleep(Duration::from_millis(50)).await;
3731
3732 fut.await.unwrap();
3733
3734 assert!(conn1.state.lock().await.close_initiated);
3735 assert!(conn2.state.lock().await.close_initiated);
3736
3737 match (rx1.try_recv(), rx2.try_recv()) {
3738 (Ok(Message::Close(_)), Ok(Message::Close(_))) => {}
3739 other => panic!("expected two Closes, got {other:?}"),
3740 }
3741 });
3742 }
3743
3744 #[test]
3745 fn does_not_mark_close_initiated_if_no_writer() {
3746 TOKIO_SHARED_RT.block_on(async {
3747 let conn = WebsocketConnection::new("c-new");
3748 let common = WebsocketCommon::new(
3749 vec![conn.clone()],
3750 WebsocketMode::Single,
3751 0,
3752 None,
3753 None,
3754 );
3755 common.disconnect().await.unwrap();
3756
3757 assert!(!conn.state.lock().await.close_initiated);
3758 });
3759 }
3760
3761 #[test]
3762 fn mixed_pool_marks_all_and_closes_only_writers() {
3763 TOKIO_SHARED_RT.block_on(async {
3764 let conn_w = WebsocketConnection::new("with");
3765 let conn_wo = WebsocketConnection::new("without");
3766 let (tx, mut rx) = unbounded_channel::<Message>();
3767 {
3768 let mut st = conn_w.state.lock().await;
3769 st.ws_write_tx = Some(tx);
3770 }
3771 let common = WebsocketCommon::new(
3772 vec![conn_w.clone(), conn_wo.clone()],
3773 WebsocketMode::Pool(2),
3774 0,
3775 None,
3776 None,
3777 );
3778 let fut = common.disconnect();
3779
3780 sleep(Duration::from_millis(50)).await;
3781
3782 fut.await.unwrap();
3783
3784 assert!(conn_w.state.lock().await.close_initiated);
3785 assert!(conn_wo.state.lock().await.close_initiated);
3786 assert!(matches!(rx.try_recv(), Ok(Message::Close(_))));
3787 });
3788 }
3789
3790 #[test]
3791 fn after_disconnect_not_connected() {
3792 TOKIO_SHARED_RT.block_on(async {
3793 let conn = WebsocketConnection::new("c1");
3794 let (tx, mut _rx) = unbounded_channel::<Message>();
3795 {
3796 let mut st = conn.state.lock().await;
3797 st.ws_write_tx = Some(tx);
3798 }
3799 let common = WebsocketCommon::new(
3800 vec![conn.clone()],
3801 WebsocketMode::Single,
3802 0,
3803 None,
3804 None,
3805 );
3806 common.disconnect().await.unwrap();
3807 assert!(!common.is_connected(Some(&conn)).await);
3808 });
3809 }
3810 }
3811
3812 mod ping_server {
3813 use super::*;
3814
3815 #[test]
3816 fn sends_ping_to_all_ready_connections() {
3817 TOKIO_SHARED_RT.block_on(async {
3818 let mut conns = Vec::new();
3819 for i in 0..3 {
3820 let conn = WebsocketConnection::new(format!("c{i}"));
3821 let (tx, rx) = unbounded_channel::<Message>();
3822 {
3823 let mut st = conn.state.lock().await;
3824 st.ws_write_tx = Some(tx);
3825 }
3826 conns.push((conn, rx));
3827 }
3828 let common = WebsocketCommon::new(
3829 conns.iter().map(|(c, _)| c.clone()).collect(),
3830 WebsocketMode::Pool(3),
3831 0,
3832 None,
3833 None,
3834 );
3835 common.ping_server().await;
3836 for (_, mut rx) in conns {
3837 match rx.try_recv() {
3838 Ok(Message::Ping(payload)) if payload.is_empty() => {}
3839 other => panic!("expected empty-payload Ping, got {other:?}"),
3840 }
3841 }
3842 });
3843 }
3844
3845 #[test]
3846 fn skips_not_ready_and_partial() {
3847 TOKIO_SHARED_RT.block_on(async {
3848 let ready = WebsocketConnection::new("ready");
3849 let not_ready = WebsocketConnection::new("not-ready");
3850 let (tx_r, mut rx_r) = unbounded_channel::<Message>();
3851 {
3852 let mut st = ready.state.lock().await;
3853 st.ws_write_tx = Some(tx_r);
3854 }
3855 {
3856 let mut st = not_ready.state.lock().await;
3857 st.ws_write_tx = None;
3858 }
3859 let common = WebsocketCommon::new(
3860 vec![ready.clone(), not_ready.clone()],
3861 WebsocketMode::Pool(2),
3862 0,
3863 None,
3864 None,
3865 );
3866 common.ping_server().await;
3867 match rx_r.try_recv() {
3868 Ok(Message::Ping(payload)) if payload.is_empty() => {}
3869 other => panic!("expected Ping on ready, got {other:?}"),
3870 }
3871 });
3872 }
3873
3874 #[test]
3875 fn no_ping_when_flags_block() {
3876 TOKIO_SHARED_RT.block_on(async {
3877 let conn = WebsocketConnection::new("c1");
3878 let (tx, mut rx) = unbounded_channel::<Message>();
3879 {
3880 let mut st = conn.state.lock().await;
3881 st.ws_write_tx = Some(tx);
3882 st.reconnection_pending = true;
3883 }
3884 let common = WebsocketCommon::new(
3885 vec![conn.clone()],
3886 WebsocketMode::Single,
3887 0,
3888 None,
3889 None,
3890 );
3891 common.ping_server().await;
3892 assert!(rx.try_recv().is_err());
3893 });
3894 }
3895 }
3896
3897 mod send {
3898 use super::*;
3899
3900 #[test]
3901 fn round_robin_send_without_specific() {
3902 TOKIO_SHARED_RT.block_on(async {
3903 let conn1 = WebsocketConnection::new("c1");
3904 let conn2 = WebsocketConnection::new("c2");
3905 let (tx1, mut rx1) = unbounded_channel::<Message>();
3906 let (tx2, mut rx2) = unbounded_channel::<Message>();
3907 {
3908 let mut s1 = conn1.state.lock().await;
3909 s1.ws_write_tx = Some(tx1);
3910 }
3911 {
3912 let mut s2 = conn2.state.lock().await;
3913 s2.ws_write_tx = Some(tx2);
3914 }
3915 let common = WebsocketCommon::new(
3916 vec![conn1.clone(), conn2.clone()],
3917 WebsocketMode::Pool(2),
3918 0,
3919 None,
3920 None,
3921 );
3922
3923 let res1 = common
3924 .send("a".into(), None, false, Duration::from_secs(1), None)
3925 .await
3926 .unwrap();
3927 assert!(res1.is_none());
3928
3929 let res2 = common
3930 .send("b".into(), None, false, Duration::from_secs(1), None)
3931 .await
3932 .unwrap();
3933 assert!(res2.is_none());
3934
3935 assert_eq!(
3936 if let Message::Text(t) = rx1.try_recv().unwrap() {
3937 t
3938 } else {
3939 panic!()
3940 },
3941 "a"
3942 );
3943 assert_eq!(
3944 if let Message::Text(t) = rx2.try_recv().unwrap() {
3945 t
3946 } else {
3947 panic!()
3948 },
3949 "b"
3950 );
3951 });
3952 }
3953
3954 #[test]
3955 fn round_robin_skips_not_ready() {
3956 TOKIO_SHARED_RT.block_on(async {
3957 let conn1 = WebsocketConnection::new("c1");
3958 let conn2 = WebsocketConnection::new("c2");
3959 let (tx2, mut rx2) = unbounded_channel::<Message>();
3960 {
3961 let mut s1 = conn1.state.lock().await;
3962 s1.ws_write_tx = None;
3963 }
3964 {
3965 let mut s2 = conn2.state.lock().await;
3966 s2.ws_write_tx = Some(tx2);
3967 }
3968 let common = WebsocketCommon::new(
3969 vec![conn1.clone(), conn2.clone()],
3970 WebsocketMode::Pool(2),
3971 0,
3972 None,
3973 None,
3974 );
3975 let res = common
3976 .send("bar".into(), None, false, Duration::from_secs(1), None)
3977 .await
3978 .unwrap();
3979 assert!(res.is_none());
3980 match rx2.try_recv().unwrap() {
3981 Message::Text(t) => assert_eq!(t, "bar"),
3982 other => panic!("unexpected {other:?}"),
3983 }
3984 });
3985 }
3986
3987 #[test]
3988 fn sync_send_on_specific_connection() {
3989 TOKIO_SHARED_RT.block_on(async {
3990 let conn1 = WebsocketConnection::new("c1");
3991 let conn2 = WebsocketConnection::new("c2");
3992 let (tx2, mut rx2) = unbounded_channel::<Message>();
3993 {
3994 let mut st = conn2.state.lock().await;
3995 st.ws_write_tx = Some(tx2);
3996 }
3997 let common = WebsocketCommon::new(
3998 vec![conn1.clone(), conn2.clone()],
3999 WebsocketMode::Pool(2),
4000 0,
4001 None,
4002 None,
4003 );
4004 let res = common
4005 .send(
4006 "payload".into(),
4007 Some("id".into()),
4008 false,
4009 Duration::from_secs(1),
4010 Some(conn2.clone()),
4011 )
4012 .await
4013 .unwrap();
4014 assert!(res.is_none());
4015 match rx2.try_recv() {
4016 Ok(Message::Text(t)) => assert_eq!(t, "payload"),
4017 other => panic!("expected Text, got {other:?}"),
4018 }
4019 });
4020 }
4021
4022 #[test]
4023 fn sync_send_with_id_does_not_insert_pending() {
4024 TOKIO_SHARED_RT.block_on(async {
4025 let conn = WebsocketConnection::new("c1");
4026 let (tx, mut rx) = unbounded_channel::<Message>();
4027 {
4028 let mut st = conn.state.lock().await;
4029 st.ws_write_tx = Some(tx);
4030 }
4031 let common = WebsocketCommon::new(
4032 vec![conn.clone()],
4033 WebsocketMode::Single,
4034 0,
4035 None,
4036 None,
4037 );
4038 let res = common
4039 .send(
4040 "msg".into(),
4041 Some("id".into()),
4042 false,
4043 Duration::from_secs(1),
4044 Some(conn.clone()),
4045 )
4046 .await
4047 .unwrap();
4048 assert!(res.is_none());
4049 assert!(conn.state.lock().await.pending_requests.is_empty());
4050 match rx.try_recv().unwrap() {
4051 Message::Text(t) => assert_eq!(t, "msg"),
4052 other => panic!("unexpected {other:?}"),
4053 }
4054 });
4055 }
4056
4057 #[test]
4058 fn sync_send_error_if_not_ready() {
4059 TOKIO_SHARED_RT.block_on(async {
4060 let conn = WebsocketConnection::new("c1");
4061 let common = WebsocketCommon::new(
4062 vec![conn.clone()],
4063 WebsocketMode::Single,
4064 0,
4065 None,
4066 None,
4067 );
4068 let err = common
4069 .send(
4070 "msg".into(),
4071 Some("id".into()),
4072 false,
4073 Duration::from_secs(1),
4074 Some(conn.clone()),
4075 )
4076 .await
4077 .unwrap_err();
4078 assert!(matches!(err, WebsocketError::NotConnected));
4079 });
4080 }
4081
4082 #[test]
4083 fn sync_send_error_when_no_ready() {
4084 TOKIO_SHARED_RT.block_on(async {
4085 let conn = WebsocketConnection::new("c1");
4086 let common = WebsocketCommon::new(
4087 vec![conn.clone()],
4088 WebsocketMode::Single,
4089 0,
4090 None,
4091 None,
4092 );
4093 let err = common
4094 .send("msg".into(), None, false, Duration::from_secs(1), None)
4095 .await
4096 .unwrap_err();
4097 assert!(matches!(err, WebsocketError::NotConnected));
4098 });
4099 }
4100
4101 #[test]
4102 fn async_send_and_receive() {
4103 TOKIO_SHARED_RT.block_on(async {
4104 let conn = WebsocketConnection::new("c1");
4105 let (tx, mut rx) = unbounded_channel::<Message>();
4106 {
4107 let mut st = conn.state.lock().await;
4108 st.ws_write_tx = Some(tx);
4109 }
4110 let common = WebsocketCommon::new(
4111 vec![conn.clone()],
4112 WebsocketMode::Single,
4113 0,
4114 None,
4115 None,
4116 );
4117 let fut = common
4118 .send(
4119 "hello".into(),
4120 Some("id".into()),
4121 true,
4122 Duration::from_secs(5),
4123 Some(conn.clone()),
4124 )
4125 .await
4126 .unwrap()
4127 .unwrap();
4128 match rx.try_recv() {
4129 Ok(Message::Text(t)) => assert_eq!(t, "hello"),
4130 other => panic!("expected Text, got {other:?}"),
4131 }
4132 {
4133 let mut st = conn.state.lock().await;
4134 let pr = st.pending_requests.remove("id").unwrap();
4135 pr.completion.send(Ok(serde_json::json!("ok"))).unwrap();
4136 }
4137 let resp = fut.await.unwrap().unwrap();
4138 assert_eq!(resp, serde_json::json!("ok"));
4139 });
4140 }
4141
4142 #[test]
4143 fn async_send_default_connection() {
4144 TOKIO_SHARED_RT.block_on(async {
4145 let conn = WebsocketConnection::new("c1");
4146 let (tx, mut rx) = unbounded_channel::<Message>();
4147 {
4148 let mut st = conn.state.lock().await;
4149 st.ws_write_tx = Some(tx);
4150 }
4151 let common = WebsocketCommon::new(
4152 vec![conn.clone()],
4153 WebsocketMode::Single,
4154 0,
4155 None,
4156 None,
4157 );
4158 let fut = common
4159 .send(
4160 "msg".into(),
4161 Some("id".into()),
4162 true,
4163 Duration::from_secs(5),
4164 None,
4165 )
4166 .await
4167 .unwrap()
4168 .unwrap();
4169 match rx.try_recv() {
4170 Ok(Message::Text(t)) => assert_eq!(t, "msg"),
4171 _ => panic!("no text"),
4172 }
4173 {
4174 let mut st = conn.state.lock().await;
4175 let pr = st.pending_requests.remove("id").unwrap();
4176 pr.completion.send(Ok(serde_json::json!(123))).unwrap();
4177 }
4178 let resp = fut.await.unwrap().unwrap();
4179 assert_eq!(resp, serde_json::json!(123));
4180 });
4181 }
4182
4183 #[test]
4184 fn async_send_error_if_no_id() {
4185 TOKIO_SHARED_RT.block_on(async {
4186 let conn = WebsocketConnection::new("c§");
4187 let (tx, _rx) = unbounded_channel::<Message>();
4188 {
4189 let mut st = conn.state.lock().await;
4190 st.ws_write_tx = Some(tx);
4191 }
4192 let common = WebsocketCommon::new(
4193 vec![conn.clone()],
4194 WebsocketMode::Single,
4195 0,
4196 None,
4197 None,
4198 );
4199 let err = common
4200 .send(
4201 "msg".into(),
4202 None,
4203 true,
4204 Duration::from_secs(1),
4205 Some(conn.clone()),
4206 )
4207 .await
4208 .unwrap_err();
4209 assert!(matches!(err, WebsocketError::NotConnected));
4210 });
4211 }
4212
4213 #[test]
4214 fn timeout_rejects_async() {
4215 TOKIO_SHARED_RT.block_on(async {
4216 pause();
4217 let conn = WebsocketConnection::new("c1");
4218 let (tx, _rx) = unbounded_channel::<Message>();
4219 {
4220 let mut st = conn.state.lock().await;
4221 st.ws_write_tx = Some(tx);
4222 }
4223 let common = WebsocketCommon::new(
4224 vec![conn.clone()],
4225 WebsocketMode::Single,
4226 0,
4227 None,
4228 None,
4229 );
4230 let fut = common
4231 .send(
4232 "msg".into(),
4233 Some("id".into()),
4234 true,
4235 Duration::from_secs(1),
4236 Some(conn.clone()),
4237 )
4238 .await
4239 .unwrap()
4240 .unwrap();
4241 advance(Duration::from_secs(1)).await;
4242 let res = fut.await.unwrap();
4243 assert!(res.is_err(), "expected timeout error");
4244 assert!(!conn.state.lock().await.pending_requests.contains_key("id"));
4245 });
4246 }
4247
4248 #[test]
4249 fn async_send_errors_if_no_connection_ready() {
4250 TOKIO_SHARED_RT.block_on(async {
4251 let conn = WebsocketConnection::new("c1");
4252 let common = WebsocketCommon::new(
4253 vec![conn.clone()],
4254 WebsocketMode::Single,
4255 0,
4256 None,
4257 None,
4258 );
4259 let err = common
4260 .send(
4261 "msg".into(),
4262 Some("id".into()),
4263 true,
4264 Duration::from_secs(1),
4265 None,
4266 )
4267 .await
4268 .unwrap_err();
4269 assert!(matches!(err, WebsocketError::NotConnected));
4270 });
4271 }
4272 }
4273 }
4274
4275 mod websocket_api {
4276 use super::*;
4277
4278 mod initialisation {
4279 use super::*;
4280
4281 #[test]
4282 fn new_initializes_common() {
4283 TOKIO_SHARED_RT.block_on(async {
4284 let conn = WebsocketConnection::new("id");
4285 let pool = vec![conn.clone()];
4286
4287 let sig_gen = SignatureGenerator::new(
4288 Some("api_secret".to_string()),
4289 None::<PrivateKey>,
4290 None::<String>,
4291 );
4292
4293 let config = ConfigurationWebsocketApi {
4294 api_key: Some("api_key".to_string()),
4295 api_secret: Some("api_secret".to_string()),
4296 private_key: None,
4297 private_key_passphrase: None,
4298 ws_url: Some("wss://example".to_string()),
4299 mode: WebsocketMode::Single,
4300 reconnect_delay: 1000,
4301 signature_gen: sig_gen,
4302 timeout: 500,
4303 time_unit: None,
4304 agent: None,
4305 user_agent: build_user_agent("product"),
4306 };
4307
4308 let api = WebsocketApi::new(config, pool.clone());
4309
4310 assert_eq!(api.common.connection_pool.len(), 1);
4311 assert_eq!(api.common.mode, WebsocketMode::Single);
4312
4313 let flag = *api.is_connecting.lock().await;
4314 assert!(!flag);
4315 });
4316 }
4317 }
4318
4319 mod connect {
4320 use super::*;
4321
4322 #[test]
4323 fn connect_when_not_connected_establishes() {
4324 TOKIO_SHARED_RT.block_on(async {
4325 let conn = WebsocketConnection::new("id");
4326 {
4327 let mut st = conn.state.lock().await;
4328 st.ws_write_tx = None;
4329 }
4330 let sig = SignatureGenerator::new(
4331 Some("api_secret".into()),
4332 None::<PrivateKey>,
4333 None::<String>,
4334 );
4335 let cfg = ConfigurationWebsocketApi {
4336 api_key: Some("api_key".into()),
4337 api_secret: Some("api_secret".to_string()),
4338 private_key: None,
4339 private_key_passphrase: None,
4340 ws_url: Some("ws://doesnotexist:1".to_string()),
4341 mode: WebsocketMode::Single,
4342 reconnect_delay: 0,
4343 signature_gen: sig,
4344 timeout: 10,
4345 time_unit: None,
4346 agent: None,
4347 user_agent: build_user_agent("product"),
4348 };
4349 let api = WebsocketApi::new(cfg, vec![conn.clone()]);
4350 let res = api.clone().connect().await;
4351 assert!(!matches!(res, Err(WebsocketError::Timeout)));
4352 });
4353 }
4354
4355 #[test]
4356 fn already_connected_returns_ok() {
4357 TOKIO_SHARED_RT.block_on(async {
4358 let conn = WebsocketConnection::new("id2");
4359 let (tx, _) = unbounded_channel();
4360 {
4361 let mut st = conn.state.lock().await;
4362 st.ws_write_tx = Some(tx);
4363 }
4364 let sig = SignatureGenerator::new(
4365 Some("api_secret".to_string()),
4366 None::<PrivateKey>,
4367 None::<String>,
4368 );
4369 let cfg = ConfigurationWebsocketApi {
4370 api_key: Some("api_key".to_string()),
4371 api_secret: Some("api_secret".to_string()),
4372 private_key: None,
4373 private_key_passphrase: None,
4374 ws_url: Some("ws://example.com".to_string()),
4375 mode: WebsocketMode::Single,
4376 reconnect_delay: 0,
4377 signature_gen: sig,
4378 timeout: 10,
4379 time_unit: None,
4380 agent: None,
4381 user_agent: build_user_agent("product"),
4382 };
4383 let api = WebsocketApi::new(cfg, vec![conn.clone()]);
4384 let res = api.connect().await;
4385 assert!(res.is_ok());
4386 });
4387 }
4388
4389 #[test]
4390 fn not_connected_returns_error() {
4391 TOKIO_SHARED_RT.block_on(async {
4392 let conn = WebsocketConnection::new("id1");
4393 let sig = SignatureGenerator::new(
4394 Some("api_secret".to_string()),
4395 None::<PrivateKey>,
4396 None::<String>,
4397 );
4398 let cfg = ConfigurationWebsocketApi {
4399 api_key: Some("api_key".to_string()),
4400 api_secret: Some("api_secret".to_string()),
4401 private_key: None,
4402 private_key_passphrase: None,
4403 ws_url: Some("ws://127.0.0.1:9".to_string()),
4404 mode: WebsocketMode::Single,
4405 reconnect_delay: 0,
4406 signature_gen: sig,
4407 timeout: 10,
4408 time_unit: None,
4409 agent: None,
4410 user_agent: build_user_agent("product"),
4411 };
4412 let api = WebsocketApi::new(cfg, vec![conn.clone()]);
4413 let res = api.connect().await;
4414 assert!(res.is_err());
4415 });
4416 }
4417
4418 #[test]
4419 fn concurrent_calls_both_error_or_ok() {
4420 TOKIO_SHARED_RT.block_on(async {
4421 let conn = WebsocketConnection::new("id3");
4422 let sig = SignatureGenerator::new(
4423 Some("api_secret".to_string()),
4424 None::<PrivateKey>,
4425 None::<String>,
4426 );
4427 let cfg = ConfigurationWebsocketApi {
4428 api_key: Some("api_key".to_string()),
4429 api_secret: Some("api_secret".to_string()),
4430 private_key: None,
4431 private_key_passphrase: None,
4432 ws_url: Some("wss://invalid-domain".to_string()),
4433 mode: WebsocketMode::Single,
4434 reconnect_delay: 0,
4435 signature_gen: sig,
4436 timeout: 10,
4437 time_unit: None,
4438 agent: None,
4439 user_agent: build_user_agent("product"),
4440 };
4441 let api = WebsocketApi::new(cfg, vec![conn.clone()]);
4442 let fut1 = tokio::spawn(api.clone().connect());
4443 let fut2 = tokio::spawn(api.clone().connect());
4444 let r1 = fut1.await.unwrap();
4445 let r2 = fut2.await.unwrap();
4446
4447 assert!(r1.is_err());
4448 assert!(r2.is_err() || r2.is_ok());
4449 });
4450 }
4451
4452 #[test]
4453 fn pool_failure_is_propagated() {
4454 TOKIO_SHARED_RT.block_on(async {
4455 let conn = WebsocketConnection::new("w");
4456 let sig = SignatureGenerator::new(
4457 Some("api_secret".to_string()),
4458 None::<PrivateKey>,
4459 None::<String>,
4460 );
4461 let cfg = ConfigurationWebsocketApi {
4462 api_key: Some("api_key".into()),
4463 api_secret: Some("api_secret".to_string()),
4464 private_key: None,
4465 private_key_passphrase: None,
4466 ws_url: Some("ws://doesnotexist:1".to_string()),
4467 mode: WebsocketMode::Single,
4468 reconnect_delay: 0,
4469 signature_gen: sig,
4470 timeout: 10,
4471 time_unit: None,
4472 agent: None,
4473 user_agent: build_user_agent("product"),
4474 };
4475 let api = WebsocketApi::new(cfg, vec![conn.clone()]);
4476 let res = api.clone().connect().await;
4477 match res {
4478 Err(WebsocketError::Handshake(_) | WebsocketError::Timeout) => {}
4479 _ => panic!("expected handshake or timeout error"),
4480 }
4481 });
4482 }
4483 }
4484
4485 mod send_message {
4486 use super::*;
4487
4488 #[test]
4489 fn unsigned_message() {
4490 TOKIO_SHARED_RT.block_on(async {
4491 let api = create_websocket_api(None);
4492 let conn = &api.common.connection_pool[0];
4493 let (tx, mut rx) = unbounded_channel::<Message>();
4494 {
4495 let mut st = conn.state.lock().await;
4496 st.ws_write_tx = Some(tx);
4497 }
4498
4499 let fut = tokio::spawn({
4500 let api = api.clone();
4501 async move {
4502 let mut params = BTreeMap::new();
4503 params.insert("foo".into(), Value::String("bar".into()));
4504 api.send_message::<Value>(
4505 "mymethod",
4506 params,
4507 WebsocketMessageSendOptions {
4508 with_api_key: false,
4509 is_signed: false,
4510 },
4511 )
4512 .await
4513 .unwrap()
4514 }
4515 });
4516
4517 let sent = rx.recv().await.unwrap();
4518 let Message::Text(txt) = sent else { panic!() };
4519 let req: Value = serde_json::from_str(&txt).unwrap();
4520 assert_eq!(req["method"], "mymethod");
4521 assert!(req["params"]["foo"] == "bar");
4522 assert!(req["params"].get("apiKey").is_none());
4523 assert!(req["params"].get("timestamp").is_none());
4524 assert!(req["params"].get("signature").is_none());
4525
4526 let id = req["id"].as_str().unwrap().to_string();
4527 let mut st = conn.state.lock().await;
4528 let pending = st.pending_requests.remove(&id).unwrap();
4529 let reply = json!({
4530 "id": id,
4531 "result": { "x": 42 },
4532 "rateLimits": [{ "limit": 7 }]
4533 });
4534 pending.completion.send(Ok(reply)).unwrap();
4535
4536 let resp = fut.await.unwrap();
4537 let rate_limits = resp.rate_limits.unwrap_or_default();
4538
4539 assert!(rate_limits.is_empty());
4540 assert_eq!(resp.raw, json!({"x": 42}));
4541 });
4542 }
4543
4544 #[test]
4545 fn with_api_key_only() {
4546 TOKIO_SHARED_RT.block_on(async {
4547 let api = create_websocket_api(None);
4548 let conn = &api.common.connection_pool[0];
4549 let (tx, mut rx) = unbounded_channel::<Message>();
4550 {
4551 let mut st = conn.state.lock().await;
4552 st.ws_write_tx = Some(tx);
4553 }
4554
4555 let fut = tokio::spawn({
4556 let api = api.clone();
4557 async move {
4558 let params = BTreeMap::new();
4559 api.send_message::<Value>(
4560 "foo",
4561 params,
4562 WebsocketMessageSendOptions {
4563 with_api_key: true,
4564 is_signed: false,
4565 },
4566 )
4567 .await
4568 .unwrap()
4569 }
4570 });
4571
4572 let Message::Text(txt) = rx.recv().await.unwrap() else {
4573 panic!()
4574 };
4575 let req: Value = serde_json::from_str(&txt).unwrap();
4576 assert_eq!(req["params"]["apiKey"], "api_key");
4577
4578 let id = req["id"].as_str().unwrap().to_string();
4579 let mut st = conn.state.lock().await;
4580 let pending = st.pending_requests.remove(&id).unwrap();
4581 pending
4582 .completion
4583 .send(Ok(json!({
4584 "id": id,
4585 "result": {},
4586 "rateLimits": []
4587 })))
4588 .unwrap();
4589
4590 let resp = fut.await.unwrap();
4591
4592 assert_eq!(resp.raw, json!({}));
4593 assert!(st.pending_requests.is_empty());
4594 });
4595 }
4596
4597 #[test]
4598 fn signed_message_has_timestamp_and_signature() {
4599 TOKIO_SHARED_RT.block_on(async {
4600 let api = create_websocket_api(None);
4601 let conn = &api.common.connection_pool[0];
4602 let (tx, mut rx) = unbounded_channel::<Message>();
4603 {
4604 let mut st = conn.state.lock().await;
4605 st.ws_write_tx = Some(tx);
4606 }
4607
4608 let fut = tokio::spawn({
4609 let api = api.clone();
4610 async move {
4611 let mut params = BTreeMap::new();
4612 params.insert("foo".into(), Value::String("bar".into()));
4613 api.send_message::<Value>(
4614 "method",
4615 params,
4616 WebsocketMessageSendOptions {
4617 with_api_key: true,
4618 is_signed: true,
4619 },
4620 )
4621 .await
4622 .unwrap()
4623 }
4624 });
4625
4626 let Message::Text(txt) = rx.recv().await.unwrap() else {
4627 panic!()
4628 };
4629 let req: Value = serde_json::from_str(&txt).unwrap();
4630 let p = &req["params"];
4631 assert!(p["apiKey"] == "api_key");
4632 assert!(p["timestamp"].is_number());
4633 assert!(p["signature"].is_string());
4634
4635 let id = req["id"].as_str().unwrap().to_string();
4636 let mut st = conn.state.lock().await;
4637 let pending = st.pending_requests.remove(&id).unwrap();
4638 pending
4639 .completion
4640 .send(Ok(json!({
4641 "id": id,
4642 "result": { "ok": true },
4643 "rateLimits": []
4644 })))
4645 .unwrap();
4646
4647 let resp = fut.await.unwrap();
4648 assert_eq!(resp.raw, json!({ "ok": true }));
4649 });
4650 }
4651
4652 #[test]
4653 fn error_if_not_connected() {
4654 TOKIO_SHARED_RT.block_on(async {
4655 let api = create_websocket_api(None);
4656 let conn = &api.common.connection_pool[0];
4657 {
4658 let mut st = conn.state.lock().await;
4659 st.ws_write_tx = None;
4660 }
4661 let params = BTreeMap::new();
4662 let err = api
4663 .send_message::<Value>(
4664 "method",
4665 params,
4666 WebsocketMessageSendOptions {
4667 with_api_key: false,
4668 is_signed: false,
4669 },
4670 )
4671 .await
4672 .unwrap_err();
4673 matches!(err, WebsocketError::NotConnected);
4674 });
4675 }
4676 }
4677
4678 mod prepare_url {
4679 use super::*;
4680
4681 #[test]
4682 fn no_time_unit() {
4683 TOKIO_SHARED_RT.block_on(async {
4684 let api = create_websocket_api(None);
4685 let url = "wss://example.com/ws".to_string();
4686 assert_eq!(api.prepare_url(&url), url);
4687 });
4688 }
4689
4690 #[test]
4691 fn appends_time_unit() {
4692 TOKIO_SHARED_RT.block_on(async {
4693 let api = create_websocket_api(Some(TimeUnit::Millisecond));
4694 let base = "wss://example.com/ws".to_string();
4695 let got = api.prepare_url(&base);
4696 assert_eq!(got, format!("{base}?timeUnit=millisecond"));
4697 });
4698 }
4699
4700 #[test]
4701 fn handles_existing_query() {
4702 TOKIO_SHARED_RT.block_on(async {
4703 let api = create_websocket_api(Some(TimeUnit::Microsecond));
4704 let base = "wss://example.com/ws?foo=bar".to_string();
4705 let got = api.prepare_url(&base);
4706 assert_eq!(got, format!("{base}&timeUnit=microsecond"));
4707 });
4708 }
4709 }
4710
4711 mod on_message {
4712 use super::*;
4713
4714 fn create_websocket_api_and_conn() -> (Arc<WebsocketApi>, Arc<WebsocketConnection>) {
4715 let sig_gen = SignatureGenerator::new(
4716 Some("api_secret".to_string()),
4717 None::<_>,
4718 None::<String>,
4719 );
4720 let config = ConfigurationWebsocketApi {
4721 api_key: Some("api_key".to_string()),
4722 api_secret: Some("api_secret".to_string()),
4723 private_key: None,
4724 private_key_passphrase: None,
4725 ws_url: Some("wss://example".to_string()),
4726 mode: WebsocketMode::Single,
4727 reconnect_delay: 0,
4728 signature_gen: sig_gen,
4729 timeout: 1000,
4730 time_unit: None,
4731 agent: None,
4732 user_agent: build_user_agent("product"),
4733 };
4734 let conn = WebsocketConnection::new("test");
4735 let api = WebsocketApi::new(config, vec![conn.clone()]);
4736 (api, conn)
4737 }
4738
4739 #[test]
4740 fn resolves_pending_and_removes_request() {
4741 TOKIO_SHARED_RT.block_on(async {
4742 let (api, conn) = create_websocket_api_and_conn();
4743 let (tx, rx) = oneshot::channel();
4744 {
4745 let mut st = conn.state.lock().await;
4746 st.pending_requests
4747 .insert("id1".to_string(), PendingRequest { completion: tx });
4748 }
4749 let msg = json!({"id":"id1","status":200,"foo":"bar"});
4750 api.on_message(msg.to_string(), conn.clone()).await;
4751 let got = rx.await.unwrap().unwrap();
4752 assert_eq!(got, msg);
4753 let st = conn.state.lock().await;
4754 assert!(!st.pending_requests.contains_key("id1"));
4755 });
4756 }
4757
4758 #[test]
4759 fn uses_result_when_present() {
4760 TOKIO_SHARED_RT.block_on(async {
4761 let (api, conn) = create_websocket_api_and_conn();
4762 let (tx, rx) = oneshot::channel();
4763 {
4764 let mut st = conn.state.lock().await;
4765 st.pending_requests
4766 .insert("id1".to_string(), PendingRequest { completion: tx });
4767 }
4768 let msg = json!({
4769 "id": "id1",
4770 "status": 200,
4771 "response": [1,2],
4772 "result": {"a":1}
4773 });
4774 api.on_message(msg.to_string(), conn.clone()).await;
4775 let got = rx.await.unwrap().unwrap();
4776 assert_eq!(got.get("result").unwrap(), &json!({"a":1}));
4777 });
4778 }
4779
4780 #[test]
4781 fn uses_response_when_no_result() {
4782 TOKIO_SHARED_RT.block_on(async {
4783 let (api, conn) = create_websocket_api_and_conn();
4784 let (tx, rx) = oneshot::channel();
4785 {
4786 let mut st = conn.state.lock().await;
4787 st.pending_requests
4788 .insert("id1".to_string(), PendingRequest { completion: tx });
4789 }
4790 let msg = json!({
4791 "id": "id1",
4792 "status": 200,
4793 "response": ["ok"]
4794 });
4795 api.on_message(msg.to_string(), conn.clone()).await;
4796 let got = rx.await.unwrap().unwrap();
4797 assert_eq!(got.get("response").unwrap(), &json!(["ok"]));
4798 });
4799 }
4800
4801 #[test]
4802 fn errors_for_status_ge_400() {
4803 TOKIO_SHARED_RT.block_on(async {
4804 let (api, conn) = create_websocket_api_and_conn();
4805 let (tx, rx) = oneshot::channel();
4806 {
4807 let mut st = conn.state.lock().await;
4808 st.pending_requests
4809 .insert("bad".to_string(), PendingRequest { completion: tx });
4810 }
4811 let err_obj = json!({"code":123,"msg":"oops"});
4812 let msg = json!({"id":"bad","status":500,"error":err_obj});
4813 api.on_message(msg.to_string(), conn.clone()).await;
4814 match rx.await.unwrap() {
4815 Err(WebsocketError::ResponseError { code, message }) => {
4816 assert_eq!(code, 123);
4817 assert_eq!(message, "oops");
4818 }
4819 other => panic!("expected ResponseError, got {other:?}"),
4820 }
4821 let st = conn.state.lock().await;
4822 assert!(!st.pending_requests.contains_key("bad"));
4823 });
4824 }
4825
4826 #[test]
4827 fn ignores_unknown_id() {
4828 TOKIO_SHARED_RT.block_on(async {
4829 let (api, conn) = create_websocket_api_and_conn();
4830 let msg = json!({"id":"nope","status":200});
4831 api.on_message(msg.to_string(), conn.clone()).await;
4832 let st = conn.state.lock().await;
4833 assert!(st.pending_requests.is_empty());
4834 });
4835 }
4836
4837 #[test]
4838 fn parse_error_ignored() {
4839 TOKIO_SHARED_RT.block_on(async {
4840 let (api, conn) = create_websocket_api_and_conn();
4841 api.on_message("not json".to_string(), conn.clone()).await;
4842 let st = conn.state.lock().await;
4843 assert!(st.pending_requests.is_empty());
4844 });
4845 }
4846
4847 #[test]
4848 fn error_status_sends_error() {
4849 TOKIO_SHARED_RT.block_on(async {
4850 let (api, conn) = create_websocket_api_and_conn();
4851 let (tx, rx) = oneshot::channel();
4852 {
4853 let mut st = conn.state.lock().await;
4854 st.pending_requests
4855 .insert("err".to_string(), PendingRequest { completion: tx });
4856 }
4857 let msg = json!({
4858 "id": "err",
4859 "status": 500,
4860 "error": { "code": 42, "msg": "Bad!" }
4861 });
4862 api.on_message(msg.to_string(), conn.clone()).await;
4863 match rx.await.unwrap() {
4864 Err(WebsocketError::ResponseError { code, message }) => {
4865 assert_eq!(code, 42);
4866 assert_eq!(message, "Bad!");
4867 }
4868 other => panic!("expected ResponseError, got {other:?}"),
4869 }
4870 });
4871 }
4872
4873 #[test]
4874 fn unknown_id_logs_warning_and_leaves_pending() {
4875 TOKIO_SHARED_RT.block_on(async {
4876 let (api, conn) = create_websocket_api_and_conn();
4877 {
4878 let mut st = conn.state.lock().await;
4879 st.pending_requests.insert(
4880 "keep".to_string(),
4881 PendingRequest {
4882 completion: oneshot::channel().0,
4883 },
4884 );
4885 }
4886 api.on_message(
4887 json!({ "id": "foo", "status": 200, "result": 1 }).to_string(),
4888 conn.clone(),
4889 )
4890 .await;
4891 let st = conn.state.lock().await;
4892 assert!(st.pending_requests.contains_key("keep"));
4893 });
4894 }
4895 }
4896 }
4897
4898 mod websocket_streams {
4899 use super::*;
4900
4901 mod initialisation {
4902 use super::*;
4903
4904 #[test]
4905 fn new_initializes_fields() {
4906 TOKIO_SHARED_RT.block_on(async {
4907 let sig_gen = SignatureGenerator::new(
4908 Some("api_secret".to_string()),
4909 None::<PrivateKey>,
4910 None::<String>,
4911 );
4912 let config = ConfigurationWebsocketApi {
4913 api_key: Some("api_key".to_string()),
4914 api_secret: Some("api_secret".to_string()),
4915 private_key: None,
4916 private_key_passphrase: None,
4917 ws_url: Some("wss://example".to_string()),
4918 mode: WebsocketMode::Single,
4919 reconnect_delay: 1000,
4920 signature_gen: sig_gen.clone(),
4921 timeout: 500,
4922 time_unit: None,
4923 agent: None,
4924 user_agent: build_user_agent("product"),
4925 };
4926 let conn1 = WebsocketConnection::new("c1");
4927 let conn2 = WebsocketConnection::new("c2");
4928 let api = WebsocketApi::new(config.clone(), vec![conn1.clone(), conn2.clone()]);
4929
4930 assert_eq!(api.common.connection_pool.len(), 2);
4931 assert!(Arc::ptr_eq(&api.common.connection_pool[0], &conn1));
4932 assert!(Arc::ptr_eq(&api.common.connection_pool[1], &conn2));
4933 assert_eq!(api.configuration.ws_url, Some("wss://example".to_string()));
4934 let flag = api.is_connecting.lock().await;
4935 assert!(!*flag);
4936 });
4937 }
4938 }
4939
4940 mod connect {
4941 use super::*;
4942
4943 #[test]
4944 fn establishes_successfully() {
4945 TOKIO_SHARED_RT.block_on(async {
4946 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
4947 let port = listener.local_addr().unwrap().port();
4948
4949 tokio::spawn(async move {
4950 for _ in 0..2 {
4951 if let Ok((stream, _)) = listener.accept().await {
4952 let mut ws = accept_async(stream).await.unwrap();
4953 ws.close(None).await.ok();
4954 }
4955 }
4956 });
4957
4958 let create_websocket_streams = |ws_url: &str| {
4959 let c1 = WebsocketConnection::new("c1");
4960 let c2 = WebsocketConnection::new("c2");
4961 let config = ConfigurationWebsocketStreams {
4962 ws_url: Some(ws_url.to_string()),
4963 mode: WebsocketMode::Pool(2),
4964 reconnect_delay: 500,
4965 time_unit: None,
4966 agent: None,
4967 user_agent: build_user_agent("product"),
4968 };
4969 WebsocketStreams::new(config, vec![c1, c2])
4970 };
4971
4972 let url = format!("ws://127.0.0.1:{port}");
4973 let ws = create_websocket_streams(&url);
4974
4975 let res = ws.connect(vec!["stream1".into()]).await;
4976 assert!(res.is_ok());
4977 });
4978 }
4979
4980 #[test]
4981 fn refused_returns_error() {
4982 TOKIO_SHARED_RT.block_on(async {
4983 let ws = create_websocket_streams(Some("ws://127.0.0.1:9"), None);
4984 let res = ws.connect(vec!["stream1".into()]).await;
4985 assert!(res.is_err());
4986 });
4987 }
4988
4989 #[test]
4990 fn invalid_url_returns_error() {
4991 TOKIO_SHARED_RT.block_on(async {
4992 let ws = create_websocket_streams(Some("not-a-url"), None);
4993 let res = ws.connect(vec!["s".into()]).await;
4994 assert!(res.is_err());
4995 });
4996 }
4997 }
4998
4999 mod disconnect {
5000 use super::*;
5001
5002 #[test]
5003 fn disconnect_clears_state_and_streams() {
5004 TOKIO_SHARED_RT.block_on(async {
5005 let ws = create_websocket_streams(None, None);
5006 let conn = &ws.common.connection_pool[0];
5007 {
5008 let mut state = conn.state.lock().await;
5009 state.stream_callbacks.insert("s1".to_string(), Vec::new());
5010 state.pending_subscriptions.push_back("s2".to_string());
5011 }
5012 {
5013 let mut map = ws.connection_streams.lock().await;
5014 map.insert("s3".to_string(), Arc::clone(conn));
5015 }
5016
5017 let res = ws.disconnect().await;
5018 assert!(res.is_ok());
5019
5020 let state = conn.state.lock().await;
5021 assert!(state.stream_callbacks.is_empty());
5022 assert!(state.pending_subscriptions.is_empty());
5023
5024 let map = ws.connection_streams.lock().await;
5025 assert!(map.is_empty());
5026 });
5027 }
5028 }
5029
5030 mod subscribe {
5031 use super::*;
5032
5033 #[test]
5034 fn empty_list_does_nothing() {
5035 TOKIO_SHARED_RT.block_on(async {
5036 let ws = create_websocket_streams(None, None);
5037 ws.clone().subscribe(Vec::new(), None).await;
5038 let map = ws.connection_streams.lock().await;
5039 assert!(map.is_empty());
5040 });
5041 }
5042
5043 #[test]
5044 fn queue_when_not_ready() {
5045 TOKIO_SHARED_RT.block_on(async {
5046 let ws = create_websocket_streams(None, None);
5047 let conn = ws.common.connection_pool[0].clone();
5048 ws.clone().subscribe(vec!["s1".into()], None).await;
5049 let state = conn.state.lock().await;
5050 let pending: Vec<String> =
5051 state.pending_subscriptions.iter().cloned().collect();
5052 assert_eq!(pending, vec!["s1".to_string()]);
5053 });
5054 }
5055
5056 #[test]
5057 fn only_one_subscription_per_stream() {
5058 TOKIO_SHARED_RT.block_on(async {
5059 let ws = create_websocket_streams(None, None);
5060 let conn = ws.common.connection_pool[0].clone();
5061 ws.clone().subscribe(vec!["s1".into()], None).await;
5062 ws.clone().subscribe(vec!["s1".into()], None).await;
5063 let state = conn.state.lock().await;
5064 let pending: Vec<String> =
5065 state.pending_subscriptions.iter().cloned().collect();
5066 assert_eq!(pending, vec!["s1".to_string()]);
5067 });
5068 }
5069
5070 #[test]
5071 fn multiple_streams_assigned() {
5072 TOKIO_SHARED_RT.block_on(async {
5073 let ws = create_websocket_streams(None, None);
5074 ws.clone()
5075 .subscribe(vec!["s1".into(), "s2".into()], None)
5076 .await;
5077 let map = ws.connection_streams.lock().await;
5078 assert!(map.contains_key("s1"));
5079 assert!(map.contains_key("s2"));
5080 });
5081 }
5082
5083 #[test]
5084 fn existing_stream_not_reassigned() {
5085 TOKIO_SHARED_RT.block_on(async {
5086 let ws = create_websocket_streams(None, None);
5087 ws.clone().subscribe(vec!["s1".into()], None).await;
5088 let first_id = {
5089 let map = ws.connection_streams.lock().await;
5090 map.get("s1").unwrap().id.clone()
5091 };
5092 ws.clone()
5093 .subscribe(vec!["s1".into(), "s2".into()], None)
5094 .await;
5095 let map = ws.connection_streams.lock().await;
5096 let second_id = map.get("s1").unwrap().id.clone();
5097 assert_eq!(first_id, second_id);
5098 assert!(map.contains_key("s2"));
5099 });
5100 }
5101 }
5102
5103 mod unsubscribe {
5104 use super::*;
5105
5106 #[test]
5107 fn removes_stream_with_no_callbacks() {
5108 TOKIO_SHARED_RT.block_on(async {
5109 let ws = create_websocket_streams(None, None);
5110 let conn = ws.common.connection_pool[0].clone();
5111
5112 {
5113 let (tx, _rx) = unbounded_channel::<Message>();
5114 let mut st = conn.state.lock().await;
5115 st.ws_write_tx = Some(tx);
5116 }
5117
5118 {
5119 let mut map = ws.connection_streams.lock().await;
5120 map.insert("s1".to_string(), conn.clone());
5121 }
5122 {
5123 let mut st = conn.state.lock().await;
5124 st.stream_callbacks.insert("s1".to_string(), Vec::new());
5125 }
5126
5127 ws.unsubscribe(vec!["s1".to_string()], None).await;
5128
5129 assert!(!ws.connection_streams.lock().await.contains_key("s1"));
5130 assert!(!conn.state.lock().await.stream_callbacks.contains_key("s1"));
5131 });
5132 }
5133
5134 #[test]
5135 fn preserves_stream_with_callbacks() {
5136 TOKIO_SHARED_RT.block_on(async {
5137 let ws = create_websocket_streams(None, None);
5138 let conn = ws.common.connection_pool[1].clone();
5139
5140 {
5141 let mut map = ws.connection_streams.lock().await;
5142 map.insert("s2".to_string(), conn.clone());
5143 }
5144 {
5145 let mut state = conn.state.lock().await;
5146 state
5147 .stream_callbacks
5148 .insert("s2".to_string(), vec![Arc::new(|_: &Value| {})]);
5149 }
5150
5151 ws.unsubscribe(vec!["s2".to_string()], None).await;
5152
5153 assert!(ws.connection_streams.lock().await.contains_key("s2"));
5154 assert!(conn.state.lock().await.stream_callbacks.contains_key("s2"));
5155 });
5156 }
5157
5158 #[test]
5159 fn does_not_send_if_callbacks_exist() {
5160 TOKIO_SHARED_RT.block_on(async {
5161 let ws = create_websocket_streams(None, None);
5162 let conn = ws.common.connection_pool[0].clone();
5163 {
5164 let mut map = ws.connection_streams.lock().await;
5165 map.insert("s1".to_string(), conn.clone());
5166 }
5167 {
5168 let mut state = conn.state.lock().await;
5169 state.stream_callbacks.insert(
5170 "s1".to_string(),
5171 vec![Arc::new(|_: &Value| {}), Arc::new(|_: &Value| {})],
5172 );
5173 }
5174 ws.unsubscribe(vec!["s1".into()], None).await;
5175 assert!(ws.connection_streams.lock().await.contains_key("s1"));
5176 assert!(conn.state.lock().await.stream_callbacks.contains_key("s1"));
5177 });
5178 }
5179
5180 #[test]
5181 fn warns_if_not_associated() {
5182 TOKIO_SHARED_RT.block_on(async {
5183 let ws = create_websocket_streams(None, None);
5184 ws.unsubscribe(vec!["nope".into()], None).await;
5185 });
5186 }
5187
5188 #[test]
5189 fn empty_list_does_nothing() {
5190 TOKIO_SHARED_RT.block_on(async {
5191 let ws = create_websocket_streams(None, None);
5192 let before = ws.connection_streams.lock().await.len();
5193 ws.unsubscribe(Vec::<String>::new(), None).await;
5194 let after = ws.connection_streams.lock().await.len();
5195 assert_eq!(before, after);
5196 });
5197 }
5198
5199 #[test]
5200 fn invalid_custom_id_falls_back() {
5201 TOKIO_SHARED_RT.block_on(async {
5202 let ws = create_websocket_streams(None, None);
5203 let conn = ws.common.connection_pool[0].clone();
5204 {
5205 let mut map = ws.connection_streams.lock().await;
5206 map.insert("foo".to_string(), conn.clone());
5207 }
5208 {
5209 let mut state = conn.state.lock().await;
5210 let (tx, _rx) = unbounded_channel();
5211 state.ws_write_tx = Some(tx);
5212 state.stream_callbacks.insert("foo".to_string(), Vec::new());
5213 }
5214 ws.unsubscribe(vec!["foo".into()], Some("bad-id".into()))
5215 .await;
5216 assert!(!ws.connection_streams.lock().await.contains_key("foo"));
5217 });
5218 }
5219
5220 #[test]
5221 fn removes_even_without_write_channel() {
5222 TOKIO_SHARED_RT.block_on(async {
5223 let ws = create_websocket_streams(None, None);
5224 let conn = ws.common.connection_pool[0].clone();
5225 {
5226 let mut map = ws.connection_streams.lock().await;
5227 map.insert("x".to_string(), conn.clone());
5228 }
5229 {
5230 let mut state = conn.state.lock().await;
5231 let (tx, _rx) = unbounded_channel();
5232 state.ws_write_tx = Some(tx);
5233 state.stream_callbacks.insert("x".to_string(), Vec::new());
5234 }
5235 ws.unsubscribe(vec!["x".into()], None).await;
5236 assert!(!ws.connection_streams.lock().await.contains_key("x"));
5237 });
5238 }
5239 }
5240
5241 mod is_subscribed {
5242 use super::*;
5243
5244 #[test]
5245 fn returns_false_when_not_subscribed() {
5246 TOKIO_SHARED_RT.block_on(async {
5247 let ws = create_websocket_streams(None, None);
5248 assert!(!ws.is_subscribed("unknown").await);
5249 });
5250 }
5251
5252 #[test]
5253 fn returns_true_when_subscribed() {
5254 TOKIO_SHARED_RT.block_on(async {
5255 let ws = create_websocket_streams(None, None);
5256 let conn = ws.common.connection_pool[0].clone();
5257 {
5258 let mut map = ws.connection_streams.lock().await;
5259 map.insert("stream1".to_string(), conn);
5260 }
5261 assert!(ws.is_subscribed("stream1").await);
5262 });
5263 }
5264 }
5265
5266 mod prepare_url {
5267 use super::*;
5268
5269 #[test]
5270 fn without_time_unit_returns_base_url() {
5271 TOKIO_SHARED_RT.block_on(async {
5272 let conns = vec![
5273 WebsocketConnection::new("c1"),
5274 WebsocketConnection::new("c2"),
5275 ];
5276 let config = ConfigurationWebsocketStreams {
5277 ws_url: Some("wss://example".to_string()),
5278 mode: WebsocketMode::Single,
5279 reconnect_delay: 100,
5280 time_unit: None,
5281 agent: None,
5282 user_agent: build_user_agent("product"),
5283 };
5284 let ws = WebsocketStreams::new(config, conns);
5285 let url = ws.prepare_url(&["s1".into(), "s2".into()]);
5286 assert_eq!(url, "wss://example/stream?streams=s1/s2");
5287 });
5288 }
5289
5290 #[test]
5291 fn with_time_unit_appends_parameter() {
5292 TOKIO_SHARED_RT.block_on(async {
5293 let conns = vec![WebsocketConnection::new("c1")];
5294 let config = ConfigurationWebsocketStreams {
5295 ws_url: Some("wss://example".to_string()),
5296 mode: WebsocketMode::Single,
5297 reconnect_delay: 100,
5298 time_unit: Some(TimeUnit::Millisecond),
5299 agent: None,
5300 user_agent: build_user_agent("product"),
5301 };
5302 let ws = WebsocketStreams::new(config, conns);
5303 let url = ws.prepare_url(&["a".into()]);
5304 assert_eq!(url, "wss://example/stream?streams=a&timeUnit=millisecond");
5305 });
5306 }
5307
5308 #[test]
5309 fn multiple_streams_and_time_unit() {
5310 TOKIO_SHARED_RT.block_on(async {
5311 let conns = vec![WebsocketConnection::new("c1")];
5312 let config = ConfigurationWebsocketStreams {
5313 ws_url: Some("wss://example".to_string()),
5314 mode: WebsocketMode::Single,
5315 reconnect_delay: 100,
5316 time_unit: Some(TimeUnit::Microsecond),
5317 agent: None,
5318 user_agent: build_user_agent("product"),
5319 };
5320 let ws = WebsocketStreams::new(config, conns);
5321 let url = ws.prepare_url(&["x".into(), "y".into(), "z".into()]);
5322 assert_eq!(
5323 url,
5324 "wss://example/stream?streams=x/y/z&timeUnit=microsecond"
5325 );
5326 });
5327 }
5328 }
5329
5330 mod handle_stream_assignment {
5331 use super::*;
5332
5333 #[test]
5334 fn assigns_new_streams_to_connections() {
5335 TOKIO_SHARED_RT.block_on(async {
5336 let ws = create_websocket_streams(None, None);
5337 let groups = ws
5338 .clone()
5339 .handle_stream_assignment(vec!["s1".into(), "s2".into()])
5340 .await;
5341 let mut seen_streams = HashSet::new();
5342 for (_conn, streams) in &groups {
5343 for s in streams {
5344 seen_streams.insert(s);
5345 }
5346 }
5347 assert_eq!(
5348 seen_streams,
5349 ["s1".to_string(), "s2".to_string()].iter().collect()
5350 );
5351 assert_eq!(groups.len(), 1);
5352 });
5353 }
5354
5355 #[test]
5356 fn reuses_existing_connection_for_duplicate_stream() {
5357 TOKIO_SHARED_RT.block_on(async {
5358 let ws = create_websocket_streams(None, None);
5359 let _ = ws.clone().handle_stream_assignment(vec!["s1".into()]).await;
5360 let groups = ws
5361 .clone()
5362 .handle_stream_assignment(vec!["s1".into(), "s3".into()])
5363 .await;
5364 let mut all_streams = Vec::new();
5365 for (_conn, streams) in groups {
5366 all_streams.extend(streams);
5367 }
5368 all_streams.sort();
5369 assert_eq!(all_streams, vec!["s1".to_string(), "s3".to_string()]);
5370 });
5371 }
5372
5373 #[test]
5374 fn empty_stream_list_returns_empty() {
5375 TOKIO_SHARED_RT.block_on(async {
5376 let ws = create_websocket_streams(None, None);
5377 let groups = ws.clone().handle_stream_assignment(vec![]).await;
5378 assert!(groups.is_empty());
5379 });
5380 }
5381
5382 #[test]
5383 fn closed_or_reconnecting_forces_reassignment_of_stream() {
5384 TOKIO_SHARED_RT.block_on(async {
5385 let ws = create_websocket_streams(None, None);
5386 let mut groups = ws.clone().handle_stream_assignment(vec!["s1".into()]).await;
5387 let (conn, _) = groups.pop().unwrap();
5388 {
5389 let mut st = conn.state.lock().await;
5390 st.close_initiated = true;
5391 }
5392 let groups2 = ws.clone().handle_stream_assignment(vec!["s2".into()]).await;
5393 assert_eq!(groups2.len(), 1);
5394 let (_new_conn, streams) = &groups2[0];
5395 assert_eq!(streams, &vec!["s2".to_string()]);
5396 });
5397 }
5398
5399 #[test]
5400 fn no_available_connections_falls_back_to_one() {
5401 TOKIO_SHARED_RT.block_on(async {
5402 let ws = create_websocket_streams(None, Some(vec![]));
5403 let assigned = ws.handle_stream_assignment(vec!["foo".into()]).await;
5404 assert_eq!(assigned.len(), 1);
5405 let (_conn, streams) = &assigned[0];
5406 assert_eq!(streams.as_slice(), &["foo".to_string()]);
5407 });
5408 }
5409
5410 #[test]
5411 fn single_connection_groups_multiple_streams() {
5412 TOKIO_SHARED_RT.block_on(async {
5413 let conn = WebsocketConnection::new("c1");
5414 let ws = create_websocket_streams(None, Some(vec![conn.clone()]));
5415 let assigned = ws
5416 .handle_stream_assignment(vec!["s1".into(), "s2".into()])
5417 .await;
5418 assert_eq!(assigned.len(), 1);
5419 let (assigned_conn, streams) = &assigned[0];
5420 assert!(Arc::ptr_eq(assigned_conn, &conn));
5421 assert_eq!(streams.len(), 2);
5422 assert!(streams.contains(&"s1".to_string()));
5423 assert!(streams.contains(&"s2".to_string()));
5424 });
5425 }
5426
5427 #[test]
5428 fn reuse_existing_healthy_connection() {
5429 TOKIO_SHARED_RT.block_on(async {
5430 let conn = WebsocketConnection::new("c");
5431 let ws = create_websocket_streams(None, Some(vec![conn.clone()]));
5432 let _ = ws.handle_stream_assignment(vec!["s1".into()]).await;
5433 let second = ws.handle_stream_assignment(vec!["s1".into()]).await;
5434 assert_eq!(second.len(), 1);
5435 let (assigned_conn, streams) = &second[0];
5436 assert!(Arc::ptr_eq(assigned_conn, &conn));
5437 assert_eq!(streams.as_slice(), &["s1".to_string()]);
5438 });
5439 }
5440
5441 #[test]
5442 fn mix_new_and_assigned_streams() {
5443 TOKIO_SHARED_RT.block_on(async {
5444 let conn = WebsocketConnection::new("c");
5445 let ws = create_websocket_streams(None, Some(vec![conn.clone()]));
5446 let _ = ws
5447 .handle_stream_assignment(vec!["s1".into(), "s2".into()])
5448 .await;
5449 let mixed = ws
5450 .handle_stream_assignment(vec!["s2".into(), "s3".into()])
5451 .await;
5452 assert_eq!(mixed.len(), 1);
5453 let (assigned_conn, streams) = &mixed[0];
5454 assert!(Arc::ptr_eq(assigned_conn, &conn));
5455 let mut got = streams.clone();
5456 got.sort();
5457 assert_eq!(got, vec!["s2".to_string(), "s3".to_string()]);
5458 });
5459 }
5460 }
5461
5462 mod send_subscription_payload {
5463 use super::*;
5464
5465 #[test]
5466 fn subscribe_payload_with_custom_id_fallbacks_if_invalid() {
5467 TOKIO_SHARED_RT.block_on(async {
5468 let ws: Arc<WebsocketStreams> =
5469 create_websocket_streams(Some("ws://example.com"), None);
5470 let conn = &ws.common.connection_pool[0];
5471 let (tx, mut rx) = unbounded_channel();
5472 {
5473 let mut st = conn.state.lock().await;
5474 st.ws_write_tx = Some(tx);
5475 }
5476 ws.send_subscription_payload(
5477 conn,
5478 &vec!["s1".to_string()],
5479 Some("badid".to_string()),
5480 );
5481 let msg = rx.recv().await.expect("no message sent");
5482 if let Message::Text(txt) = msg {
5483 let v: serde_json::Value = serde_json::from_str(&txt).unwrap();
5484 assert_eq!(v["method"], "SUBSCRIBE");
5485 let id = v["id"].as_str().unwrap();
5486 assert_ne!(id, "badid");
5487 assert!(Regex::new(r"^[0-9a-fA-F]{32}$").unwrap().is_match(id));
5488 } else {
5489 panic!("unexpected message: {msg:?}");
5490 }
5491 });
5492 }
5493
5494 #[test]
5495 fn subscribe_payload_with_and_without_custom_id() {
5496 TOKIO_SHARED_RT.block_on(async {
5497 let ws: Arc<WebsocketStreams> =
5498 create_websocket_streams(Some("ws://unused"), None);
5499 let conn = &ws.common.connection_pool[0];
5500 let (tx, mut rx) = unbounded_channel();
5501 {
5502 let mut st = conn.state.lock().await;
5503 st.ws_write_tx = Some(tx);
5504 }
5505 ws.send_subscription_payload(
5506 conn,
5507 &vec!["a".to_string(), "b".to_string()],
5508 Some("deadbeefdeadbeefdeadbeefdeadbeef".to_string()),
5509 );
5510 let msg1 = rx.recv().await.unwrap();
5511 ws.send_subscription_payload(conn, &vec!["x".to_string()], None);
5512 let msg2 = rx.recv().await.unwrap();
5513
5514 if let Message::Text(txt1) = msg1 {
5515 let v1: serde_json::Value = serde_json::from_str(&txt1).unwrap();
5516 assert_eq!(v1["id"], "deadbeefdeadbeefdeadbeefdeadbeef");
5517 assert_eq!(
5518 v1["params"].as_array().unwrap(),
5519 &vec![serde_json::json!("a"), serde_json::json!("b")]
5520 );
5521 } else {
5522 panic!()
5523 }
5524
5525 if let Message::Text(txt2) = msg2 {
5526 let v2: serde_json::Value = serde_json::from_str(&txt2).unwrap();
5527 assert_eq!(v2["method"], "SUBSCRIBE");
5528 let params = v2["params"].as_array().unwrap();
5529 assert_eq!(params.len(), 1);
5530 assert_eq!(params[0], "x");
5531 let id2 = v2["id"].as_str().unwrap();
5532 assert!(Regex::new(r"^[0-9a-fA-F]{32}$").unwrap().is_match(id2));
5533 } else {
5534 panic!()
5535 }
5536 });
5537 }
5538 }
5539
5540 mod on_open {
5541 use super::*;
5542
5543 #[test]
5544 fn sends_pending_subscriptions() {
5545 TOKIO_SHARED_RT.block_on(async {
5546 let ws: Arc<WebsocketStreams> =
5547 create_websocket_streams(Some("ws://example.com"), None);
5548 let conn = &ws.common.connection_pool[0];
5549 let (tx, mut rx) = unbounded_channel();
5550 {
5551 let mut st = conn.state.lock().await;
5552 st.ws_write_tx = Some(tx);
5553 st.pending_subscriptions.push_back("foo".to_string());
5554 st.pending_subscriptions.push_back("bar".to_string());
5555 }
5556 ws.on_open("ws://example.com".to_string(), conn.clone())
5557 .await;
5558 let msg = rx.recv().await.expect("no subscription sent");
5559 if let Message::Text(txt) = msg {
5560 let v: Value = serde_json::from_str(&txt).unwrap();
5561 assert_eq!(v["method"], "SUBSCRIBE");
5562 let params = v["params"].as_array().unwrap();
5563 assert_eq!(
5564 params,
5565 &vec![Value::String("foo".into()), Value::String("bar".into())]
5566 );
5567 } else {
5568 panic!("unexpected message: {msg:?}");
5569 }
5570 let st_after = conn.state.lock().await;
5571 assert!(st_after.pending_subscriptions.is_empty());
5572 });
5573 }
5574
5575 #[test]
5576 fn with_no_pending_subscriptions_sends_nothing() {
5577 TOKIO_SHARED_RT.block_on(async {
5578 let ws: Arc<WebsocketStreams> =
5579 create_websocket_streams(Some("ws://example.com"), None);
5580 let conn = &ws.common.connection_pool[0];
5581 let (tx, mut rx) = unbounded_channel();
5582 {
5583 let mut st = conn.state.lock().await;
5584 st.ws_write_tx = Some(tx);
5585 }
5586 ws.on_open("ws://example.com".to_string(), conn.clone())
5587 .await;
5588 assert!(rx.try_recv().is_err(), "unexpected message sent");
5589 });
5590 }
5591
5592 #[test]
5593 fn clears_pending_without_write_channel() {
5594 TOKIO_SHARED_RT.block_on(async {
5595 let ws: Arc<WebsocketStreams> =
5596 create_websocket_streams(Some("ws://example.com"), None);
5597 let conn = &ws.common.connection_pool[0];
5598 {
5599 let mut st = conn.state.lock().await;
5600 st.pending_subscriptions.push_back("solo".to_string());
5601 }
5602 ws.on_open("ws://example.com".to_string(), conn.clone())
5603 .await;
5604 let st_after = conn.state.lock().await;
5605 assert!(st_after.pending_subscriptions.is_empty());
5606 });
5607 }
5608 }
5609
5610 mod on_message {
5611 use super::*;
5612
5613 #[test]
5614 fn invokes_registered_callback() {
5615 TOKIO_SHARED_RT.block_on(async {
5616 let ws: Arc<WebsocketStreams> =
5617 create_websocket_streams(Some("ws://example.com"), None);
5618 let conn = &ws.common.connection_pool[0];
5619 let called = Arc::new(AtomicBool::new(false));
5620 let called_clone = called.clone();
5621
5622 {
5623 let mut st = conn.state.lock().await;
5624 st.stream_callbacks
5625 .entry("stream1".to_string())
5626 .or_default()
5627 .push(
5628 (Box::new(move |_: &Value| {
5629 called_clone.store(true, Ordering::SeqCst);
5630 })
5631 as Box<dyn Fn(&Value) + Send + Sync>)
5632 .into(),
5633 );
5634 }
5635
5636 let msg = json!({
5637 "stream": "stream1",
5638 "data": { "key": "value" }
5639 })
5640 .to_string();
5641
5642 ws.on_message(msg, conn.clone()).await;
5643
5644 assert!(called.load(Ordering::SeqCst));
5645 });
5646 }
5647
5648 #[test]
5649 fn invokes_all_registered_callbacks() {
5650 TOKIO_SHARED_RT.block_on(async {
5651 let ws: Arc<WebsocketStreams> =
5652 create_websocket_streams(Some("ws://example.com"), None);
5653 let conn = &ws.common.connection_pool[0];
5654 let counter = Arc::new(AtomicUsize::new(0));
5655
5656 {
5657 let mut st = conn.state.lock().await;
5658 let entry = st.stream_callbacks.entry("s".into()).or_default();
5659 let c1 = counter.clone();
5660 entry.push(
5661 (Box::new(move |_: &Value| {
5662 c1.fetch_add(1, Ordering::SeqCst);
5663 }) as Box<dyn Fn(&Value) + Send + Sync>)
5664 .into(),
5665 );
5666 let c2 = counter.clone();
5667 entry.push(
5668 (Box::new(move |_: &Value| {
5669 c2.fetch_add(1, Ordering::SeqCst);
5670 }) as Box<dyn Fn(&Value) + Send + Sync>)
5671 .into(),
5672 );
5673 }
5674
5675 let msg = json!({"stream":"s","data":42}).to_string();
5676 ws.on_message(msg, conn.clone()).await;
5677
5678 assert_eq!(counter.load(Ordering::SeqCst), 2);
5679 });
5680 }
5681
5682 #[test]
5683 fn handles_null_data_field() {
5684 TOKIO_SHARED_RT.block_on(async {
5685 let ws: Arc<WebsocketStreams> =
5686 create_websocket_streams(Some("ws://example.com"), None);
5687 let conn = &ws.common.connection_pool[0];
5688 let called = Arc::new(AtomicUsize::new(0));
5689 {
5690 let mut st = conn.state.lock().await;
5691 st.stream_callbacks.entry("n".into()).or_default().push(
5692 (Box::new({
5693 let c = called.clone();
5694 move |data: &Value| {
5695 if data.is_null() {
5696 c.fetch_add(1, Ordering::SeqCst);
5697 }
5698 }
5699 }) as Box<dyn Fn(&Value) + Send + Sync>)
5700 .into(),
5701 );
5702 }
5703 let msg = json!({"stream":"n","data":null}).to_string();
5704 ws.on_message(msg, conn.clone()).await;
5705 assert_eq!(called.load(Ordering::SeqCst), 1);
5706 });
5707 }
5708
5709 #[test]
5710 fn with_invalid_json_does_not_panic() {
5711 TOKIO_SHARED_RT.block_on(async {
5712 let ws: Arc<WebsocketStreams> =
5713 create_websocket_streams(Some("ws://example.com"), None);
5714 let conn = &ws.common.connection_pool[0];
5715 let bad = "not a json";
5716 ws.on_message(bad.to_string(), conn.clone()).await;
5717 });
5718 }
5719
5720 #[test]
5721 fn without_stream_field_does_nothing() {
5722 TOKIO_SHARED_RT.block_on(async {
5723 let ws: Arc<WebsocketStreams> =
5724 create_websocket_streams(Some("ws://example.com"), None);
5725 let conn = &ws.common.connection_pool[0];
5726 let msg = json!({ "data": { "foo": 1 } }).to_string();
5727 ws.on_message(msg, conn.clone()).await;
5728 });
5729 }
5730
5731 #[test]
5732 fn with_unregistered_stream_does_not_panic() {
5733 TOKIO_SHARED_RT.block_on(async {
5734 let ws: Arc<WebsocketStreams> =
5735 create_websocket_streams(Some("ws://example.com"), None);
5736 let conn = &ws.common.connection_pool[0];
5737 let msg = json!({
5738 "stream": "nope",
5739 "data": { "foo": 1 }
5740 })
5741 .to_string();
5742 ws.on_message(msg, conn.clone()).await;
5743 });
5744 }
5745 }
5746
5747 mod get_reconnect_url {
5748 use super::*;
5749
5750 #[test]
5751 fn single_stream_reconnect_url() {
5752 TOKIO_SHARED_RT.block_on(async {
5753 let ws: Arc<WebsocketStreams> =
5754 create_websocket_streams(Some("ws://example.com"), None);
5755 let c0 = ws.common.connection_pool[0].clone();
5756 {
5757 let mut map = ws.connection_streams.lock().await;
5758 map.insert("s1".to_string(), c0.clone());
5759 }
5760 let url = ws.get_reconnect_url("default_url".into(), c0).await;
5761 assert_eq!(url, "ws://example.com/stream?streams=s1");
5762 });
5763 }
5764
5765 #[test]
5766 fn multiple_streams_same_connection() {
5767 TOKIO_SHARED_RT.block_on(async {
5768 let ws: Arc<WebsocketStreams> =
5769 create_websocket_streams(Some("ws://example.com"), None);
5770 let c0 = ws.common.connection_pool[0].clone();
5771 {
5772 let mut map = ws.connection_streams.lock().await;
5773 map.insert("a".to_string(), c0.clone());
5774 map.insert("b".to_string(), c0.clone());
5775 }
5776 let url = ws.get_reconnect_url("default_url".into(), c0).await;
5777 let suffix = url
5778 .strip_prefix("ws://example.com/stream?streams=")
5779 .unwrap();
5780 let parts: Vec<_> = suffix.split('&').next().unwrap().split('/').collect();
5781 let set = parts.into_iter().collect::<std::collections::HashSet<_>>();
5782 assert_eq!(set, ["a", "b"].iter().copied().collect());
5783 });
5784 }
5785
5786 #[test]
5787 fn reconnect_url_with_time_unit() {
5788 TOKIO_SHARED_RT.block_on(async {
5789 let mut ws: Arc<WebsocketStreams> =
5790 create_websocket_streams(Some("ws://example.com"), None);
5791 Arc::get_mut(&mut ws).unwrap().configuration.time_unit =
5792 Some(TimeUnit::Microsecond);
5793 let c0 = ws.common.connection_pool[0].clone();
5794 {
5795 let mut map = ws.connection_streams.lock().await;
5796 map.insert("x".to_string(), c0.clone());
5797 }
5798 let url = ws.get_reconnect_url("default_url".into(), c0).await;
5799 assert_eq!(
5800 url,
5801 "ws://example.com/stream?streams=x&timeUnit=microsecond"
5802 );
5803 });
5804 }
5805 }
5806 }
5807
5808 mod websocket_stream {
5809 use super::*;
5810
5811 mod on {
5812 use super::*;
5813
5814 #[test]
5815 fn registers_callback_and_stream_callback_for_websocket_streams() {
5816 TOKIO_SHARED_RT.block_on(async {
5817 let ws_base = create_websocket_streams(Some("example.com"), None);
5818 let stream_name = "s1".to_string();
5819 let conn = ws_base.common.connection_pool[0].clone();
5820 {
5821 let mut map = ws_base.connection_streams.lock().await;
5822 map.insert(stream_name.clone(), conn.clone());
5823 }
5824 {
5825 let mut state = conn.state.lock().await;
5826 state
5827 .stream_callbacks
5828 .insert(stream_name.clone(), Vec::new());
5829 }
5830 let stream = Arc::new(WebsocketStream::<Value> {
5831 websocket_base: WebsocketBase::WebsocketStreams(ws_base.clone()),
5832 stream_or_id: stream_name.clone(),
5833 callback: Mutex::new(None),
5834 id: None,
5835 _phantom: PhantomData,
5836 });
5837 let called = Arc::new(Mutex::new(false));
5838 let called_clone = called.clone();
5839 stream
5840 .on("message", move |v: Value| {
5841 let mut lock = called_clone.blocking_lock();
5842 *lock = v == Value::String("x".into());
5843 })
5844 .await;
5845 let cb_guard = stream.callback.lock().await;
5846 assert!(cb_guard.is_some());
5847 let cbs = {
5848 let state = conn.state.lock().await;
5849 state.stream_callbacks.get(&stream_name).unwrap().clone()
5850 };
5851 assert_eq!(cbs.len(), 1);
5852 });
5853 }
5854
5855 #[test]
5856 fn message_twice_registers_two_wrappers_for_websocket_streams() {
5857 TOKIO_SHARED_RT.block_on(async {
5858 let ws_base = create_websocket_streams(Some("example.com"), None);
5859 let stream_name = "s2".to_string();
5860 let conn = ws_base.common.connection_pool[0].clone();
5861 {
5862 let mut map = ws_base.connection_streams.lock().await;
5863 map.insert(stream_name.clone(), conn.clone());
5864 }
5865 {
5866 let mut state = conn.state.lock().await;
5867 state
5868 .stream_callbacks
5869 .insert(stream_name.clone(), Vec::new());
5870 }
5871 let stream = Arc::new(WebsocketStream::<Value> {
5872 websocket_base: WebsocketBase::WebsocketStreams(ws_base.clone()),
5873 stream_or_id: stream_name.clone(),
5874 callback: Mutex::new(None),
5875 id: None,
5876 _phantom: PhantomData,
5877 });
5878 stream.on("message", |_| {}).await;
5879 stream.on("message", |_| {}).await;
5880 let state = conn.state.lock().await;
5881 let callbacks = state.stream_callbacks.get(&stream_name).unwrap();
5882 assert_eq!(callbacks.len(), 2);
5883 });
5884 }
5885
5886 #[test]
5887 fn ignores_non_message_event_for_websocket_streams() {
5888 TOKIO_SHARED_RT.block_on(async {
5889 let ws_base = create_websocket_streams(Some("example.com"), None);
5890 let stream = Arc::new(WebsocketStream::<Value> {
5891 websocket_base: WebsocketBase::WebsocketStreams(ws_base.clone()),
5892 stream_or_id: "s".into(),
5893 callback: Mutex::new(None),
5894 id: None,
5895 _phantom: PhantomData,
5896 });
5897 stream.on("open", |_| {}).await;
5898 let guard = stream.callback.lock().await;
5899 assert!(guard.is_none());
5900 });
5901 }
5902
5903 #[test]
5904 fn registers_callback_and_stream_callback_for_websocket_api() {
5905 TOKIO_SHARED_RT.block_on(async {
5906 let ws_base = create_websocket_api(None);
5907
5908 {
5909 let mut stream_callbacks = ws_base.stream_callbacks.lock().await;
5910 stream_callbacks.insert("id1".to_string(), Vec::new());
5911 }
5912
5913 let stream = Arc::new(WebsocketStream::<Value> {
5914 websocket_base: WebsocketBase::WebsocketApi(ws_base.clone()),
5915 stream_or_id: "id1".to_string(),
5916 callback: Mutex::new(None),
5917 id: None,
5918 _phantom: PhantomData,
5919 });
5920
5921 let called = Arc::new(Mutex::new(false));
5922 let called_clone = called.clone();
5923 stream
5924 .on("message", move |v: Value| {
5925 let mut lock = called_clone.blocking_lock();
5926 *lock = v == Value::String("x".into());
5927 })
5928 .await;
5929
5930 let cb_guard = stream.callback.lock().await;
5931 assert!(cb_guard.is_some());
5932
5933 let stream_callbacks = ws_base.stream_callbacks.lock().await;
5934 let callbacks = stream_callbacks.get("id1").unwrap();
5935 assert_eq!(callbacks.len(), 1);
5936 });
5937 }
5938
5939 #[test]
5940 fn message_twice_registers_two_wrappers_for_websocket_api() {
5941 TOKIO_SHARED_RT.block_on(async {
5942 let ws_base = create_websocket_api(None);
5943
5944 let stream = Arc::new(WebsocketStream::<Value> {
5945 websocket_base: WebsocketBase::WebsocketApi(ws_base.clone()),
5946 stream_or_id: "id2".to_string(),
5947 callback: Mutex::new(None),
5948 id: None,
5949 _phantom: PhantomData,
5950 });
5951
5952 stream.on("message", |_| {}).await;
5953 stream.on("message", |_| {}).await;
5954
5955 let stream_callbacks = ws_base.stream_callbacks.lock().await;
5956 let callbacks = stream_callbacks.get("id2").unwrap();
5957 assert_eq!(callbacks.len(), 2);
5958 });
5959 }
5960
5961 #[test]
5962 fn ignores_non_message_event_for_websocket_api() {
5963 TOKIO_SHARED_RT.block_on(async {
5964 let ws_base = create_websocket_api(None);
5965
5966 let stream = Arc::new(WebsocketStream::<Value> {
5967 websocket_base: WebsocketBase::WebsocketApi(ws_base.clone()),
5968 stream_or_id: "id3".into(),
5969 callback: Mutex::new(None),
5970 id: None,
5971 _phantom: PhantomData,
5972 });
5973
5974 stream.on("open", |_| {}).await;
5975
5976 let guard = stream.callback.lock().await;
5977 assert!(guard.is_none());
5978
5979 let stream_callbacks = ws_base.stream_callbacks.lock().await;
5980 assert!(stream_callbacks.get("id3").is_none());
5981 assert!(stream_callbacks.is_empty());
5982 });
5983 }
5984 }
5985
5986 mod on_message {
5987 use super::*;
5988
5989 #[test]
5990 fn on_message_registers_callback_for_websocket_streams() {
5991 TOKIO_SHARED_RT.block_on(async {
5992 let ws_base = create_websocket_streams(Some("example.com"), None);
5993 let stream_name = "s".to_string();
5994 let conn = ws_base.common.connection_pool[0].clone();
5995 {
5996 let mut map = ws_base.connection_streams.lock().await;
5997 map.insert(stream_name.clone(), conn.clone());
5998 }
5999 {
6000 let mut state = conn.state.lock().await;
6001 state
6002 .stream_callbacks
6003 .insert(stream_name.clone(), Vec::new());
6004 }
6005 let stream = Arc::new(WebsocketStream::<Value> {
6006 websocket_base: WebsocketBase::WebsocketStreams(ws_base.clone()),
6007 stream_or_id: stream_name.clone(),
6008 callback: Mutex::new(None),
6009 id: None,
6010 _phantom: PhantomData,
6011 });
6012 stream.on_message(|_v| {});
6013 let callbacks = &conn.state.lock().await.stream_callbacks[&stream_name];
6014 assert_eq!(callbacks.len(), 1);
6015 });
6016 }
6017
6018 #[test]
6019 fn on_message_twice_registers_two_callbacks_for_websocket_streams() {
6020 TOKIO_SHARED_RT.block_on(async {
6021 let ws_base = create_websocket_streams(Some("example.com"), None);
6022 let stream_name = "s".to_string();
6023 let conn = ws_base.common.connection_pool[0].clone();
6024 {
6025 let mut map = ws_base.connection_streams.lock().await;
6026 map.insert(stream_name.clone(), conn.clone());
6027 }
6028 {
6029 let mut state = conn.state.lock().await;
6030 state
6031 .stream_callbacks
6032 .insert(stream_name.clone(), Vec::new());
6033 }
6034 let stream = Arc::new(WebsocketStream::<Value> {
6035 websocket_base: WebsocketBase::WebsocketStreams(ws_base.clone()),
6036 stream_or_id: stream_name.clone(),
6037 callback: Mutex::new(None),
6038 id: None,
6039 _phantom: PhantomData,
6040 });
6041 stream.on_message(|_v| {});
6042 stream.on_message(|_v| {});
6043 let callbacks = &conn.state.lock().await.stream_callbacks[&stream_name];
6044 assert_eq!(callbacks.len(), 2);
6045 });
6046 }
6047
6048 #[test]
6049 fn on_message_registers_callback_for_websocket_api() {
6050 TOKIO_SHARED_RT.block_on(async {
6051 let ws_base = create_websocket_api(None);
6052 let identifier = "id1".to_string();
6053
6054 let stream = Arc::new(WebsocketStream::<Value> {
6055 websocket_base: WebsocketBase::WebsocketApi(ws_base.clone()),
6056 stream_or_id: identifier.clone(),
6057 callback: Mutex::new(None),
6058 id: None,
6059 _phantom: PhantomData,
6060 });
6061
6062 stream.on_message(|_v: Value| {});
6063
6064 let stream_callbacks = ws_base.stream_callbacks.lock().await;
6065 let callbacks = stream_callbacks.get(&identifier).unwrap();
6066 assert_eq!(callbacks.len(), 1);
6067 });
6068 }
6069
6070 #[test]
6071 fn on_message_twice_registers_two_callbacks_for_websocket_api() {
6072 TOKIO_SHARED_RT.block_on(async {
6073 let ws_base = create_websocket_api(None);
6074 let identifier = "id2".to_string();
6075
6076 let stream = Arc::new(WebsocketStream::<Value> {
6077 websocket_base: WebsocketBase::WebsocketApi(ws_base.clone()),
6078 stream_or_id: identifier.clone(),
6079 callback: Mutex::new(None),
6080 id: None,
6081 _phantom: PhantomData,
6082 });
6083
6084 stream.on_message(|_v: Value| {});
6085 stream.on_message(|_v: Value| {});
6086
6087 let stream_callbacks = ws_base.stream_callbacks.lock().await;
6088 let callbacks = stream_callbacks.get(&identifier).unwrap();
6089 assert_eq!(callbacks.len(), 2);
6090 });
6091 }
6092 }
6093
6094 mod unsubscribe {
6095 use super::*;
6096
6097 #[test]
6098 fn without_callback_does_nothing() {
6099 TOKIO_SHARED_RT.block_on(async {
6100 let ws_base = create_websocket_streams(Some("example.com"), None);
6101 let stream_name = "s1".to_string();
6102 let conn = ws_base.common.connection_pool[0].clone();
6103 {
6104 let mut map = ws_base.connection_streams.lock().await;
6105 map.insert(stream_name.clone(), conn.clone());
6106 }
6107 let mut state = conn.state.lock().await;
6108 state.stream_callbacks.insert(stream_name.clone(), vec![]);
6109 drop(state);
6110 let stream = Arc::new(WebsocketStream::<Value> {
6111 websocket_base: WebsocketBase::WebsocketStreams(ws_base.clone()),
6112 stream_or_id: stream_name.clone(),
6113 callback: Mutex::new(None),
6114 id: None,
6115 _phantom: PhantomData,
6116 });
6117 stream.unsubscribe().await;
6118 let state = conn.state.lock().await;
6119 assert!(state.stream_callbacks.contains_key(&stream_name));
6120 });
6121 }
6122
6123 #[test]
6124 fn removes_registered_callback_and_clears_state() {
6125 TOKIO_SHARED_RT.block_on(async {
6126 let ws_base = create_websocket_streams(Some("example.com"), None);
6127 let stream_name = "s2".to_string();
6128 let conn = ws_base.common.connection_pool[0].clone();
6129 {
6130 let mut map = ws_base.connection_streams.lock().await;
6131 map.insert(stream_name.clone(), conn.clone());
6132 }
6133 {
6134 let mut state = conn.state.lock().await;
6135 state
6136 .stream_callbacks
6137 .insert(stream_name.clone(), Vec::new());
6138 }
6139 let stream = Arc::new(WebsocketStream::<Value> {
6140 websocket_base: WebsocketBase::WebsocketStreams(ws_base.clone()),
6141 stream_or_id: stream_name.clone(),
6142 callback: Mutex::new(None),
6143 id: None,
6144 _phantom: PhantomData,
6145 });
6146 stream.on("message", |_| {}).await;
6147 {
6148 let guard = stream.callback.lock().await;
6149 assert!(guard.is_some());
6150 }
6151 stream.unsubscribe().await;
6152 sleep(Duration::from_millis(10)).await;
6153 let guard = stream.callback.lock().await;
6154 assert!(guard.is_none());
6155 let state = conn.state.lock().await;
6156 assert!(
6157 state
6158 .stream_callbacks
6159 .get(&stream_name)
6160 .is_none_or(std::vec::Vec::is_empty)
6161 );
6162 });
6163 }
6164
6165 #[test]
6166 fn without_callback_does_nothing_for_websocket_api() {
6167 TOKIO_SHARED_RT.block_on(async {
6168 let ws_base = create_websocket_api(None);
6169 let identifier = "id1".to_string();
6170
6171 {
6172 let mut stream_callbacks = ws_base.stream_callbacks.lock().await;
6173 stream_callbacks.insert(identifier.clone(), Vec::new());
6174 }
6175
6176 let stream = Arc::new(WebsocketStream::<Value> {
6177 websocket_base: WebsocketBase::WebsocketApi(ws_base.clone()),
6178 stream_or_id: identifier.clone(),
6179 callback: Mutex::new(None),
6180 id: None,
6181 _phantom: PhantomData,
6182 });
6183
6184 stream.unsubscribe().await;
6185
6186 let stream_callbacks = ws_base.stream_callbacks.lock().await;
6187 assert!(stream_callbacks.contains_key(&identifier));
6188 let callbacks = stream_callbacks.get(&identifier).unwrap();
6189 assert!(callbacks.is_empty());
6190 });
6191 }
6192
6193 #[test]
6194 fn removes_registered_callback_and_clears_state_for_websocket_api() {
6195 TOKIO_SHARED_RT.block_on(async {
6196 let ws_base = create_websocket_api(None);
6197 let identifier = "id2".to_string();
6198
6199 {
6200 let mut stream_callbacks = ws_base.stream_callbacks.lock().await;
6201 stream_callbacks.insert(identifier.clone(), Vec::new());
6202 }
6203
6204 let stream = Arc::new(WebsocketStream::<Value> {
6205 websocket_base: WebsocketBase::WebsocketApi(ws_base.clone()),
6206 stream_or_id: identifier.clone(),
6207 callback: Mutex::new(None),
6208 id: None,
6209 _phantom: PhantomData,
6210 });
6211
6212 stream.on("message", |_| {}).await;
6213
6214 {
6215 let stream_callbacks = ws_base.stream_callbacks.lock().await;
6216 let callbacks = stream_callbacks
6217 .get(&identifier)
6218 .expect("Entry for 'id2' should exist");
6219 assert_eq!(callbacks.len(), 1);
6220 }
6221
6222 stream.unsubscribe().await;
6223
6224 {
6225 let guard = stream.callback.lock().await;
6226 assert!(guard.is_none());
6227 }
6228
6229 {
6230 let stream_callbacks = ws_base.stream_callbacks.lock().await;
6231 let callbacks = stream_callbacks
6232 .get(&identifier)
6233 .expect("Entry for 'id2' should still exist");
6234 assert!(callbacks.is_empty());
6235 }
6236 });
6237 }
6238 }
6239 }
6240
6241 mod create_stream_handler {
6242 use super::*;
6243
6244 #[test]
6245 fn create_stream_handler_without_id_registers_stream() {
6246 TOKIO_SHARED_RT.block_on(async {
6247 let ws = create_websocket_streams(Some("ws://example.com"), None);
6248 let stream_name = "foo".to_string();
6249 let handler = create_stream_handler::<serde_json::Value>(
6250 WebsocketBase::WebsocketStreams(ws.clone()),
6251 stream_name.clone(),
6252 None,
6253 )
6254 .await;
6255 assert_eq!(handler.stream_or_id, stream_name);
6256 assert!(handler.id.is_none());
6257 let map = ws.connection_streams.lock().await;
6258 assert!(map.contains_key(&stream_name));
6259 });
6260 }
6261
6262 #[test]
6263 fn create_stream_handler_with_custom_id_registers_stream_and_id() {
6264 TOKIO_SHARED_RT.block_on(async {
6265 let ws = create_websocket_streams(Some("ws://example.com"), None);
6266 let stream_name = "bar".to_string();
6267 let custom_id = Some("my-custom-id".to_string());
6268 let handler = create_stream_handler::<serde_json::Value>(
6269 WebsocketBase::WebsocketStreams(ws.clone()),
6270 stream_name.clone(),
6271 custom_id.clone(),
6272 )
6273 .await;
6274 assert_eq!(handler.stream_or_id, stream_name);
6275 assert_eq!(handler.id, custom_id);
6276 let map = ws.connection_streams.lock().await;
6277 assert!(map.contains_key(&stream_name));
6278 });
6279 }
6280
6281 #[test]
6282 fn create_stream_handler_without_id_registers_api_stream() {
6283 TOKIO_SHARED_RT.block_on(async {
6284 let ws_base = create_websocket_api(None);
6285 let identifier = "foo-api".to_string();
6286
6287 let handler = create_stream_handler::<Value>(
6288 WebsocketBase::WebsocketApi(ws_base.clone()),
6289 identifier.clone(),
6290 None,
6291 )
6292 .await;
6293
6294 assert_eq!(handler.stream_or_id, identifier);
6295 assert!(handler.id.is_none());
6296 });
6297 }
6298
6299 #[test]
6300 fn create_stream_handler_with_custom_id_registers_api_stream_and_id() {
6301 TOKIO_SHARED_RT.block_on(async {
6302 let ws_base = create_websocket_api(None);
6303 let identifier = "bar-api".to_string();
6304 let custom_id = Some("custom-123".to_string());
6305
6306 let handler = create_stream_handler::<Value>(
6307 WebsocketBase::WebsocketApi(ws_base.clone()),
6308 identifier.clone(),
6309 custom_id.clone(),
6310 )
6311 .await;
6312
6313 assert_eq!(handler.stream_or_id, identifier);
6314 assert_eq!(handler.id, custom_id);
6315 });
6316 }
6317 }
6318}