1use async_trait::async_trait;
2use flate2::read::ZlibDecoder;
3use futures::{SinkExt, StreamExt, stream::FuturesUnordered};
4use regex::Regex;
5use serde::de::DeserializeOwned;
6use serde_json::{Value, json};
7use std::{
8 collections::{BTreeMap, HashMap, VecDeque},
9 io::Read,
10 marker::PhantomData,
11 mem::take,
12 sync::{
13 Arc, LazyLock,
14 atomic::{AtomicUsize, Ordering},
15 },
16 time::Duration,
17};
18use tokio::{
19 net::TcpStream,
20 select, spawn,
21 sync::{
22 Mutex, Notify, broadcast,
23 mpsc::{Receiver, Sender, UnboundedSender, channel, unbounded_channel},
24 oneshot,
25 },
26 task::JoinHandle,
27 time::{sleep, timeout},
28};
29use tokio_tungstenite::{
30 Connector, MaybeTlsStream, WebSocketStream, connect_async_tls_with_config,
31 tungstenite::{
32 Message,
33 client::IntoClientRequest,
34 protocol::{CloseFrame, WebSocketConfig, frame::coding::CloseCode},
35 },
36};
37use tokio_util::time::DelayQueue;
38use tracing::{debug, error, info, warn};
39
40use crate::common::utils::{remove_empty_value, sort_object_params};
41
42use super::{
43 config::{AgentConnector, ConfigurationWebsocketApi, ConfigurationWebsocketStreams},
44 errors::WebsocketError,
45 models::{WebsocketApiResponse, WebsocketEvent, WebsocketMode},
46 utils::{get_timestamp, random_string, validate_time_unit},
47};
48
49static ID_REGEX: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"^[0-9a-f]{32}$").unwrap());
50
51pub type WebSocketClient = WebSocketStream<MaybeTlsStream<TcpStream>>;
52
53const MAX_CONN_DURATION: Duration = Duration::from_secs(23 * 60 * 60);
54
55pub struct Subscription {
56 handle: JoinHandle<()>,
57}
58
59impl Subscription {
60 pub fn unsubscribe(self) {
75 self.handle.abort();
76 }
77}
78
79#[derive(Clone)]
80pub enum WebsocketBase {
81 WebsocketApi(Arc<WebsocketApi>),
82 WebsocketStreams(Arc<WebsocketStreams>),
83}
84
85pub struct WebsocketEventEmitter {
86 tx: broadcast::Sender<WebsocketEvent>,
87}
88
89impl Default for WebsocketEventEmitter {
90 fn default() -> Self {
91 Self::new()
92 }
93}
94
95impl WebsocketEventEmitter {
96 #[must_use]
97 pub fn new() -> Self {
98 let (tx, _rx) = broadcast::channel(100);
99 Self { tx }
100 }
101
102 pub fn subscribe<F>(&self, mut callback: F) -> Subscription
123 where
124 F: FnMut(WebsocketEvent) + Send + 'static,
125 {
126 let mut rx = self.tx.subscribe();
127 let handle = spawn(async move {
128 while let Ok(event) = rx.recv().await {
129 callback(event);
130 }
131 });
132 Subscription { handle }
133 }
134
135 fn emit(&self, event: WebsocketEvent) {
146 let _ = self.tx.send(event);
147 }
148}
149
150#[async_trait]
165pub trait WebsocketHandler: Send + Sync + 'static {
166 async fn on_open(&self, url: String, connection: Arc<WebsocketConnection>);
167 async fn on_message(&self, data: String, connection: Arc<WebsocketConnection>);
168 async fn get_reconnect_url(
169 &self,
170 default_url: String,
171 connection: Arc<WebsocketConnection>,
172 ) -> String;
173}
174
175pub struct PendingRequest {
176 pub completion: oneshot::Sender<Result<Value, WebsocketError>>,
177}
178
179pub struct WebsocketConnectionState {
180 pub reconnection_pending: bool,
181 pub renewal_pending: bool,
182 pub close_initiated: bool,
183 pub pending_requests: HashMap<String, PendingRequest>,
184 pub pending_subscriptions: VecDeque<String>,
185 pub stream_callbacks: HashMap<String, Vec<Arc<dyn Fn(&Value) + Send + Sync + 'static>>>,
186 pub handler: Option<Arc<dyn WebsocketHandler>>,
187 pub ws_write_tx: Option<UnboundedSender<Message>>,
188}
189
190impl Default for WebsocketConnectionState {
191 fn default() -> Self {
192 Self::new()
193 }
194}
195
196impl WebsocketConnectionState {
197 #[must_use]
198 pub fn new() -> Self {
199 Self {
200 reconnection_pending: false,
201 renewal_pending: false,
202 close_initiated: false,
203 pending_requests: HashMap::new(),
204 pending_subscriptions: VecDeque::new(),
205 stream_callbacks: HashMap::new(),
206 handler: None,
207 ws_write_tx: None,
208 }
209 }
210}
211
212pub struct WebsocketConnection {
213 pub id: String,
214 pub drain_notify: Notify,
215 pub state: Mutex<WebsocketConnectionState>,
216}
217
218impl WebsocketConnection {
219 pub fn new(id: impl Into<String>) -> Arc<Self> {
220 Arc::new(Self {
221 id: id.into(),
222 drain_notify: Notify::new(),
223 state: Mutex::new(WebsocketConnectionState::new()),
224 })
225 }
226
227 pub async fn set_handler(&self, handler: Arc<dyn WebsocketHandler>) {
228 let mut conn_state = self.state.lock().await;
229 conn_state.handler = Some(handler);
230 }
231}
232
233struct ReconnectEntry {
234 connection_id: String,
235 url: String,
236 is_renewal: bool,
237}
238
239pub struct WebsocketCommon {
240 pub events: WebsocketEventEmitter,
241 mode: WebsocketMode,
242 round_robin_index: AtomicUsize,
243 connection_pool: Vec<Arc<WebsocketConnection>>,
244 reconnect_tx: Sender<ReconnectEntry>,
245 renewal_tx: Sender<(String, String)>,
246 reconnect_delay: usize,
247 agent: Option<AgentConnector>,
248}
249
250impl WebsocketCommon {
251 #[must_use]
252 pub fn new(
253 mut initial_pool: Vec<Arc<WebsocketConnection>>,
254 mode: WebsocketMode,
255 reconnect_delay: usize,
256 agent: Option<AgentConnector>,
257 ) -> Arc<Self> {
258 if initial_pool.is_empty() {
259 for _ in 0..mode.pool_size() {
260 let id = random_string();
261 initial_pool.push(WebsocketConnection::new(id));
262 }
263 }
264
265 let (reconnect_tx, reconnect_rx) = channel::<ReconnectEntry>(mode.pool_size());
266 let (renewal_tx, renewal_rx) = channel::<(String, String)>(mode.pool_size());
267
268 let common = Arc::new(Self {
269 events: WebsocketEventEmitter::new(),
270 mode,
271 round_robin_index: AtomicUsize::new(0),
272 connection_pool: initial_pool,
273 reconnect_tx,
274 renewal_tx,
275 reconnect_delay,
276 agent,
277 });
278
279 Self::spawn_reconnect_loop(Arc::clone(&common), reconnect_rx);
280 Self::spawn_renewal_loop(&Arc::clone(&common), renewal_rx);
281
282 common
283 }
284
285 fn spawn_reconnect_loop(common: Arc<Self>, mut reconnect_rx: Receiver<ReconnectEntry>) {
303 spawn(async move {
304 while let Some(entry) = reconnect_rx.recv().await {
305 info!("Scheduling reconnect for id {}", entry.connection_id);
306
307 if !entry.is_renewal {
308 sleep(Duration::from_millis(common.reconnect_delay as u64)).await;
309 }
310
311 if let Some(conn_arc) = common
312 .connection_pool
313 .iter()
314 .find(|c| c.id == entry.connection_id)
315 .cloned()
316 {
317 let common_clone = Arc::clone(&common);
318 if let Err(err) = common_clone
319 .init_connect(&entry.url, entry.is_renewal, Some(conn_arc.clone()))
320 .await
321 {
322 error!(
323 "Reconnect failed for {} → {}: {:?}",
324 entry.connection_id, entry.url, err
325 );
326 }
327
328 sleep(Duration::from_secs(1)).await;
329 } else {
330 warn!("No connection {} found for reconnect", entry.connection_id);
331 }
332 }
333 });
334 }
335
336 fn spawn_renewal_loop(common: &Arc<Self>, renewal_rx: Receiver<(String, String)>) {
350 let common = Arc::clone(common);
351 spawn(async move {
352 let mut dq = DelayQueue::new();
353 let mut renewal_rx = renewal_rx;
354
355 loop {
356 select! {
357 Some((conn_id, url)) = renewal_rx.recv() => {
358 debug!("Scheduling renewal for {}", conn_id);
359 dq.insert((conn_id, url), MAX_CONN_DURATION);
360 }
361
362 Some(expired) = dq.next() => {
363 let (conn_id, default_url) = expired.into_inner();
364
365 if let Some(conn_arc) = common
366 .connection_pool
367 .iter()
368 .find(|c| c.id == conn_id)
369 .cloned()
370 {
371 debug!("Renewing connection {}", conn_id);
372 let url = common
373 .get_reconnect_url(&default_url, Arc::clone(&conn_arc))
374 .await;
375 if let Err(e) = common.reconnect_tx.send(ReconnectEntry {
376 connection_id: conn_id.clone(),
377 url,
378 is_renewal: true,
379 }).await {
380 error!(
381 "Failed to enqueue renewal for {}: {:?}",
382 conn_id, e
383 );
384 }
385 } else {
386 warn!("No connection {} found for renewal", conn_id);
387 }
388 }
389 }
390 }
391 });
392 }
393
394 pub async fn is_connection_ready(
413 &self,
414 connection: &WebsocketConnection,
415 allow_non_established: bool,
416 ) -> bool {
417 let conn_state = connection.state.lock().await;
418 (allow_non_established || conn_state.ws_write_tx.is_some())
419 && !conn_state.renewal_pending
420 && !conn_state.reconnection_pending
421 && !conn_state.close_initiated
422 }
423
424 async fn is_connected(&self, connection: Option<&Arc<WebsocketConnection>>) -> bool {
440 if let Some(conn_arc) = connection {
441 return self.is_connection_ready(conn_arc, false).await;
442 }
443
444 for conn_arc in &self.connection_pool {
445 if self.is_connection_ready(conn_arc, false).await {
446 return true;
447 }
448 }
449
450 false
451 }
452
453 async fn get_connection(
473 &self,
474 allow_non_established: bool,
475 ) -> Result<Arc<WebsocketConnection>, WebsocketError> {
476 if let WebsocketMode::Single = self.mode {
477 return Ok(Arc::clone(&self.connection_pool[0]));
478 }
479
480 let mut ready = Vec::new();
481 for conn in &self.connection_pool {
482 if self.is_connection_ready(conn, allow_non_established).await {
483 ready.push(Arc::clone(conn));
484 }
485 }
486
487 if ready.is_empty() {
488 return Err(WebsocketError::NotConnected);
489 }
490
491 let idx = self.round_robin_index.fetch_add(1, Ordering::Relaxed) % ready.len();
492
493 Ok(Arc::clone(&ready[idx]))
494 }
495
496 async fn close_connection_gracefully(
513 &self,
514 ws_write_tx_to_close: UnboundedSender<Message>,
515 connection: Arc<WebsocketConnection>,
516 ) -> Result<(), WebsocketError> {
517 debug!("Waiting for pending requests to complete before disconnecting.");
518
519 let drain = async {
520 loop {
521 {
522 let conn_state = connection.state.lock().await;
523 if conn_state.pending_requests.is_empty() {
524 debug!("All pending requests completed, proceeding to close.");
525 break;
526 }
527 }
528 connection.drain_notify.notified().await;
529 }
530 };
531
532 if timeout(Duration::from_secs(30), drain).await.is_err() {
533 warn!("Timeout waiting for pending requests; forcing close.");
534 }
535
536 info!("Closing WebSocket connection for {}", connection.id);
537 let _ = ws_write_tx_to_close.send(Message::Close(Some(CloseFrame {
538 code: CloseCode::Normal,
539 reason: "".into(),
540 })));
541
542 Ok(())
543 }
544
545 async fn get_reconnect_url(
562 &self,
563 default_url: &str,
564 connection: Arc<WebsocketConnection>,
565 ) -> String {
566 if let Some(handler) = {
567 let conn_state = connection.state.lock().await;
568 conn_state.handler.clone()
569 } {
570 return handler
571 .get_reconnect_url(default_url.to_string(), Arc::clone(&connection))
572 .await;
573 }
574
575 default_url.to_string()
576 }
577
578 async fn on_open(
600 &self,
601 url: String,
602 connection: Arc<WebsocketConnection>,
603 old_ws_writer: Option<UnboundedSender<Message>>,
604 ) {
605 if let Some(handler) = {
606 let conn_state = connection.state.lock().await;
607 conn_state.handler.clone()
608 } {
609 handler.on_open(url.clone(), Arc::clone(&connection)).await;
610 }
611
612 let conn_id = &connection.id;
613 info!("Connected to WebSocket Server with id {}: {}", conn_id, url);
614
615 {
616 let mut conn_state = connection.state.lock().await;
617
618 if conn_state.renewal_pending {
619 conn_state.renewal_pending = false;
620 drop(conn_state);
621 if let Some(tx) = old_ws_writer {
622 info!("Connection renewal in progress; closing previous connection.");
623 let _ = self
624 .close_connection_gracefully(tx, Arc::clone(&connection))
625 .await;
626 }
627 return;
628 }
629
630 if conn_state.close_initiated {
631 drop(conn_state);
632 if let Some(tx) = connection.state.lock().await.ws_write_tx.clone() {
633 info!("Close initiated; closing connection.");
634 let _ = self
635 .close_connection_gracefully(tx, Arc::clone(&connection))
636 .await;
637 }
638 return;
639 }
640
641 self.events.emit(WebsocketEvent::Open);
642 }
643 }
644
645 async fn on_message(&self, msg: String, connection: Arc<WebsocketConnection>) {
657 if let Some(handler) = connection.state.lock().await.handler.clone() {
658 let handler_clone = handler.clone();
659 let data = msg.clone();
660 let conn_clone = connection.clone();
661 spawn(async move {
662 handler_clone.on_message(data, conn_clone).await;
663 });
664 }
665 self.events.emit(WebsocketEvent::Message(msg));
666 }
667
668 async fn create_websocket(
690 url: &str,
691 agent: Option<AgentConnector>,
692 ) -> Result<WebSocketStream<MaybeTlsStream<tokio::net::TcpStream>>, WebsocketError> {
693 let req = url
694 .into_client_request()
695 .map_err(|e| WebsocketError::Handshake(e.to_string()))?;
696
697 let ws_config: Option<WebSocketConfig> = None;
698 let disable_nagle = false;
699 let connector: Option<Connector> = agent.map(|dbg| dbg.0);
700
701 let timeout_duration = Duration::from_secs(10);
702 let handshake = connect_async_tls_with_config(req, ws_config, disable_nagle, connector);
703 match timeout(timeout_duration, handshake).await {
704 Ok(Ok((ws_stream, response))) => {
705 debug!("WebSocket connected: {:?}", response);
706 Ok(ws_stream)
707 }
708 Ok(Err(e)) => {
709 let msg = e.to_string();
710 error!("WebSocket handshake failed: {}", msg);
711 Err(WebsocketError::Handshake(msg))
712 }
713 Err(_) => {
714 error!(
715 "WebSocket connection timed out after {}s",
716 timeout_duration.as_secs()
717 );
718 Err(WebsocketError::Timeout)
719 }
720 }
721 }
722
723 async fn connect_pool(self: Arc<Self>, url: &str) -> Result<(), WebsocketError> {
742 let mut tasks = FuturesUnordered::new();
743
744 for conn in &self.connection_pool {
745 let common = Arc::clone(&self);
746 let url = url.to_owned();
747 let conn_clone = Arc::clone(conn);
748
749 tasks.push(async move {
750 match common.init_connect(&url, false, Some(conn_clone)).await {
751 Ok(()) => {
752 info!("Successfully connected to {}", url);
753 Ok(())
754 }
755 Err(err) => {
756 error!("Failed to connect to {}: {:?}", url, err);
757 Err(err)
758 }
759 }
760 });
761 }
762
763 while let Some(result) = tasks.next().await {
764 result?;
765 }
766
767 Ok(())
768 }
769
770 async fn init_connect(
791 self: Arc<Self>,
792 url: &str,
793 is_renewal: bool,
794 connection: Option<Arc<WebsocketConnection>>,
795 ) -> Result<(), WebsocketError> {
796 let conn = connection.unwrap_or(self.get_connection(true).await?);
797
798 {
799 let mut conn_state = conn.state.lock().await;
800 if conn_state.renewal_pending && is_renewal {
801 info!("Renewal in progress {}→{}", conn.id, url);
802 return Ok(());
803 }
804 if conn_state.ws_write_tx.is_some() && !is_renewal {
805 info!("Exists {}; skipping {}", conn.id, url);
806 return Ok(());
807 }
808 if is_renewal {
809 conn_state.renewal_pending = true;
810 }
811 }
812
813 let ws = Self::create_websocket(url, self.agent.clone())
814 .await
815 .map_err(|e| {
816 error!("Handshake failed {}: {:?}", url, e);
817 e
818 })?;
819
820 info!("Established {} → {}", conn.id, url);
821
822 if let Err(e) = self.renewal_tx.try_send((conn.id.clone(), url.to_string())) {
823 error!("Failed to schedule renewal for {}: {:?}", conn.id, e);
824 }
825
826 let (write_half, mut read_half) = ws.split();
827 let (tx, mut rx) = unbounded_channel::<Message>();
828
829 let old_writer = {
830 let mut conn_state = conn.state.lock().await;
831 conn_state.ws_write_tx.replace(tx.clone())
832 };
833
834 let wconn = conn.clone();
835
836 spawn(async move {
837 let mut sink = write_half;
838 while let Some(msg) = rx.recv().await {
839 if sink.send(msg).await.is_err() {
840 error!("Write error {}", wconn.id);
841 break;
842 }
843 }
844 debug!("Writer {} exit", wconn.id);
845 });
846
847 self.on_open(url.to_string(), conn.clone(), old_writer)
848 .await;
849
850 let common = self.clone();
851 let reader_conn = conn.clone();
852 let read_url = url.to_string();
853
854 spawn(async move {
855 while let Some(item) = read_half.next().await {
856 match item {
857 Ok(Message::Text(msg)) => {
858 common
859 .on_message(msg.to_string(), Arc::clone(&reader_conn))
860 .await;
861 }
862 Ok(Message::Binary(bin)) => {
863 let mut decoder = ZlibDecoder::new(&bin[..]);
864 let mut decompressed = String::new();
865 if let Err(err) = decoder.read_to_string(&mut decompressed) {
866 error!("Binary message decompress failed: {:?}", err);
867 continue;
868 }
869 common
870 .on_message(decompressed, Arc::clone(&reader_conn))
871 .await;
872 }
873 Ok(Message::Ping(payload)) => {
874 info!("PING received from server on {}", reader_conn.id);
875 common.events.emit(WebsocketEvent::Ping);
876 if let Some(tx) = reader_conn.state.lock().await.ws_write_tx.clone() {
877 let _ = tx.send(Message::Pong(payload));
878 info!(
879 "Responded PONG to server's PING message on {}",
880 reader_conn.id
881 );
882 }
883 }
884 Ok(Message::Pong(_)) => {
885 info!("Received PONG from server on {}", reader_conn.id);
886 common.events.emit(WebsocketEvent::Pong);
887 }
888 Ok(Message::Close(frame)) => {
889 let (code, reason) = frame
890 .map_or((1000, String::new()), |CloseFrame { code, reason }| {
891 (code.into(), reason.to_string())
892 });
893 common
894 .events
895 .emit(WebsocketEvent::Close(code, reason.clone()));
896
897 let mut conn_state = reader_conn.state.lock().await;
898 if !conn_state.close_initiated
899 && !is_renewal
900 && CloseCode::from(code) != CloseCode::Normal
901 {
902 warn!(
903 "Connection {} closed due to {}: {}",
904 reader_conn.id, code, reason
905 );
906 conn_state.reconnection_pending = true;
907 drop(conn_state);
908 let reconnect_url = common
909 .get_reconnect_url(&read_url, Arc::clone(&reader_conn))
910 .await;
911 let _ = common.reconnect_tx.send(ReconnectEntry {
912 connection_id: reader_conn.id.clone(),
913 url: reconnect_url,
914 is_renewal: false,
915 });
916 }
917 break;
918 }
919 Err(e) => {
920 error!("WebSocket error on {}: {:?}", reader_conn.id, e);
921 common.events.emit(WebsocketEvent::Error(e.to_string()));
922 }
923 _ => {}
924 }
925 }
926 debug!("Reader actor for {} exiting", reader_conn.id);
927 });
928
929 Ok(())
930 }
931
932 async fn disconnect(&self) -> Result<(), WebsocketError> {
947 if !self.is_connected(None).await {
948 warn!("No active connection to close.");
949 return Ok(());
950 }
951
952 let mut shutdowns = FuturesUnordered::new();
953 for conn in &self.connection_pool {
954 {
955 let mut conn_state = conn.state.lock().await;
956 conn_state.close_initiated = true;
957 if let Some(tx) = &conn_state.ws_write_tx {
958 shutdowns.push(self.close_connection_gracefully(tx.clone(), Arc::clone(conn)));
959 }
960 }
961 }
962
963 let close_all = async {
964 while let Some(result) = shutdowns.next().await {
965 result?;
966 }
967 Ok::<(), WebsocketError>(())
968 };
969
970 match timeout(Duration::from_secs(30), close_all).await {
971 Ok(Ok(())) => {
972 info!("Disconnected all WebSocket connections successfully.");
973 Ok(())
974 }
975 Ok(Err(err)) => {
976 error!("Error while disconnecting: {:?}", err);
977 Err(err)
978 }
979 Err(_) => {
980 error!("Timed out while disconnecting WebSocket connections.");
981 Err(WebsocketError::Timeout)
982 }
983 }
984 }
985
986 async fn ping_server(&self) {
999 let mut ready = Vec::new();
1000 for conn in &self.connection_pool {
1001 if self.is_connection_ready(conn, false).await {
1002 let id = conn.id.clone();
1003 let ws_write_tx = {
1004 let conn_state = conn.state.lock().await;
1005 conn_state.ws_write_tx.clone()
1006 };
1007 ready.push((id, ws_write_tx));
1008 }
1009 }
1010
1011 if ready.is_empty() {
1012 warn!("No ready connections for PING.");
1013 return;
1014 }
1015 info!("Sending PING to {} WebSocket connections.", ready.len());
1016
1017 let mut tasks = FuturesUnordered::new();
1018 for (id, ws_write_tx_opt) in ready {
1019 if let Some(tx) = ws_write_tx_opt {
1020 tasks.push(async move {
1021 if let Err(e) = tx.send(Message::Ping(Vec::new().into())) {
1022 error!("Failed to send PING to {}: {:?}", id, e);
1023 } else {
1024 debug!("Sent PING to connection {}", id);
1025 }
1026 });
1027 } else {
1028 error!("Connection {} was ready but has no write channel", id);
1029 }
1030 }
1031
1032 while tasks.next().await.is_some() {}
1033 }
1034
1035 async fn send(
1053 &self,
1054 payload: String,
1055 id: Option<String>,
1056 wait_for_reply: bool,
1057 timeout: Duration,
1058 connection: Option<Arc<WebsocketConnection>>,
1059 ) -> Result<Option<oneshot::Receiver<Result<Value, WebsocketError>>>, WebsocketError> {
1060 let conn = if let Some(c) = connection {
1061 c
1062 } else {
1063 self.get_connection(false).await?
1064 };
1065
1066 if !self.is_connected(Some(&conn)).await {
1067 warn!("Send attempted on a non-connected socket");
1068 return Err(WebsocketError::NotConnected);
1069 }
1070
1071 let ws_write_tx = {
1072 let conn_state = conn.state.lock().await;
1073 conn_state
1074 .ws_write_tx
1075 .clone()
1076 .ok_or(WebsocketError::NotConnected)?
1077 };
1078
1079 debug!("Sending message to WebSocket on connection {}", conn.id);
1080
1081 ws_write_tx
1082 .send(Message::Text(payload.clone().into()))
1083 .map_err(|_| WebsocketError::NotConnected)?;
1084
1085 if !wait_for_reply {
1086 return Ok(None);
1087 }
1088
1089 let request_id = id.ok_or_else(|| {
1090 error!("id is required when waiting for a reply");
1091 WebsocketError::NotConnected
1092 })?;
1093
1094 let (tx, rx) = oneshot::channel();
1095 {
1096 let mut conn_state = conn.state.lock().await;
1097 conn_state
1098 .pending_requests
1099 .insert(request_id.clone(), PendingRequest { completion: tx });
1100 }
1101
1102 let conn_clone = Arc::clone(&conn);
1103 spawn(async move {
1104 sleep(timeout).await;
1105 let mut conn_state = conn_clone.state.lock().await;
1106 if let Some(pending_req) = conn_state.pending_requests.remove(&request_id) {
1107 let _ = pending_req.completion.send(Err(WebsocketError::Timeout));
1108 }
1109 });
1110
1111 Ok(Some(rx))
1112 }
1113}
1114
1115pub struct WebsocketMessageSendOptions {
1116 pub with_api_key: bool,
1117 pub is_signed: bool,
1118}
1119
1120pub struct WebsocketApi {
1121 pub common: Arc<WebsocketCommon>,
1122 configuration: ConfigurationWebsocketApi,
1123 is_connecting: Arc<Mutex<bool>>,
1124 stream_callbacks: Mutex<HashMap<String, Vec<Arc<dyn Fn(&Value) + Send + Sync + 'static>>>>,
1125}
1126
1127impl WebsocketApi {
1128 #[must_use]
1129 pub fn new(
1150 configuration: ConfigurationWebsocketApi,
1151 connection_pool: Vec<Arc<WebsocketConnection>>,
1152 ) -> Arc<Self> {
1153 let agent_clone = configuration.agent.clone();
1154 let common = WebsocketCommon::new(
1155 connection_pool,
1156 configuration.mode.clone(),
1157 usize::try_from(configuration.reconnect_delay)
1158 .expect("reconnect_delay should fit in usize"),
1159 agent_clone,
1160 );
1161
1162 Arc::new(Self {
1163 common: Arc::clone(&common),
1164 configuration,
1165 is_connecting: Arc::new(Mutex::new(false)),
1166 stream_callbacks: Mutex::new(HashMap::new()),
1167 })
1168 }
1169
1170 pub async fn connect(self: Arc<Self>) -> Result<(), WebsocketError> {
1192 if self.common.is_connected(None).await {
1193 info!("WebSocket connection already established");
1194 return Ok(());
1195 }
1196
1197 {
1198 let mut flag = self.is_connecting.lock().await;
1199 if *flag {
1200 info!("Already connecting...");
1201 return Ok(());
1202 }
1203 *flag = true;
1204 }
1205
1206 let url = self.prepare_url(self.configuration.ws_url.as_deref().unwrap_or_default());
1207
1208 let handler: Arc<dyn WebsocketHandler> = self.clone();
1209 for slot in &self.common.connection_pool {
1210 slot.set_handler(handler.clone()).await;
1211 }
1212
1213 let result = select! {
1214 () = sleep(Duration::from_secs(10)) => Err(WebsocketError::Timeout),
1215 r = self.common.clone().connect_pool(&url) => r,
1216 };
1217
1218 {
1219 let mut flag = self.is_connecting.lock().await;
1220 *flag = false;
1221 }
1222
1223 result
1224 }
1225
1226 pub async fn disconnect(&self) -> Result<(), WebsocketError> {
1239 self.common.disconnect().await
1240 }
1241
1242 pub async fn is_connected(&self) -> bool {
1248 self.common.is_connected(None).await
1249 }
1250
1251 pub async fn ping_server(&self) {
1256 self.common.ping_server().await;
1257 }
1258
1259 pub async fn send_message<R>(
1289 &self,
1290 method: &str,
1291 mut payload: BTreeMap<String, Value>,
1292 options: WebsocketMessageSendOptions,
1293 ) -> Result<WebsocketApiResponse<R>, WebsocketError>
1294 where
1295 R: DeserializeOwned + Send + Sync + 'static,
1296 {
1297 if !self.common.is_connected(None).await {
1298 return Err(WebsocketError::NotConnected);
1299 }
1300
1301 let id = payload
1302 .get("id")
1303 .and_then(Value::as_str)
1304 .filter(|s| ID_REGEX.is_match(s))
1305 .map_or_else(random_string, String::from);
1306
1307 payload.remove("id");
1308
1309 let mut params = remove_empty_value(payload.into_iter());
1310 if options.with_api_key || options.is_signed {
1311 params.insert(
1312 "apiKey".into(),
1313 Value::String(
1314 self.configuration
1315 .api_key
1316 .clone()
1317 .expect("API key must be set"),
1318 ),
1319 );
1320 }
1321 if options.is_signed {
1322 let ts = get_timestamp();
1323 let ts_i64 = i64::try_from(ts).map_err(|e| WebsocketError::Protocol(e.to_string()))?;
1324 params.insert(
1325 "timestamp".into(),
1326 Value::Number(serde_json::Number::from(ts_i64)),
1327 );
1328 let mut sorted_params = sort_object_params(¶ms);
1329 let sig = self
1330 .configuration
1331 .signature_gen
1332 .get_signature(&sorted_params)
1333 .map_err(|e| WebsocketError::Protocol(e.to_string()))?;
1334 sorted_params.insert("signature".into(), Value::String(sig));
1335 params = sorted_params.into_iter().collect();
1336 }
1337
1338 let request = json!({
1339 "id": id,
1340 "method": method,
1341 "params": params,
1342 });
1343 debug!("Sending message to WebSocket API: {:?}", request);
1344
1345 let timeout = Duration::from_millis(self.configuration.timeout);
1346 let maybe_rx = self
1347 .common
1348 .send(
1349 serde_json::to_string(&request).unwrap(),
1350 Some(id.clone()),
1351 true,
1352 timeout,
1353 None,
1354 )
1355 .await?;
1356
1357 let msg: Value = if let Some(rx) = maybe_rx {
1358 rx.await.unwrap_or(Err(WebsocketError::Timeout))?
1359 } else {
1360 return Err(WebsocketError::NoResponse);
1361 };
1362
1363 let raw = msg
1364 .get("result")
1365 .or_else(|| msg.get("response"))
1366 .cloned()
1367 .unwrap_or(Value::Null);
1368
1369 let rate_limits = msg
1370 .get("rateLimits")
1371 .and_then(Value::as_array)
1372 .map(|arr| {
1373 arr.iter()
1374 .filter_map(|v| serde_json::from_value(v.clone()).ok())
1375 .collect()
1376 })
1377 .unwrap_or_default();
1378
1379 Ok(WebsocketApiResponse {
1380 raw,
1381 rate_limits,
1382 _marker: PhantomData,
1383 })
1384 }
1385
1386 fn prepare_url(&self, ws_url: &str) -> String {
1401 let mut url = ws_url.to_string();
1402
1403 let time_unit = match &self.configuration.time_unit {
1404 Some(u) => u.to_string(),
1405 None => return url,
1406 };
1407
1408 match validate_time_unit(&time_unit) {
1409 Ok(Some(validated)) => {
1410 let sep = if url.contains('?') { '&' } else { '?' };
1411 url.push(sep);
1412 url.push_str("timeUnit=");
1413 url.push_str(validated);
1414 }
1415 Ok(None) => {}
1416 Err(e) => {
1417 error!("Invalid time unit provided: {:?}", e);
1418 }
1419 }
1420
1421 url
1422 }
1423}
1424
1425#[async_trait]
1426impl WebsocketHandler for WebsocketApi {
1427 async fn on_open(&self, _url: String, _connection: Arc<WebsocketConnection>) {}
1443
1444 async fn on_message(&self, data: String, connection: Arc<WebsocketConnection>) {
1464 let msg: Value = match serde_json::from_str(&data) {
1465 Ok(v) => v,
1466 Err(err) => {
1467 error!("Failed to parse WebSocket message {} – {}", data, err);
1468 return;
1469 }
1470 };
1471
1472 if let Some(id) = msg.get("id").and_then(Value::as_str) {
1473 let maybe_sender = {
1474 let mut conn_state = connection.state.lock().await;
1475 conn_state.pending_requests.remove(id)
1476 };
1477
1478 if let Some(PendingRequest { completion }) = maybe_sender {
1479 connection.drain_notify.notify_one();
1480 let status = msg.get("status").and_then(Value::as_u64).unwrap_or(200);
1481 if status >= 400 {
1482 let error_map = msg
1483 .get("error")
1484 .and_then(Value::as_object)
1485 .unwrap_or(&serde_json::Map::new())
1486 .clone();
1487
1488 let code = error_map
1489 .get("code")
1490 .and_then(Value::as_i64)
1491 .unwrap_or(status as i64);
1492
1493 let message = error_map
1494 .get("msg")
1495 .and_then(Value::as_str)
1496 .unwrap_or("Unknown error")
1497 .to_string();
1498
1499 let _ = completion.send(Err(WebsocketError::ResponseError { code, message }));
1500 } else {
1501 let _ = completion.send(Ok(msg.clone()));
1502 }
1503 }
1504
1505 return;
1506 }
1507
1508 if let Some(event) = msg.get("event") {
1509 if event.get("e").is_some() {
1510 for callbacks in self.stream_callbacks.lock().await.values() {
1511 for callback in callbacks {
1512 callback(event);
1513 }
1514 }
1515
1516 return;
1517 }
1518 }
1519
1520 warn!(
1521 "Received response for unknown or timed-out request: {}",
1522 data
1523 );
1524 }
1525
1526 async fn get_reconnect_url(
1537 &self,
1538 default_url: String,
1539 _connection: Arc<WebsocketConnection>,
1540 ) -> String {
1541 default_url
1542 }
1543}
1544
1545pub struct WebsocketStreams {
1546 pub common: Arc<WebsocketCommon>,
1547 is_connecting: Mutex<bool>,
1548 connection_streams: Mutex<HashMap<String, Arc<WebsocketConnection>>>,
1549 configuration: ConfigurationWebsocketStreams,
1550}
1551
1552impl WebsocketStreams {
1553 #[must_use]
1568 pub fn new(
1569 configuration: ConfigurationWebsocketStreams,
1570 connection_pool: Vec<Arc<WebsocketConnection>>,
1571 ) -> Arc<Self> {
1572 let agent_clone = configuration.agent.clone();
1573 let common = WebsocketCommon::new(
1574 connection_pool,
1575 configuration.mode.clone(),
1576 usize::try_from(configuration.reconnect_delay)
1577 .expect("reconnect_delay should fit in usize"),
1578 agent_clone,
1579 );
1580 Arc::new(Self {
1581 common,
1582 is_connecting: Mutex::new(false),
1583 connection_streams: Mutex::new(HashMap::new()),
1584 configuration,
1585 })
1586 }
1587
1588 pub async fn connect(self: Arc<Self>, streams: Vec<String>) -> Result<(), WebsocketError> {
1605 if self.common.is_connected(None).await {
1606 info!("WebSocket connection already established");
1607 return Ok(());
1608 }
1609
1610 {
1611 let mut flag = self.is_connecting.lock().await;
1612 if *flag {
1613 info!("Already connecting...");
1614 return Ok(());
1615 }
1616 *flag = true;
1617 }
1618
1619 let url = self.prepare_url(&streams);
1620
1621 let handler: Arc<dyn WebsocketHandler> = self.clone();
1622 for conn in &self.common.connection_pool {
1623 conn.set_handler(handler.clone()).await;
1624 }
1625
1626 let connect_res = select! {
1627 () = sleep(Duration::from_secs(10)) => Err(WebsocketError::Timeout),
1628 r = self.common.clone().connect_pool(&url) => r,
1629 };
1630
1631 {
1632 let mut flag = self.is_connecting.lock().await;
1633 *flag = false;
1634 }
1635
1636 connect_res
1637 }
1638
1639 pub async fn disconnect(&self) -> Result<(), WebsocketError> {
1655 for connection in &self.common.connection_pool {
1656 let mut conn_state = connection.state.lock().await;
1657 conn_state.stream_callbacks.clear();
1658 conn_state.pending_subscriptions.clear();
1659 }
1660 self.connection_streams.lock().await.clear();
1661 self.common.disconnect().await
1662 }
1663
1664 pub async fn is_connected(&self) -> bool {
1670 self.common.is_connected(None).await
1671 }
1672
1673 pub async fn ping_server(&self) {
1682 self.common.ping_server().await;
1683 }
1684
1685 pub async fn subscribe(self: Arc<Self>, streams: Vec<String>, id: Option<String>) {
1704 let streams: Vec<String> = {
1705 let map = self.connection_streams.lock().await;
1706 streams
1707 .into_iter()
1708 .filter(|s| !map.contains_key(s))
1709 .collect()
1710 };
1711 let connection_streams = self.handle_stream_assignment(streams).await;
1712
1713 for (conn, streams) in connection_streams {
1714 if !self.common.is_connected(Some(&conn)).await {
1715 info!(
1716 "Connection is not ready. Queuing subscription for streams: {:?}",
1717 streams
1718 );
1719 let mut conn_state = conn.state.lock().await;
1720 conn_state.pending_subscriptions.extend(streams.clone());
1721 continue;
1722 }
1723 self.send_subscription_payload(conn.clone(), streams.clone(), id.clone());
1724 }
1725 }
1726
1727 pub async fn unsubscribe(&self, streams: Vec<String>, id: Option<String>) {
1755 let request_id = id
1756 .filter(|s| ID_REGEX.is_match(s))
1757 .unwrap_or_else(random_string);
1758
1759 for stream in streams {
1760 let maybe_conn = { self.connection_streams.lock().await.get(&stream).cloned() };
1761
1762 let conn = if let Some(c) = maybe_conn {
1763 if !self.common.is_connected(Some(&c)).await {
1764 warn!(
1765 "Stream {} not associated with an active connection.",
1766 stream
1767 );
1768 continue;
1769 }
1770 c
1771 } else {
1772 warn!("Stream {} was not subscribed.", stream);
1773 continue;
1774 };
1775
1776 let callbacks = {
1777 let conn_state = conn.state.lock().await;
1778 conn_state
1779 .stream_callbacks
1780 .get(&stream)
1781 .is_none_or(std::vec::Vec::is_empty)
1782 };
1783
1784 if !callbacks {
1785 continue;
1786 }
1787
1788 let payload = json!({
1789 "method": "UNSUBSCRIBE",
1790 "params": [stream.clone()],
1791 "id": request_id,
1792 });
1793
1794 info!("UNSUBSCRIBE → {:?}", payload);
1795
1796 let common = Arc::clone(&self.common);
1797 let conn_clone = Arc::clone(&conn);
1798 let msg = serde_json::to_string(&payload).unwrap();
1799 spawn(async move {
1800 let _ = common
1801 .send(msg, None, false, Duration::ZERO, Some(conn_clone))
1802 .await;
1803 });
1804
1805 {
1806 let mut connection_streams = self.connection_streams.lock().await;
1807 connection_streams.remove(&stream);
1808 }
1809 {
1810 let mut conn_state = conn.state.lock().await;
1811 conn_state.stream_callbacks.remove(&stream);
1812 }
1813 }
1814 }
1815
1816 pub async fn is_subscribed(&self, stream: &str) -> bool {
1830 self.connection_streams.lock().await.contains_key(stream)
1831 }
1832
1833 fn prepare_url(&self, streams: &[String]) -> String {
1849 let mut url = format!(
1850 "{}/stream?streams={}",
1851 self.configuration.ws_url.as_deref().unwrap_or(""),
1852 streams.join("/")
1853 );
1854
1855 let time_unit = match &self.configuration.time_unit {
1856 Some(u) => u.to_string(),
1857 None => return url,
1858 };
1859
1860 match validate_time_unit(&time_unit) {
1861 Ok(Some(validated)) => {
1862 let sep = if url.contains('?') { '&' } else { '?' };
1863 url.push(sep);
1864 url.push_str("timeUnit=");
1865 url.push_str(validated);
1866 }
1867 Ok(None) => {}
1868 Err(e) => {
1869 error!("Invalid time unit provided: {:?}", e);
1870 }
1871 }
1872
1873 url
1874 }
1875
1876 async fn handle_stream_assignment(
1894 &self,
1895 streams: Vec<String>,
1896 ) -> Vec<(Arc<WebsocketConnection>, Vec<String>)> {
1897 let mut connection_streams: Vec<(String, Arc<WebsocketConnection>)> = Vec::new();
1898
1899 for stream in streams {
1900 let mut conn_opt = {
1901 let map = self.connection_streams.lock().await;
1902 map.get(&stream).cloned()
1903 };
1904
1905 let need_new = if let Some(conn) = &conn_opt {
1906 let state = conn.state.lock().await;
1907 state.close_initiated || state.reconnection_pending
1908 } else {
1909 true
1910 };
1911
1912 if need_new {
1913 match self.common.get_connection(true).await {
1914 Ok(new_conn) => {
1915 let mut map = self.connection_streams.lock().await;
1916 map.insert(stream.clone(), new_conn.clone());
1917 conn_opt = Some(new_conn);
1918 }
1919 Err(err) => {
1920 warn!(
1921 "No available WebSocket connection to subscribe stream `{}`: {:?}",
1922 stream, err
1923 );
1924 continue;
1925 }
1926 }
1927 }
1928
1929 if let Some(conn) = conn_opt {
1930 {
1931 let mut conn_state = conn.state.lock().await;
1932 conn_state
1933 .stream_callbacks
1934 .entry(stream.clone())
1935 .or_default();
1936 }
1937 connection_streams.push((stream.clone(), conn));
1938 }
1939 }
1940
1941 let mut groups: Vec<(Arc<WebsocketConnection>, Vec<String>)> = Vec::new();
1942 for (stream, conn) in connection_streams {
1943 if let Some((_, vec)) = groups.iter_mut().find(|(c, _)| Arc::ptr_eq(c, &conn)) {
1944 vec.push(stream);
1945 } else {
1946 groups.push((conn, vec![stream]));
1947 }
1948 }
1949
1950 groups
1951 }
1952
1953 fn send_subscription_payload(
1966 &self,
1967 connection: Arc<WebsocketConnection>,
1968 streams: Vec<String>,
1969 id: Option<String>,
1970 ) {
1971 let request_id = id
1972 .filter(|s| ID_REGEX.is_match(s))
1973 .unwrap_or_else(random_string);
1974
1975 let payload = json!({
1976 "method": "SUBSCRIBE",
1977 "params": streams,
1978 "id": request_id,
1979 });
1980
1981 info!("SUBSCRIBE → {:?}", payload);
1982
1983 let common = Arc::clone(&self.common);
1984 let msg = match serde_json::to_string(&payload) {
1985 Ok(s) => s,
1986 Err(e) => {
1987 error!("Failed to serialize SUBSCRIBE payload: {}", e);
1988 return;
1989 }
1990 };
1991 let conn_clone = Arc::clone(&connection);
1992
1993 spawn(async move {
1994 let _ = common
1995 .send(msg, None, false, Duration::ZERO, Some(conn_clone))
1996 .await;
1997 });
1998 }
1999}
2000
2001#[async_trait]
2002impl WebsocketHandler for WebsocketStreams {
2003 async fn on_open(&self, _url: String, connection: Arc<WebsocketConnection>) {
2020 let pending_subs: Vec<String> = {
2021 let mut conn_state = connection.state.lock().await;
2022 take(&mut conn_state.pending_subscriptions)
2023 .into_iter()
2024 .collect()
2025 };
2026
2027 if !pending_subs.is_empty() {
2028 info!("Processing queued subscriptions for connection");
2029 self.send_subscription_payload(connection.clone(), pending_subs, None);
2030 }
2031 }
2032
2033 async fn on_message(&self, data: String, connection: Arc<WebsocketConnection>) {
2050 let msg: Value = match serde_json::from_str(&data) {
2051 Ok(v) => v,
2052 Err(err) => {
2053 error!(
2054 "Failed to parse WebSocket stream message {} – {}",
2055 data, err
2056 );
2057 return;
2058 }
2059 };
2060
2061 let (stream_name, payload) = match (
2062 msg.get("stream").and_then(Value::as_str),
2063 msg.get("data").cloned(),
2064 ) {
2065 (Some(name), Some(data)) => (name.to_string(), data),
2066 _ => return,
2067 };
2068
2069 let callbacks = {
2070 let conn_state = connection.state.lock().await;
2071 conn_state
2072 .stream_callbacks
2073 .get(&stream_name)
2074 .cloned()
2075 .unwrap_or_else(Vec::new)
2076 };
2077
2078 for callback in callbacks {
2079 callback(&payload);
2080 }
2081 }
2082
2083 async fn get_reconnect_url(
2094 &self,
2095 _default_url: String,
2096 connection: Arc<WebsocketConnection>,
2097 ) -> String {
2098 let connection_streams = self.connection_streams.lock().await;
2099 let reconnect_streams = connection_streams
2100 .iter()
2101 .filter_map(|(stream, conn_arc)| {
2102 if Arc::ptr_eq(conn_arc, &connection) {
2103 Some(stream.clone())
2104 } else {
2105 None
2106 }
2107 })
2108 .collect::<Vec<_>>();
2109 self.prepare_url(&reconnect_streams)
2110 }
2111}
2112
2113pub struct WebsocketStream<T> {
2114 websocket_base: WebsocketBase,
2115 stream_or_id: String,
2116 callback: Mutex<Option<Arc<dyn Fn(&Value) + Send + Sync>>>,
2117 pub id: Option<String>,
2118 _phantom: PhantomData<T>,
2119}
2120
2121impl<T> WebsocketStream<T>
2122where
2123 T: DeserializeOwned + Send + 'static,
2124{
2125 async fn on<F>(&self, event: &str, callback_fn: F)
2146 where
2147 F: Fn(T) + Send + Sync + 'static,
2148 {
2149 if event != "message" {
2150 return;
2151 }
2152
2153 let cb_wrapper: Arc<dyn Fn(&Value) + Send + Sync> =
2154 Arc::new(
2155 move |v: &Value| match serde_json::from_value::<T>(v.clone()) {
2156 Ok(data) => callback_fn(data),
2157 Err(e) => error!("Failed to deserialize stream payload: {:?}", e),
2158 },
2159 );
2160
2161 {
2162 let mut guard = self.callback.lock().await;
2163 *guard = Some(cb_wrapper.clone());
2164 }
2165
2166 match &self.websocket_base {
2167 WebsocketBase::WebsocketStreams(ws_streams) => {
2168 let conn = {
2169 let map = ws_streams.connection_streams.lock().await;
2170 map.get(&self.stream_or_id)
2171 .cloned()
2172 .expect("stream must be subscribed")
2173 };
2174
2175 {
2176 let mut conn_state = conn.state.lock().await;
2177 let entry = conn_state
2178 .stream_callbacks
2179 .entry(self.stream_or_id.clone())
2180 .or_default();
2181
2182 if !entry
2183 .iter()
2184 .any(|existing| Arc::ptr_eq(existing, &cb_wrapper))
2185 {
2186 entry.push(cb_wrapper);
2187 }
2188 }
2189 }
2190 WebsocketBase::WebsocketApi(ws_api) => {
2191 let mut stream_callbacks = ws_api.stream_callbacks.lock().await;
2192 let entry = stream_callbacks
2193 .entry(self.stream_or_id.clone())
2194 .or_default();
2195
2196 if !entry
2197 .iter()
2198 .any(|existing| Arc::ptr_eq(existing, &cb_wrapper))
2199 {
2200 entry.push(cb_wrapper);
2201 }
2202 }
2203 }
2204 }
2205
2206 pub fn on_message<F>(self: &Arc<Self>, callback_fn: F)
2225 where
2226 T: Send + Sync,
2227 F: Fn(T) + Send + Sync + 'static,
2228 {
2229 let handler: Arc<Self> = Arc::clone(self);
2230
2231 std::thread::spawn(move || {
2232 let rt = tokio::runtime::Builder::new_current_thread()
2233 .enable_all()
2234 .build()
2235 .expect("failed to build Tokio runtime");
2236
2237 rt.block_on(handler.on("message", callback_fn));
2238 })
2239 .join()
2240 .expect("on_message thread panicked");
2241 }
2242
2243 pub async fn unsubscribe(&self) {
2258 let maybe_cb = {
2259 let mut guard = self.callback.lock().await;
2260 guard.take()
2261 };
2262
2263 if let Some(cb) = maybe_cb {
2264 match &self.websocket_base {
2265 WebsocketBase::WebsocketStreams(ws_streams) => {
2266 let conn = {
2267 let map = ws_streams.connection_streams.lock().await;
2268 map.get(&self.stream_or_id)
2269 .cloned()
2270 .expect("stream must have been subscribed")
2271 };
2272
2273 {
2274 let mut conn_state = conn.state.lock().await;
2275 if let Some(list) = conn_state.stream_callbacks.get_mut(&self.stream_or_id)
2276 {
2277 list.retain(|existing| !Arc::ptr_eq(existing, &cb));
2278 }
2279 }
2280
2281 let stream = self.stream_or_id.clone();
2282 let id = self.id.clone();
2283 let websocket_streams_base = Arc::clone(ws_streams);
2284 spawn(async move {
2285 websocket_streams_base.unsubscribe(vec![stream], id).await;
2286 });
2287 }
2288 WebsocketBase::WebsocketApi(ws_api) => {
2289 let mut stream_callbacks = ws_api.stream_callbacks.lock().await;
2290 if let Some(list) = stream_callbacks.get_mut(&self.stream_or_id) {
2291 list.retain(|existing| !Arc::ptr_eq(existing, &cb));
2292 }
2293 }
2294 }
2295 }
2296 }
2297}
2298
2299pub async fn create_stream_handler<T>(
2300 websocket_base: WebsocketBase,
2301 stream_or_id: String,
2302 id: Option<String>,
2303) -> Arc<WebsocketStream<T>>
2304where
2305 T: DeserializeOwned + Send + 'static,
2306{
2307 match &websocket_base {
2308 WebsocketBase::WebsocketStreams(ws_streams) => {
2309 ws_streams
2310 .clone()
2311 .subscribe(vec![stream_or_id.clone()], id.clone())
2312 .await;
2313 }
2314 WebsocketBase::WebsocketApi(_) => {}
2315 }
2316
2317 Arc::new(WebsocketStream {
2318 websocket_base,
2319 stream_or_id,
2320 id,
2321 callback: Mutex::new(None),
2322 _phantom: PhantomData,
2323 })
2324}
2325
2326#[cfg(test)]
2327mod tests {
2328 use crate::TOKIO_SHARED_RT;
2329 use crate::common::utils::SignatureGenerator;
2330 use crate::common::websocket::{
2331 PendingRequest, ReconnectEntry, WebsocketApi, WebsocketBase, WebsocketCommon,
2332 WebsocketConnection, WebsocketEvent, WebsocketEventEmitter, WebsocketHandler,
2333 WebsocketMessageSendOptions, WebsocketMode, WebsocketStream, WebsocketStreams,
2334 create_stream_handler,
2335 };
2336 use crate::config::{ConfigurationWebsocketApi, ConfigurationWebsocketStreams, PrivateKey};
2337 use crate::errors::WebsocketError;
2338 use crate::models::TimeUnit;
2339 use async_trait::async_trait;
2340 use futures::{SinkExt, StreamExt};
2341 use regex::Regex;
2342 use serde_json::{Value, json};
2343 use std::collections::{BTreeMap, HashSet};
2344 use std::marker::PhantomData;
2345 use std::net::SocketAddr;
2346 use std::sync::{
2347 Arc,
2348 atomic::{AtomicBool, AtomicUsize, Ordering},
2349 };
2350 use tokio::net::TcpListener;
2351 use tokio::sync::{Mutex, mpsc::unbounded_channel, oneshot};
2352 use tokio::time::{Duration, advance, pause, resume, sleep, timeout};
2353 use tokio_tungstenite::{accept_async, tungstenite, tungstenite::Message};
2354
2355 fn subscribe_events(common: &WebsocketCommon) -> Arc<Mutex<Vec<WebsocketEvent>>> {
2356 let events = Arc::new(Mutex::new(Vec::new()));
2357 let events_clone = events.clone();
2358 common.events.subscribe(move |event| {
2359 let events_clone = events_clone.clone();
2360 tokio::spawn(async move {
2361 events_clone.lock().await.push(event);
2362 });
2363 });
2364 events
2365 }
2366
2367 async fn create_connection(
2368 id: &str,
2369 has_writer: bool,
2370 reconnection_pending: bool,
2371 renewal_pending: bool,
2372 close_initiated: bool,
2373 ) -> Arc<WebsocketConnection> {
2374 let conn = WebsocketConnection::new(id);
2375 let mut st = conn.state.lock().await;
2376 st.reconnection_pending = reconnection_pending;
2377 st.renewal_pending = renewal_pending;
2378 st.close_initiated = close_initiated;
2379 if has_writer {
2380 let (tx, _) = unbounded_channel::<Message>();
2381 st.ws_write_tx = Some(tx);
2382 } else {
2383 st.ws_write_tx = None;
2384 }
2385 drop(st);
2386 conn
2387 }
2388
2389 fn create_websocket_api(time_unit: Option<TimeUnit>) -> Arc<WebsocketApi> {
2390 let sig_gen = SignatureGenerator::new(
2391 Some("api_secret".into()),
2392 None::<PrivateKey>,
2393 None::<String>,
2394 );
2395 let config = ConfigurationWebsocketApi {
2396 api_key: Some("api_key".into()),
2397 api_secret: Some("api_secret".into()),
2398 private_key: None,
2399 private_key_passphrase: None,
2400 ws_url: Some("wss://example.com".into()),
2401 mode: WebsocketMode::Single,
2402 reconnect_delay: 1000,
2403 signature_gen: sig_gen,
2404 timeout: 500,
2405 time_unit,
2406 agent: None,
2407 };
2408 let conn = WebsocketConnection::new("c1");
2409 WebsocketApi::new(config, vec![conn])
2410 }
2411
2412 fn create_websocket_streams(
2413 ws_url: Option<&str>,
2414 conns: Option<Vec<Arc<WebsocketConnection>>>,
2415 ) -> Arc<WebsocketStreams> {
2416 let mut connections: Vec<Arc<WebsocketConnection>> = vec![];
2417 if conns.is_none() {
2418 connections.push(WebsocketConnection::new("c1"));
2419 connections.push(WebsocketConnection::new("c2"));
2420 } else {
2421 connections = conns.expect("Expected connections to be set");
2422 }
2423 let config = ConfigurationWebsocketStreams {
2424 ws_url: Some(ws_url.unwrap_or("example.com").to_string()),
2425 mode: WebsocketMode::Single,
2426 reconnect_delay: 500,
2427 time_unit: None,
2428 agent: None,
2429 };
2430 WebsocketStreams::new(config, connections)
2431 }
2432
2433 mod event_emitter {
2434 use super::*;
2435
2436 #[test]
2437 fn event_emitter_subscribe_and_emit() {
2438 TOKIO_SHARED_RT.block_on(async {
2439 let emitter = WebsocketEventEmitter::new();
2440 let (tx, rx) = oneshot::channel();
2441 let tx = Arc::new(std::sync::Mutex::new(Some(tx)));
2442 let tx_clone = tx.clone();
2443 let _sub = emitter.subscribe(move |event| {
2444 if let Some(sender) = tx_clone.lock().unwrap().take() {
2445 let _ = sender.send(event);
2446 }
2447 });
2448 emitter.emit(WebsocketEvent::Open);
2449 let received = timeout(Duration::from_millis(100), rx)
2450 .await
2451 .expect("timed out");
2452 assert_eq!(received, Ok(WebsocketEvent::Open));
2453 });
2454 }
2455 }
2456
2457 mod websocket_common {
2458 use super::*;
2459
2460 mod initialisation {
2461 use super::*;
2462
2463 #[test]
2464 fn single_mode() {
2465 TOKIO_SHARED_RT.block_on(async {
2466 let common = WebsocketCommon::new(vec![], WebsocketMode::Single, 0, None);
2467 assert_eq!(common.connection_pool.len(), 1);
2468 });
2469 }
2470
2471 #[test]
2472 fn pool_mode() {
2473 TOKIO_SHARED_RT.block_on(async {
2474 let common = WebsocketCommon::new(vec![], WebsocketMode::Pool(3), 0, None);
2475 assert_eq!(common.connection_pool.len(), 3);
2476 });
2477 }
2478 }
2479
2480 mod spawn_reconnect_loop {
2481 use super::*;
2482
2483 #[test]
2484 fn successful_reconnect_entry_triggers_init_connect() {
2485 TOKIO_SHARED_RT.block_on(async {
2486 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
2487 let addr = listener.local_addr().unwrap();
2488 tokio::spawn(async move {
2489 if let Ok((stream, _)) = listener.accept().await {
2490 let mut ws = accept_async(stream).await.unwrap();
2491 let _ = ws.close(None).await;
2492 }
2493 });
2494
2495 let conn = WebsocketConnection::new("c1");
2496 let common =
2497 WebsocketCommon::new(vec![conn.clone()], WebsocketMode::Single, 10, None);
2498 let url = format!("ws://{addr}");
2499 common
2500 .reconnect_tx
2501 .send(ReconnectEntry {
2502 connection_id: "c1".into(),
2503 url: url.clone(),
2504 is_renewal: false,
2505 })
2506 .await
2507 .unwrap();
2508
2509 sleep(Duration::from_secs(2)).await;
2510
2511 let st = conn.state.lock().await;
2512 assert!(st.ws_write_tx.is_some());
2513 });
2514 }
2515
2516 #[test]
2517 fn reconnect_entry_with_unknown_id_is_ignored() {
2518 TOKIO_SHARED_RT.block_on(async {
2519 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
2520 let addr = listener.local_addr().unwrap();
2521 tokio::spawn(async move {
2522 if let Ok((stream, _)) = listener.accept().await {
2523 let mut ws = accept_async(stream).await.unwrap();
2524 let _ = ws.close(None).await;
2525 }
2526 });
2527
2528 let conn = WebsocketConnection::new("c1");
2529 let common =
2530 WebsocketCommon::new(vec![conn.clone()], WebsocketMode::Single, 5, None);
2531 let url = format!("ws://{addr}");
2532 common
2533 .reconnect_tx
2534 .send(ReconnectEntry {
2535 connection_id: "other".into(),
2536 url,
2537 is_renewal: false,
2538 })
2539 .await
2540 .unwrap();
2541
2542 sleep(Duration::from_secs(1)).await;
2543
2544 let st = conn.state.lock().await;
2545 assert!(st.ws_write_tx.is_none());
2546 });
2547 }
2548
2549 #[test]
2550 fn renewal_entries_bypass_initial_delay() {
2551 TOKIO_SHARED_RT.block_on(async {
2552 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
2553 let addr = listener.local_addr().unwrap();
2554 tokio::spawn(async move {
2555 if let Ok((stream, _)) = listener.accept().await {
2556 let mut ws = accept_async(stream).await.unwrap();
2557 let _ = ws.close(None).await;
2558 }
2559 });
2560
2561 let conn = WebsocketConnection::new("renew");
2562 let common =
2563 WebsocketCommon::new(vec![conn.clone()], WebsocketMode::Single, 200, None);
2564 let url = format!("ws://{addr}");
2565 common
2566 .reconnect_tx
2567 .send(ReconnectEntry {
2568 connection_id: "renew".into(),
2569 url: url.clone(),
2570 is_renewal: true,
2571 })
2572 .await
2573 .unwrap();
2574
2575 sleep(Duration::from_secs(2)).await;
2576
2577 let st = conn.state.lock().await;
2578
2579 assert!(st.ws_write_tx.is_some());
2580 });
2581 }
2582
2583 #[test]
2584 fn non_renewal_entries_respect_initial_delay() {
2585 TOKIO_SHARED_RT.block_on(async {
2586 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
2587 let addr = listener.local_addr().unwrap();
2588 tokio::spawn(async move {
2589 if let Ok((stream, _)) = listener.accept().await {
2590 let mut ws = accept_async(stream).await.unwrap();
2591 let _ = ws.close(None).await;
2592 }
2593 });
2594
2595 let conn = WebsocketConnection::new("nonrenew");
2596 let common =
2597 WebsocketCommon::new(vec![conn.clone()], WebsocketMode::Single, 200, None);
2598 let url = format!("ws://{addr}");
2599 common
2600 .reconnect_tx
2601 .send(ReconnectEntry {
2602 connection_id: "nonrenew".into(),
2603 url: url.clone(),
2604 is_renewal: false,
2605 })
2606 .await
2607 .unwrap();
2608
2609 sleep(Duration::from_millis(100)).await;
2610 assert!(conn.state.lock().await.ws_write_tx.is_none());
2611
2612 sleep(Duration::from_secs(2)).await;
2613
2614 assert!(conn.state.lock().await.ws_write_tx.is_some());
2615 });
2616 }
2617 }
2618
2619 mod spawn_renewal_loop {
2620 use super::*;
2621
2622 #[tokio::test]
2623 async fn scheduling_renewal_does_not_panic_for_known_connection() {
2624 pause();
2625
2626 let conn = WebsocketConnection::new("known");
2627 let common =
2628 WebsocketCommon::new(vec![conn.clone()], WebsocketMode::Single, 0, None);
2629 let url = "wss://example".to_string();
2630 common
2631 .renewal_tx
2632 .send((conn.id.clone(), url))
2633 .await
2634 .unwrap();
2635 advance(Duration::from_secs(23 * 60 * 60 + 1)).await;
2636
2637 resume();
2638 }
2639
2640 #[tokio::test]
2641 async fn scheduling_renewal_ignored_for_unknown_connection() {
2642 pause();
2643
2644 let conn = WebsocketConnection::new("c1");
2645 let common =
2646 WebsocketCommon::new(vec![conn.clone()], WebsocketMode::Single, 0, None);
2647 common
2648 .renewal_tx
2649 .send(("other".into(), "u".into()))
2650 .await
2651 .unwrap();
2652 advance(Duration::from_secs(23 * 60 * 60 + 1)).await;
2653
2654 resume();
2655 }
2656 }
2657
2658 mod is_connection_ready {
2659 use super::*;
2660
2661 #[test]
2662 fn is_connection_ready() {
2663 TOKIO_SHARED_RT.block_on(async {
2664 let conn = WebsocketConnection::new("c1");
2665 let common =
2666 WebsocketCommon::new(vec![conn.clone()], WebsocketMode::Single, 0, None);
2667 assert!(!common.is_connection_ready(&conn, false).await);
2668 assert!(common.is_connection_ready(&conn, true).await);
2669 });
2670 }
2671
2672 #[test]
2673 fn connection_ready_basic() {
2674 TOKIO_SHARED_RT.block_on(async {
2675 let conn = create_connection("c1", true, false, false, false).await;
2676 let common =
2677 WebsocketCommon::new(vec![conn.clone()], WebsocketMode::Single, 0, None);
2678 assert!(common.is_connection_ready(&conn, false).await);
2679 });
2680 }
2681
2682 #[test]
2683 fn connection_not_ready_without_writer() {
2684 TOKIO_SHARED_RT.block_on(async {
2685 let conn = create_connection("c1", false, false, false, false).await;
2686 let common =
2687 WebsocketCommon::new(vec![conn.clone()], WebsocketMode::Single, 0, None);
2688 assert!(!common.is_connection_ready(&conn, false).await);
2689 assert!(common.is_connection_ready(&conn, true).await);
2690 });
2691 }
2692
2693 #[test]
2694 fn connection_not_ready_when_flagged() {
2695 TOKIO_SHARED_RT.block_on(async {
2696 let conn1 = create_connection("c1", true, true, false, false).await;
2697 let conn2 = create_connection("c2", true, false, true, false).await;
2698 let conn3 = create_connection("c3", true, false, false, true).await;
2699
2700 let common = WebsocketCommon::new(
2701 vec![conn1.clone(), conn2.clone(), conn3.clone()],
2702 WebsocketMode::Pool(3),
2703 0,
2704 None,
2705 );
2706
2707 assert!(!common.is_connection_ready(&conn1, false).await);
2708 assert!(!common.is_connection_ready(&conn2, false).await);
2709 assert!(!common.is_connection_ready(&conn3, false).await);
2710 });
2711 }
2712 }
2713
2714 mod is_connected {
2715 use super::*;
2716
2717 #[test]
2718 fn with_pool_various_connections() {
2719 TOKIO_SHARED_RT.block_on(async {
2720 let conn_a = create_connection("a", true, false, false, false).await;
2721 let conn_b = create_connection("b", false, false, false, false).await;
2722 let conn_c = create_connection("c", true, true, false, false).await;
2723 let pool = vec![conn_a.clone(), conn_b.clone(), conn_c.clone()];
2724 let common = WebsocketCommon::new(pool, WebsocketMode::Pool(3), 0, None);
2725
2726 assert!(common.is_connected(None).await);
2727 assert!(common.is_connected(Some(&conn_a)).await);
2728 assert!(!common.is_connected(Some(&conn_b)).await);
2729 assert!(!common.is_connected(Some(&conn_c)).await);
2730 });
2731 }
2732
2733 #[test]
2734 fn with_pool_all_bad_connections() {
2735 TOKIO_SHARED_RT.block_on(async {
2736 let bad1 = create_connection("c1", false, false, false, false).await;
2737 let bad2 = create_connection("c2", true, true, false, false).await;
2738 let bad3 = create_connection("c3", true, false, false, true).await;
2739 let common = WebsocketCommon::new(
2740 vec![bad1, bad2, bad3],
2741 WebsocketMode::Pool(3),
2742 0,
2743 None,
2744 );
2745
2746 assert!(!common.is_connected(None).await);
2747 });
2748 }
2749
2750 #[test]
2751 fn with_pool_ignore_close_initiated() {
2752 TOKIO_SHARED_RT.block_on(async {
2753 let good = create_connection("c1", true, false, false, false).await;
2754 let closed = create_connection("c2", true, false, false, true).await;
2755 let bad = create_connection("c3", false, false, false, false).await;
2756 let common = WebsocketCommon::new(
2757 vec![closed.clone(), good.clone(), bad.clone()],
2758 WebsocketMode::Pool(3),
2759 0,
2760 None,
2761 );
2762
2763 assert!(common.is_connected(None).await);
2764 assert!(!common.is_connected(Some(&closed)).await);
2765 });
2766 }
2767 }
2768
2769 mod get_connection {
2770 use super::*;
2771
2772 #[test]
2773 fn single_mode() {
2774 TOKIO_SHARED_RT.block_on(async {
2775 let common = WebsocketCommon::new(vec![], WebsocketMode::Single, 0, None);
2776 let conn = common
2777 .get_connection(false)
2778 .await
2779 .expect("should get connection");
2780 assert_eq!(conn.id, common.connection_pool[0].id);
2781 });
2782 }
2783
2784 #[test]
2785 fn pool_mode_not_ready() {
2786 TOKIO_SHARED_RT.block_on(async {
2787 let common = WebsocketCommon::new(vec![], WebsocketMode::Pool(2), 0, None);
2788 let result = common.get_connection(false).await;
2789 assert!(matches!(
2790 result,
2791 Err(crate::errors::WebsocketError::NotConnected)
2792 ));
2793 });
2794 }
2795
2796 #[test]
2797 fn pool_mode_with_ready() {
2798 TOKIO_SHARED_RT.block_on(async {
2799 let conn1 = WebsocketConnection::new("c1");
2800 let conn2 = WebsocketConnection::new("c2");
2801 let (tx1, _rx1) = unbounded_channel();
2802 {
2803 let mut s1 = conn1.state.lock().await;
2804 s1.ws_write_tx = Some(tx1);
2805 }
2806 let pool = vec![conn1.clone(), conn2.clone()];
2807 let common = WebsocketCommon::new(pool, WebsocketMode::Pool(2), 0, None);
2808 let result = common.get_connection(false).await;
2809 assert!(result.is_ok());
2810 let chosen = result.unwrap();
2811 assert_eq!(chosen.id, conn1.id);
2812 });
2813 }
2814 }
2815
2816 mod close_connection_gracefully {
2817 use super::*;
2818
2819 #[tokio::test]
2820 async fn waits_for_pending_requests_then_closes() {
2821 pause();
2822
2823 let conn = WebsocketConnection::new("c1");
2824 let (tx, mut rx) = unbounded_channel::<Message>();
2825 let (req_tx, _req_rx) = oneshot::channel();
2826 {
2827 let mut st = conn.state.lock().await;
2828 st.pending_requests
2829 .insert("r".to_string(), PendingRequest { completion: req_tx });
2830 }
2831 let common =
2832 WebsocketCommon::new(vec![conn.clone()], WebsocketMode::Single, 0, None);
2833 let close_fut = common.close_connection_gracefully(tx.clone(), conn.clone());
2834 advance(Duration::from_secs(1)).await;
2835 {
2836 let mut st = conn.state.lock().await;
2837 st.pending_requests.clear();
2838 }
2839 conn.drain_notify.notify_waiters();
2840 advance(Duration::from_secs(1)).await;
2841 close_fut.await.unwrap();
2842 match rx.try_recv() {
2843 Ok(Message::Close(_)) => {}
2844 other => panic!("expected Close, got {other:?}"),
2845 }
2846
2847 resume();
2848 }
2849
2850 #[tokio::test]
2851 async fn force_closes_after_timeout() {
2852 pause();
2853
2854 let conn = WebsocketConnection::new("c2");
2855 let (tx, mut rx) = unbounded_channel::<Message>();
2856 let (req_tx, _req_rx) = oneshot::channel();
2857 {
2858 let mut st = conn.state.lock().await;
2859 st.pending_requests.insert(
2860 "request_id".to_string(),
2861 PendingRequest { completion: req_tx },
2862 );
2863 }
2864 let common =
2865 WebsocketCommon::new(vec![conn.clone()], WebsocketMode::Single, 0, None);
2866 let close_fut = common.close_connection_gracefully(tx.clone(), conn.clone());
2867 advance(Duration::from_secs(30)).await;
2868 close_fut.await.unwrap();
2869 match rx.try_recv() {
2870 Ok(Message::Close(_)) => {}
2871 other => panic!("expected Close on timeout, got {other:?}"),
2872 }
2873
2874 resume();
2875 }
2876 }
2877
2878 mod get_reconnect_url {
2879 use super::*;
2880
2881 struct DummyHandler {
2882 url: String,
2883 }
2884
2885 #[async_trait::async_trait]
2886 impl WebsocketHandler for DummyHandler {
2887 async fn on_open(&self, _url: String, _connection: Arc<WebsocketConnection>) {}
2888 async fn on_message(&self, _data: String, _connection: Arc<WebsocketConnection>) {}
2889 async fn get_reconnect_url(
2890 &self,
2891 _default_url: String,
2892 _connection: Arc<WebsocketConnection>,
2893 ) -> String {
2894 self.url.clone()
2895 }
2896 }
2897
2898 #[test]
2899 fn returns_default_when_no_handler() {
2900 TOKIO_SHARED_RT.block_on(async {
2901 let conn = WebsocketConnection::new("c1");
2902 let common =
2903 WebsocketCommon::new(vec![conn.clone()], WebsocketMode::Single, 0, None);
2904 let default = "wss://default".to_string();
2905 let result = common.get_reconnect_url(&default, conn.clone()).await;
2906 assert_eq!(result, default);
2907 });
2908 }
2909
2910 #[test]
2911 fn returns_handler_url_when_set() {
2912 TOKIO_SHARED_RT.block_on(async {
2913 let conn = WebsocketConnection::new("c2");
2914 let handler = Arc::new(DummyHandler {
2915 url: "wss://custom".into(),
2916 });
2917 conn.set_handler(handler).await;
2918 let common =
2919 WebsocketCommon::new(vec![conn.clone()], WebsocketMode::Single, 0, None);
2920 let default = "wss://default".to_string();
2921 let result = common.get_reconnect_url(&default, conn.clone()).await;
2922 assert_eq!(result, "wss://custom");
2923 });
2924 }
2925 }
2926
2927 mod on_open {
2928 use super::*;
2929
2930 struct DummyHandler {
2931 called: Arc<Mutex<bool>>,
2932 opened_url: Arc<Mutex<Option<String>>>,
2933 }
2934
2935 #[async_trait]
2936 impl WebsocketHandler for DummyHandler {
2937 async fn on_open(&self, url: String, _connection: Arc<WebsocketConnection>) {
2938 let mut flag = self.called.lock().await;
2939 *flag = true;
2940 let mut store = self.opened_url.lock().await;
2941 *store = Some(url);
2942 }
2943 async fn on_message(&self, _data: String, _connection: Arc<WebsocketConnection>) {}
2944 async fn get_reconnect_url(
2945 &self,
2946 default_url: String,
2947 _connection: Arc<WebsocketConnection>,
2948 ) -> String {
2949 default_url
2950 }
2951 }
2952
2953 #[test]
2954 fn emits_open_and_calls_handler() {
2955 TOKIO_SHARED_RT.block_on(async {
2956 let conn = WebsocketConnection::new("c1");
2957 let called = Arc::new(Mutex::new(false));
2958 let opened_url = Arc::new(Mutex::new(None));
2959 let handler = Arc::new(DummyHandler {
2960 called: called.clone(),
2961 opened_url: opened_url.clone(),
2962 });
2963
2964 conn.set_handler(handler.clone()).await;
2965 let common =
2966 WebsocketCommon::new(vec![conn.clone()], WebsocketMode::Single, 0, None);
2967 let events = subscribe_events(&common);
2968 common
2969 .on_open("wss://example.com".into(), conn.clone(), None)
2970 .await;
2971
2972 sleep(std::time::Duration::from_millis(10)).await;
2973
2974 let evs = events.lock().await;
2975 assert!(evs.iter().any(|e| matches!(e, WebsocketEvent::Open)));
2976 assert!(*called.lock().await);
2977 assert_eq!(
2978 opened_url.lock().await.as_deref(),
2979 Some("wss://example.com")
2980 );
2981 });
2982 }
2983
2984 #[test]
2985 fn handles_renewal_pending_and_closes_old_writer() {
2986 TOKIO_SHARED_RT.block_on(async {
2987 let conn = WebsocketConnection::new("c2");
2988 let (old_tx, mut old_rx) = unbounded_channel::<Message>();
2989 {
2990 let mut st = conn.state.lock().await;
2991 st.renewal_pending = true;
2992 }
2993 let common =
2994 WebsocketCommon::new(vec![conn.clone()], WebsocketMode::Single, 0, None);
2995 common
2996 .on_open("url".into(), conn.clone(), Some(old_tx.clone()))
2997 .await;
2998 assert!(!conn.state.lock().await.renewal_pending);
2999 match old_rx.try_recv() {
3000 Ok(Message::Close(_)) => {}
3001 other => panic!("expected Close, got {other:?}"),
3002 }
3003 });
3004 }
3005 }
3006
3007 mod on_message {
3008 use super::*;
3009
3010 struct DummyHandler {
3011 called_with: Arc<Mutex<Vec<String>>>,
3012 }
3013
3014 #[async_trait]
3015 impl WebsocketHandler for DummyHandler {
3016 async fn on_open(&self, _url: String, _connection: Arc<WebsocketConnection>) {}
3017 async fn on_message(&self, data: String, _connection: Arc<WebsocketConnection>) {
3018 self.called_with.lock().await.push(data);
3019 }
3020 async fn get_reconnect_url(
3021 &self,
3022 default_url: String,
3023 _connection: Arc<WebsocketConnection>,
3024 ) -> String {
3025 default_url
3026 }
3027 }
3028
3029 #[test]
3030 fn emits_message_event_without_handler() {
3031 TOKIO_SHARED_RT.block_on(async {
3032 let conn = WebsocketConnection::new("c1");
3033 let common =
3034 WebsocketCommon::new(vec![conn.clone()], WebsocketMode::Single, 0, None);
3035 let events = subscribe_events(&common);
3036 common.on_message("msg".into(), conn.clone()).await;
3037
3038 sleep(Duration::from_millis(10)).await;
3039
3040 let locked = events.lock().await;
3041 assert!(
3042 locked
3043 .iter()
3044 .any(|e| matches!(e, WebsocketEvent::Message(m) if m == "msg"))
3045 );
3046 });
3047 }
3048
3049 #[test]
3050 fn calls_handler_and_emits_message() {
3051 TOKIO_SHARED_RT.block_on(async {
3052 let conn = WebsocketConnection::new("c2");
3053 let called = Arc::new(Mutex::new(Vec::new()));
3054 let handler = Arc::new(DummyHandler {
3055 called_with: called.clone(),
3056 });
3057 conn.set_handler(handler.clone()).await;
3058
3059 let common =
3060 WebsocketCommon::new(vec![conn.clone()], WebsocketMode::Single, 0, None);
3061 let events = subscribe_events(&common);
3062 common.on_message("msg".into(), conn.clone()).await;
3063
3064 sleep(Duration::from_millis(10)).await;
3065
3066 let evs = events.lock().await;
3067 assert!(
3068 evs.iter()
3069 .any(|e| matches!(e, WebsocketEvent::Message(m) if m == "msg"))
3070 );
3071 let msgs = called.lock().await;
3072 assert_eq!(msgs.as_slice(), &["msg".to_string()]);
3073 });
3074 }
3075 }
3076
3077 mod create_websocket {
3078 use super::*;
3079
3080 #[test]
3081 fn successful_connection() {
3082 TOKIO_SHARED_RT.block_on(async {
3083 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
3084 let addr: SocketAddr = listener.local_addr().unwrap();
3085 tokio::spawn(async move {
3086 if let Ok((stream, _)) = listener.accept().await {
3087 if let Ok(mut ws_stream) = accept_async(stream).await {
3088 let _ = ws_stream.close(None).await;
3089 }
3090 }
3091 });
3092
3093 let url = format!("ws://{addr}");
3094 let res = WebsocketCommon::create_websocket(&url, None).await;
3095 assert!(res.is_ok(), "Expected successful handshake, got {res:?}");
3096 });
3097 }
3098
3099 #[test]
3100 fn invalid_url_returns_handshake_error() {
3101 TOKIO_SHARED_RT.block_on(async {
3102 let res = WebsocketCommon::create_websocket("not-a-valid-url", None).await;
3103 assert!(matches!(res, Err(WebsocketError::Handshake(_))));
3104 });
3105 }
3106
3107 #[test]
3108 fn unreachable_host_returns_handshake_error() {
3109 TOKIO_SHARED_RT.block_on(async {
3110 let res = WebsocketCommon::create_websocket("ws://127.0.0.1:1", None).await;
3111 assert!(matches!(res, Err(WebsocketError::Handshake(_))));
3112 });
3113 }
3114 }
3115
3116 mod connect_pool {
3117 use super::*;
3118
3119 #[test]
3120 fn connects_all_in_pool() {
3121 TOKIO_SHARED_RT.block_on(async {
3122 let pool_size = 3;
3123 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
3124 let addr = listener.local_addr().unwrap();
3125 tokio::spawn(async move {
3126 for _ in 0..pool_size {
3127 if let Ok((stream, _)) = listener.accept().await {
3128 let mut ws = accept_async(stream).await.unwrap();
3129 let _ = ws.close(None).await;
3130 }
3131 }
3132 });
3133 let conns: Vec<Arc<WebsocketConnection>> = (0..pool_size)
3134 .map(|i| WebsocketConnection::new(format!("c{i}")))
3135 .collect();
3136 let common = WebsocketCommon::new(
3137 conns.clone(),
3138 WebsocketMode::Pool(pool_size),
3139 0,
3140 None,
3141 );
3142 let url = format!("ws://{addr}");
3143 common.clone().connect_pool(&url).await.unwrap();
3144 for conn in conns {
3145 let st = conn.state.lock().await;
3146 assert!(st.ws_write_tx.is_some());
3147 }
3148 });
3149 }
3150
3151 #[test]
3152 fn fails_if_any_refused() {
3153 TOKIO_SHARED_RT.block_on(async {
3154 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
3155 let addr = listener.local_addr().unwrap();
3156 let pool_size = 3;
3157 tokio::spawn(async move {
3158 for _ in 0..2 {
3159 if let Ok((stream, _)) = listener.accept().await {
3160 let mut ws = accept_async(stream).await.unwrap();
3161 let _ = ws.close(None).await;
3162 }
3163 }
3164 });
3165 let mut conns = Vec::new();
3166 let valid_url = format!("ws://{addr}");
3167 for i in 0..2 {
3168 conns.push(WebsocketConnection::new(format!("c{i}")));
3169 }
3170 conns.push(WebsocketConnection::new("bad"));
3171 let common = WebsocketCommon::new(
3172 conns.clone(),
3173 WebsocketMode::Pool(pool_size),
3174 0,
3175 None,
3176 );
3177 let res = common.clone().connect_pool(&valid_url).await;
3178 assert!(matches!(res, Err(WebsocketError::Handshake(_))));
3179 });
3180 }
3181
3182 #[test]
3183 fn fails_on_invalid_url() {
3184 TOKIO_SHARED_RT.block_on(async {
3185 let conns = vec![WebsocketConnection::new("c1")];
3186 let common = WebsocketCommon::new(conns, WebsocketMode::Pool(1), 0, None);
3187 let res = common.connect_pool("not-a-url").await;
3188 assert!(matches!(res, Err(WebsocketError::Handshake(_))));
3189 });
3190 }
3191
3192 #[test]
3193 fn fails_if_mixed_success_and_invalid_url() {
3194 TOKIO_SHARED_RT.block_on(async {
3195 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
3196 let addr = listener.local_addr().unwrap();
3197 tokio::spawn(async move {
3198 if let Ok((stream, _)) = listener.accept().await {
3199 let mut ws = accept_async(stream).await.unwrap();
3200 let _ = ws.close(None).await;
3201 }
3202 });
3203 let good = WebsocketConnection::new("good");
3204 let bad = WebsocketConnection::new("bad");
3205 let common =
3206 WebsocketCommon::new(vec![good, bad], WebsocketMode::Pool(2), 0, None);
3207 let url = format!("ws://{addr}");
3208 let res = common.connect_pool(&url).await;
3209 assert!(matches!(res, Err(WebsocketError::Handshake(_))));
3210 });
3211 }
3212
3213 #[test]
3214 fn init_connect_invoked_for_each() {
3215 TOKIO_SHARED_RT.block_on(async {
3216 let pool_size = 2;
3217 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
3218 let addr = listener.local_addr().unwrap();
3219 tokio::spawn(async move {
3220 for _ in 0..pool_size {
3221 if let Ok((stream, _)) = listener.accept().await {
3222 let mut ws = accept_async(stream).await.unwrap();
3223 let _ = ws.close(None).await;
3224 }
3225 }
3226 });
3227 let conns: Vec<Arc<WebsocketConnection>> = (0..pool_size)
3228 .map(|i| WebsocketConnection::new(format!("c{i}")))
3229 .collect();
3230 let common = WebsocketCommon::new(
3231 conns.clone(),
3232 WebsocketMode::Pool(pool_size),
3233 0,
3234 None,
3235 );
3236 let url = format!("ws://{addr}");
3237 common.clone().connect_pool(&url).await.unwrap();
3238 for conn in conns {
3239 let st = conn.state.lock().await;
3240 assert!(st.ws_write_tx.is_some());
3241 }
3242 });
3243 }
3244
3245 #[test]
3246 fn single_mode_uses_first_connection() {
3247 TOKIO_SHARED_RT.block_on(async {
3248 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
3249 let addr = listener.local_addr().unwrap();
3250 tokio::spawn(async move {
3251 if let Ok((stream, _)) = listener.accept().await {
3252 let mut ws = accept_async(stream).await.unwrap();
3253 let _ = ws.close(None).await;
3254 }
3255 });
3256 let conn = WebsocketConnection::new("c1");
3257 let common =
3258 WebsocketCommon::new(vec![conn.clone()], WebsocketMode::Single, 0, None);
3259 let url = format!("ws://{addr}");
3260 common.connect_pool(&url).await.unwrap();
3261 let st = conn.state.lock().await;
3262 assert!(st.ws_write_tx.is_some());
3263 });
3264 }
3265 }
3266
3267 mod init_connect {
3268 use super::*;
3269
3270 #[test]
3271 fn pool_mode_none_connection_uses_first() {
3272 TOKIO_SHARED_RT.block_on(async {
3273 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
3274 let addr = listener.local_addr().unwrap();
3275 tokio::spawn(async move {
3276 for _ in 0..2 {
3277 if let Ok((stream, _)) = listener.accept().await {
3278 let mut ws = accept_async(stream).await.unwrap();
3279 ws.close(None).await.ok();
3280 }
3281 }
3282 });
3283
3284 let c1 = WebsocketConnection::new("c1");
3285 let c2 = WebsocketConnection::new("c2");
3286 let common = WebsocketCommon::new(
3287 vec![c1.clone(), c2.clone()],
3288 WebsocketMode::Pool(2),
3289 0,
3290 None,
3291 );
3292 let url = format!("ws://{addr}");
3293
3294 common
3295 .clone()
3296 .init_connect(&url, false, None)
3297 .await
3298 .unwrap();
3299 let st1 = c1.state.lock().await;
3300 let st2 = c2.state.lock().await;
3301
3302 assert!(st1.ws_write_tx.is_some());
3303 assert!(st2.ws_write_tx.is_none());
3304 });
3305 }
3306
3307 #[test]
3308 fn writer_channel_can_send_text() {
3309 TOKIO_SHARED_RT.block_on(async {
3310 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
3311 let addr = listener.local_addr().unwrap();
3312 let received = Arc::new(Mutex::new(None::<String>));
3313 let received_clone = received.clone();
3314
3315 tokio::spawn(async move {
3316 if let Ok((stream, _)) = listener.accept().await {
3317 let mut ws = accept_async(stream).await.unwrap();
3318 if let Some(Ok(Message::Text(txt))) = ws.next().await {
3319 *received_clone.lock().await = Some(txt.to_string());
3320 }
3321 ws.close(None).await.ok();
3322 }
3323 });
3324
3325 let conn = WebsocketConnection::new("cw");
3326 let common =
3327 WebsocketCommon::new(vec![conn.clone()], WebsocketMode::Single, 0, None);
3328 let url = format!("ws://{addr}");
3329 common
3330 .clone()
3331 .init_connect(&url, false, Some(conn.clone()))
3332 .await
3333 .unwrap();
3334
3335 let tx = conn.state.lock().await.ws_write_tx.clone().unwrap();
3336 tx.send(Message::Text("ping".into())).unwrap();
3337
3338 sleep(Duration::from_millis(50)).await;
3339
3340 let lock = received.lock().await;
3341 assert_eq!(lock.as_deref(), Some("ping"));
3342 });
3343 }
3344
3345 #[test]
3346 fn responds_to_ping_with_pong() {
3347 TOKIO_SHARED_RT.block_on(async {
3348 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
3349 let addr = listener.local_addr().unwrap();
3350
3351 let saw_pong = Arc::new(Mutex::new(false));
3352 let saw_pong2 = saw_pong.clone();
3353
3354 tokio::spawn(async move {
3355 if let Ok((stream, _)) = listener.accept().await {
3356 let mut ws = accept_async(stream).await.unwrap();
3357 ws.send(Message::Ping(vec![1, 2, 3].into())).await.unwrap();
3358 if let Some(Ok(Message::Pong(payload))) = ws.next().await {
3359 if payload[..] == [1, 2, 3] {
3360 *saw_pong2.lock().await = true;
3361 }
3362 }
3363 let _ = ws.close(None).await;
3364 }
3365 });
3366
3367 let conn = WebsocketConnection::new("c-ping");
3368 let common =
3369 WebsocketCommon::new(vec![conn.clone()], WebsocketMode::Single, 0, None);
3370 let url = format!("ws://{addr}");
3371 common
3372 .clone()
3373 .init_connect(&url, false, Some(conn))
3374 .await
3375 .unwrap();
3376
3377 sleep(Duration::from_millis(50)).await;
3378
3379 assert!(*saw_pong.lock().await, "server should have seen a Pong");
3380 });
3381 }
3382
3383 #[test]
3384 fn handshake_error_on_invalid_url() {
3385 TOKIO_SHARED_RT.block_on(async {
3386 let conn = WebsocketConnection::new("c-invalid");
3387 let common =
3388 WebsocketCommon::new(vec![conn.clone()], WebsocketMode::Single, 0, None);
3389 let res = common
3390 .clone()
3391 .init_connect("not-a-url", false, Some(conn.clone()))
3392 .await;
3393 assert!(matches!(res, Err(WebsocketError::Handshake(_))));
3394 });
3395 }
3396
3397 #[test]
3398 fn skip_if_writer_exists_and_not_renewal() {
3399 TOKIO_SHARED_RT.block_on(async {
3400 let conn = WebsocketConnection::new("c-writer");
3401 let (tx, mut rx) = unbounded_channel::<Message>();
3402 {
3403 let mut st = conn.state.lock().await;
3404 st.ws_write_tx = Some(tx.clone());
3405 }
3406 let common =
3407 WebsocketCommon::new(vec![conn.clone()], WebsocketMode::Single, 0, None);
3408 let res = common
3409 .clone()
3410 .init_connect("ws://127.0.0.1:1", false, Some(conn.clone()))
3411 .await;
3412
3413 assert!(res.is_ok());
3414 assert!(rx.try_recv().is_err());
3415 });
3416 }
3417
3418 #[test]
3419 fn short_circuit_on_already_renewing() {
3420 TOKIO_SHARED_RT.block_on(async {
3421 let conn = WebsocketConnection::new("c-renew");
3422 {
3423 let mut st = conn.state.lock().await;
3424 st.renewal_pending = true;
3425 }
3426 let common =
3427 WebsocketCommon::new(vec![conn.clone()], WebsocketMode::Single, 0, None);
3428 let res = common
3429 .clone()
3430 .init_connect("ws://127.0.0.1:1", true, Some(conn.clone()))
3431 .await;
3432
3433 assert!(res.is_ok());
3434 assert!(conn.state.lock().await.ws_write_tx.is_none());
3435 });
3436 }
3437
3438 #[test]
3439 fn is_renewal_true_sets_and_clears_flag() {
3440 TOKIO_SHARED_RT.block_on(async {
3441 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
3442 let addr = listener.local_addr().unwrap();
3443 tokio::spawn(async move {
3444 if let Ok((stream, _)) = listener.accept().await {
3445 let mut ws = accept_async(stream).await.unwrap();
3446 let _ = ws.close(None).await;
3447 }
3448 });
3449
3450 let conn = WebsocketConnection::new("c-new-renew");
3451 let common =
3452 WebsocketCommon::new(vec![conn.clone()], WebsocketMode::Single, 0, None);
3453 let url = format!("ws://{addr}");
3454 let res = common
3455 .clone()
3456 .init_connect(&url, true, Some(conn.clone()))
3457 .await;
3458
3459 assert!(res.is_ok());
3460 let st = conn.state.lock().await;
3461 assert!(st.ws_write_tx.is_some());
3462 assert!(!st.renewal_pending);
3463 });
3464 }
3465
3466 #[test]
3467 fn default_connection_selected_when_none_passed() {
3468 TOKIO_SHARED_RT.block_on(async {
3469 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
3470 let addr = listener.local_addr().unwrap();
3471 tokio::spawn(async move {
3472 if let Ok((stream, _)) = listener.accept().await {
3473 let mut ws = accept_async(stream).await.unwrap();
3474 let _ = ws.close(None).await;
3475 }
3476 });
3477 let conn = WebsocketConnection::new("c-default");
3478 let common =
3479 WebsocketCommon::new(vec![conn.clone()], WebsocketMode::Single, 0, None);
3480 let url = format!("ws://{addr}");
3481 let res = common.clone().init_connect(&url, false, None).await;
3482
3483 assert!(res.is_ok());
3484 assert!(conn.state.lock().await.ws_write_tx.is_some());
3485 });
3486 }
3487
3488 #[test]
3489 fn schedules_reconnect_on_abnormal_close() {
3490 TOKIO_SHARED_RT.block_on(async {
3491 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
3492 let addr = listener.local_addr().unwrap();
3493 tokio::spawn(async move {
3494 if let Ok((stream, _)) = listener.accept().await {
3495 let mut ws = accept_async(stream).await.unwrap();
3496 ws.close(Some(tungstenite::protocol::CloseFrame {
3497 code: tungstenite::protocol::frame::coding::CloseCode::Abnormal,
3498 reason: "oops".into(),
3499 }))
3500 .await
3501 .ok();
3502 }
3503 });
3504 let conn = WebsocketConnection::new("c-close");
3505 let common =
3506 WebsocketCommon::new(vec![conn.clone()], WebsocketMode::Single, 10, None);
3507 let url = format!("ws://{addr}");
3508 common
3509 .clone()
3510 .init_connect(&url, false, Some(conn.clone()))
3511 .await
3512 .unwrap();
3513
3514 sleep(Duration::from_millis(50)).await;
3515
3516 let st = conn.state.lock().await;
3517 assert!(
3518 st.reconnection_pending,
3519 "expected reconnection_pending to be true after abnormal close"
3520 );
3521 });
3522 }
3523 }
3524
3525 mod disconnect {
3526 use super::*;
3527
3528 #[test]
3529 fn returns_ok_when_no_connections_are_ready() {
3530 TOKIO_SHARED_RT.block_on(async {
3531 let conn = WebsocketConnection::new("c1");
3532 let common =
3533 WebsocketCommon::new(vec![conn.clone()], WebsocketMode::Single, 0, None);
3534 let res = common.disconnect().await;
3535
3536 assert!(res.is_ok());
3537 assert!(!conn.state.lock().await.close_initiated);
3538 });
3539 }
3540
3541 #[test]
3542 fn closes_all_ready_connections() {
3543 TOKIO_SHARED_RT.block_on(async {
3544 let conn1 = WebsocketConnection::new("c1");
3545 let conn2 = WebsocketConnection::new("c2");
3546 let (tx1, mut rx1) = unbounded_channel::<Message>();
3547 let (tx2, mut rx2) = unbounded_channel::<Message>();
3548 {
3549 let mut s1 = conn1.state.lock().await;
3550 s1.ws_write_tx = Some(tx1);
3551 }
3552 {
3553 let mut s2 = conn2.state.lock().await;
3554 s2.ws_write_tx = Some(tx2);
3555 }
3556 let common = WebsocketCommon::new(
3557 vec![conn1.clone(), conn2.clone()],
3558 WebsocketMode::Pool(2),
3559 0,
3560 None,
3561 );
3562 let fut = common.disconnect();
3563
3564 sleep(Duration::from_millis(50)).await;
3565
3566 fut.await.unwrap();
3567
3568 assert!(conn1.state.lock().await.close_initiated);
3569 assert!(conn2.state.lock().await.close_initiated);
3570
3571 match (rx1.try_recv(), rx2.try_recv()) {
3572 (Ok(Message::Close(_)), Ok(Message::Close(_))) => {}
3573 other => panic!("expected two Closes, got {other:?}"),
3574 }
3575 });
3576 }
3577
3578 #[test]
3579 fn does_not_mark_close_initiated_if_no_writer() {
3580 TOKIO_SHARED_RT.block_on(async {
3581 let conn = WebsocketConnection::new("c-new");
3582 let common =
3583 WebsocketCommon::new(vec![conn.clone()], WebsocketMode::Single, 0, None);
3584 common.disconnect().await.unwrap();
3585
3586 assert!(!conn.state.lock().await.close_initiated);
3587 });
3588 }
3589
3590 #[test]
3591 fn mixed_pool_marks_all_and_closes_only_writers() {
3592 TOKIO_SHARED_RT.block_on(async {
3593 let conn_w = WebsocketConnection::new("with");
3594 let conn_wo = WebsocketConnection::new("without");
3595 let (tx, mut rx) = unbounded_channel::<Message>();
3596 {
3597 let mut st = conn_w.state.lock().await;
3598 st.ws_write_tx = Some(tx);
3599 }
3600 let common = WebsocketCommon::new(
3601 vec![conn_w.clone(), conn_wo.clone()],
3602 WebsocketMode::Pool(2),
3603 0,
3604 None,
3605 );
3606 let fut = common.disconnect();
3607
3608 sleep(Duration::from_millis(50)).await;
3609
3610 fut.await.unwrap();
3611
3612 assert!(conn_w.state.lock().await.close_initiated);
3613 assert!(conn_wo.state.lock().await.close_initiated);
3614 assert!(matches!(rx.try_recv(), Ok(Message::Close(_))));
3615 });
3616 }
3617
3618 #[test]
3619 fn after_disconnect_not_connected() {
3620 TOKIO_SHARED_RT.block_on(async {
3621 let conn = WebsocketConnection::new("c1");
3622 let (tx, mut _rx) = unbounded_channel::<Message>();
3623 {
3624 let mut st = conn.state.lock().await;
3625 st.ws_write_tx = Some(tx);
3626 }
3627 let common =
3628 WebsocketCommon::new(vec![conn.clone()], WebsocketMode::Single, 0, None);
3629 common.disconnect().await.unwrap();
3630 assert!(!common.is_connected(Some(&conn)).await);
3631 });
3632 }
3633 }
3634
3635 mod ping_server {
3636 use super::*;
3637
3638 #[test]
3639 fn sends_ping_to_all_ready_connections() {
3640 TOKIO_SHARED_RT.block_on(async {
3641 let mut conns = Vec::new();
3642 for i in 0..3 {
3643 let conn = WebsocketConnection::new(format!("c{i}"));
3644 let (tx, rx) = unbounded_channel::<Message>();
3645 {
3646 let mut st = conn.state.lock().await;
3647 st.ws_write_tx = Some(tx);
3648 }
3649 conns.push((conn, rx));
3650 }
3651 let common = WebsocketCommon::new(
3652 conns.iter().map(|(c, _)| c.clone()).collect(),
3653 WebsocketMode::Pool(3),
3654 0,
3655 None,
3656 );
3657 common.ping_server().await;
3658 for (_, mut rx) in conns {
3659 match rx.try_recv() {
3660 Ok(Message::Ping(payload)) if payload.is_empty() => {}
3661 other => panic!("expected empty-payload Ping, got {other:?}"),
3662 }
3663 }
3664 });
3665 }
3666
3667 #[test]
3668 fn skips_not_ready_and_partial() {
3669 TOKIO_SHARED_RT.block_on(async {
3670 let ready = WebsocketConnection::new("ready");
3671 let not_ready = WebsocketConnection::new("not-ready");
3672 let (tx_r, mut rx_r) = unbounded_channel::<Message>();
3673 {
3674 let mut st = ready.state.lock().await;
3675 st.ws_write_tx = Some(tx_r);
3676 }
3677 {
3678 let mut st = not_ready.state.lock().await;
3679 st.ws_write_tx = None;
3680 }
3681 let common = WebsocketCommon::new(
3682 vec![ready.clone(), not_ready.clone()],
3683 WebsocketMode::Pool(2),
3684 0,
3685 None,
3686 );
3687 common.ping_server().await;
3688 match rx_r.try_recv() {
3689 Ok(Message::Ping(payload)) if payload.is_empty() => {}
3690 other => panic!("expected Ping on ready, got {other:?}"),
3691 }
3692 });
3693 }
3694
3695 #[test]
3696 fn no_ping_when_flags_block() {
3697 TOKIO_SHARED_RT.block_on(async {
3698 let conn = WebsocketConnection::new("c1");
3699 let (tx, mut rx) = unbounded_channel::<Message>();
3700 {
3701 let mut st = conn.state.lock().await;
3702 st.ws_write_tx = Some(tx);
3703 st.reconnection_pending = true;
3704 }
3705 let common =
3706 WebsocketCommon::new(vec![conn.clone()], WebsocketMode::Single, 0, None);
3707 common.ping_server().await;
3708 assert!(rx.try_recv().is_err());
3709 });
3710 }
3711 }
3712
3713 mod send {
3714 use super::*;
3715
3716 #[test]
3717 fn round_robin_send_without_specific() {
3718 TOKIO_SHARED_RT.block_on(async {
3719 let conn1 = WebsocketConnection::new("c1");
3720 let conn2 = WebsocketConnection::new("c2");
3721 let (tx1, mut rx1) = unbounded_channel::<Message>();
3722 let (tx2, mut rx2) = unbounded_channel::<Message>();
3723 {
3724 let mut s1 = conn1.state.lock().await;
3725 s1.ws_write_tx = Some(tx1);
3726 }
3727 {
3728 let mut s2 = conn2.state.lock().await;
3729 s2.ws_write_tx = Some(tx2);
3730 }
3731 let common = WebsocketCommon::new(
3732 vec![conn1.clone(), conn2.clone()],
3733 WebsocketMode::Pool(2),
3734 0,
3735 None,
3736 );
3737
3738 let res1 = common
3739 .send("a".into(), None, false, Duration::from_secs(1), None)
3740 .await
3741 .unwrap();
3742 assert!(res1.is_none());
3743
3744 let res2 = common
3745 .send("b".into(), None, false, Duration::from_secs(1), None)
3746 .await
3747 .unwrap();
3748 assert!(res2.is_none());
3749
3750 assert_eq!(
3751 if let Message::Text(t) = rx1.try_recv().unwrap() {
3752 t
3753 } else {
3754 panic!()
3755 },
3756 "a"
3757 );
3758 assert_eq!(
3759 if let Message::Text(t) = rx2.try_recv().unwrap() {
3760 t
3761 } else {
3762 panic!()
3763 },
3764 "b"
3765 );
3766 });
3767 }
3768
3769 #[test]
3770 fn round_robin_skips_not_ready() {
3771 TOKIO_SHARED_RT.block_on(async {
3772 let conn1 = WebsocketConnection::new("c1");
3773 let conn2 = WebsocketConnection::new("c2");
3774 let (tx2, mut rx2) = unbounded_channel::<Message>();
3775 {
3776 let mut s1 = conn1.state.lock().await;
3777 s1.ws_write_tx = None;
3778 }
3779 {
3780 let mut s2 = conn2.state.lock().await;
3781 s2.ws_write_tx = Some(tx2);
3782 }
3783 let common = WebsocketCommon::new(
3784 vec![conn1.clone(), conn2.clone()],
3785 WebsocketMode::Pool(2),
3786 0,
3787 None,
3788 );
3789 let res = common
3790 .send("bar".into(), None, false, Duration::from_secs(1), None)
3791 .await
3792 .unwrap();
3793 assert!(res.is_none());
3794 match rx2.try_recv().unwrap() {
3795 Message::Text(t) => assert_eq!(t, "bar"),
3796 other => panic!("unexpected {other:?}"),
3797 }
3798 });
3799 }
3800
3801 #[test]
3802 fn sync_send_on_specific_connection() {
3803 TOKIO_SHARED_RT.block_on(async {
3804 let conn1 = WebsocketConnection::new("c1");
3805 let conn2 = WebsocketConnection::new("c2");
3806 let (tx2, mut rx2) = unbounded_channel::<Message>();
3807 {
3808 let mut st = conn2.state.lock().await;
3809 st.ws_write_tx = Some(tx2);
3810 }
3811 let common = WebsocketCommon::new(
3812 vec![conn1.clone(), conn2.clone()],
3813 WebsocketMode::Pool(2),
3814 0,
3815 None,
3816 );
3817 let res = common
3818 .send(
3819 "payload".into(),
3820 Some("id".into()),
3821 false,
3822 Duration::from_secs(1),
3823 Some(conn2.clone()),
3824 )
3825 .await
3826 .unwrap();
3827 assert!(res.is_none());
3828 match rx2.try_recv() {
3829 Ok(Message::Text(t)) => assert_eq!(t, "payload"),
3830 other => panic!("expected Text, got {other:?}"),
3831 }
3832 });
3833 }
3834
3835 #[test]
3836 fn sync_send_with_id_does_not_insert_pending() {
3837 TOKIO_SHARED_RT.block_on(async {
3838 let conn = WebsocketConnection::new("c1");
3839 let (tx, mut rx) = unbounded_channel::<Message>();
3840 {
3841 let mut st = conn.state.lock().await;
3842 st.ws_write_tx = Some(tx);
3843 }
3844 let common =
3845 WebsocketCommon::new(vec![conn.clone()], WebsocketMode::Single, 0, None);
3846 let res = common
3847 .send(
3848 "msg".into(),
3849 Some("id".into()),
3850 false,
3851 Duration::from_secs(1),
3852 Some(conn.clone()),
3853 )
3854 .await
3855 .unwrap();
3856 assert!(res.is_none());
3857 assert!(conn.state.lock().await.pending_requests.is_empty());
3858 match rx.try_recv().unwrap() {
3859 Message::Text(t) => assert_eq!(t, "msg"),
3860 other => panic!("unexpected {other:?}"),
3861 }
3862 });
3863 }
3864
3865 #[test]
3866 fn sync_send_error_if_not_ready() {
3867 TOKIO_SHARED_RT.block_on(async {
3868 let conn = WebsocketConnection::new("c1");
3869 let common =
3870 WebsocketCommon::new(vec![conn.clone()], WebsocketMode::Single, 0, None);
3871 let err = common
3872 .send(
3873 "msg".into(),
3874 Some("id".into()),
3875 false,
3876 Duration::from_secs(1),
3877 Some(conn.clone()),
3878 )
3879 .await
3880 .unwrap_err();
3881 assert!(matches!(err, WebsocketError::NotConnected));
3882 });
3883 }
3884
3885 #[test]
3886 fn sync_send_error_when_no_ready() {
3887 TOKIO_SHARED_RT.block_on(async {
3888 let conn = WebsocketConnection::new("c1");
3889 let common =
3890 WebsocketCommon::new(vec![conn.clone()], WebsocketMode::Single, 0, None);
3891 let err = common
3892 .send("msg".into(), None, false, Duration::from_secs(1), None)
3893 .await
3894 .unwrap_err();
3895 assert!(matches!(err, WebsocketError::NotConnected));
3896 });
3897 }
3898
3899 #[test]
3900 fn async_send_and_receive() {
3901 TOKIO_SHARED_RT.block_on(async {
3902 let conn = WebsocketConnection::new("c1");
3903 let (tx, mut rx) = unbounded_channel::<Message>();
3904 {
3905 let mut st = conn.state.lock().await;
3906 st.ws_write_tx = Some(tx);
3907 }
3908 let common =
3909 WebsocketCommon::new(vec![conn.clone()], WebsocketMode::Single, 0, None);
3910 let fut = common
3911 .send(
3912 "hello".into(),
3913 Some("id".into()),
3914 true,
3915 Duration::from_secs(5),
3916 Some(conn.clone()),
3917 )
3918 .await
3919 .unwrap()
3920 .unwrap();
3921 match rx.try_recv() {
3922 Ok(Message::Text(t)) => assert_eq!(t, "hello"),
3923 other => panic!("expected Text, got {other:?}"),
3924 }
3925 {
3926 let mut st = conn.state.lock().await;
3927 let pr = st.pending_requests.remove("id").unwrap();
3928 pr.completion.send(Ok(serde_json::json!("ok"))).unwrap();
3929 }
3930 let resp = fut.await.unwrap().unwrap();
3931 assert_eq!(resp, serde_json::json!("ok"));
3932 });
3933 }
3934
3935 #[test]
3936 fn async_send_default_connection() {
3937 TOKIO_SHARED_RT.block_on(async {
3938 let conn = WebsocketConnection::new("c1");
3939 let (tx, mut rx) = unbounded_channel::<Message>();
3940 {
3941 let mut st = conn.state.lock().await;
3942 st.ws_write_tx = Some(tx);
3943 }
3944 let common =
3945 WebsocketCommon::new(vec![conn.clone()], WebsocketMode::Single, 0, None);
3946 let fut = common
3947 .send(
3948 "msg".into(),
3949 Some("id".into()),
3950 true,
3951 Duration::from_secs(5),
3952 None,
3953 )
3954 .await
3955 .unwrap()
3956 .unwrap();
3957 match rx.try_recv() {
3958 Ok(Message::Text(t)) => assert_eq!(t, "msg"),
3959 _ => panic!("no text"),
3960 }
3961 {
3962 let mut st = conn.state.lock().await;
3963 let pr = st.pending_requests.remove("id").unwrap();
3964 pr.completion.send(Ok(serde_json::json!(123))).unwrap();
3965 }
3966 let resp = fut.await.unwrap().unwrap();
3967 assert_eq!(resp, serde_json::json!(123));
3968 });
3969 }
3970
3971 #[test]
3972 fn async_send_error_if_no_id() {
3973 TOKIO_SHARED_RT.block_on(async {
3974 let conn = WebsocketConnection::new("c§");
3975 let (tx, _rx) = unbounded_channel::<Message>();
3976 {
3977 let mut st = conn.state.lock().await;
3978 st.ws_write_tx = Some(tx);
3979 }
3980 let common =
3981 WebsocketCommon::new(vec![conn.clone()], WebsocketMode::Single, 0, None);
3982 let err = common
3983 .send(
3984 "msg".into(),
3985 None,
3986 true,
3987 Duration::from_secs(1),
3988 Some(conn.clone()),
3989 )
3990 .await
3991 .unwrap_err();
3992 assert!(matches!(err, WebsocketError::NotConnected));
3993 });
3994 }
3995
3996 #[test]
3997 fn timeout_rejects_async() {
3998 TOKIO_SHARED_RT.block_on(async {
3999 pause();
4000 let conn = WebsocketConnection::new("c1");
4001 let (tx, _rx) = unbounded_channel::<Message>();
4002 {
4003 let mut st = conn.state.lock().await;
4004 st.ws_write_tx = Some(tx);
4005 }
4006 let common =
4007 WebsocketCommon::new(vec![conn.clone()], WebsocketMode::Single, 0, None);
4008 let fut = common
4009 .send(
4010 "msg".into(),
4011 Some("id".into()),
4012 true,
4013 Duration::from_secs(1),
4014 Some(conn.clone()),
4015 )
4016 .await
4017 .unwrap()
4018 .unwrap();
4019 advance(Duration::from_secs(1)).await;
4020 let res = fut.await.unwrap();
4021 assert!(res.is_err(), "expected timeout error");
4022 assert!(!conn.state.lock().await.pending_requests.contains_key("id"));
4023 });
4024 }
4025
4026 #[test]
4027 fn async_send_errors_if_no_connection_ready() {
4028 TOKIO_SHARED_RT.block_on(async {
4029 let conn = WebsocketConnection::new("c1");
4030 let common =
4031 WebsocketCommon::new(vec![conn.clone()], WebsocketMode::Single, 0, None);
4032 let err = common
4033 .send(
4034 "msg".into(),
4035 Some("id".into()),
4036 true,
4037 Duration::from_secs(1),
4038 None,
4039 )
4040 .await
4041 .unwrap_err();
4042 assert!(matches!(err, WebsocketError::NotConnected));
4043 });
4044 }
4045 }
4046 }
4047
4048 mod websocket_api {
4049 use super::*;
4050
4051 mod initialisation {
4052 use super::*;
4053
4054 #[test]
4055 fn new_initializes_common() {
4056 TOKIO_SHARED_RT.block_on(async {
4057 let conn = WebsocketConnection::new("id");
4058 let pool = vec![conn.clone()];
4059
4060 let sig_gen = SignatureGenerator::new(
4061 Some("api_secret".to_string()),
4062 None::<PrivateKey>,
4063 None::<String>,
4064 );
4065
4066 let config = ConfigurationWebsocketApi {
4067 api_key: Some("api_key".to_string()),
4068 api_secret: Some("api_secret".to_string()),
4069 private_key: None,
4070 private_key_passphrase: None,
4071 ws_url: Some("wss://example".to_string()),
4072 mode: WebsocketMode::Single,
4073 reconnect_delay: 1000,
4074 signature_gen: sig_gen,
4075 timeout: 500,
4076 time_unit: None,
4077 agent: None,
4078 };
4079
4080 let api = WebsocketApi::new(config, pool.clone());
4081
4082 assert_eq!(api.common.connection_pool.len(), 1);
4083 assert_eq!(api.common.mode, WebsocketMode::Single);
4084
4085 let flag = *api.is_connecting.lock().await;
4086 assert!(!flag);
4087 });
4088 }
4089 }
4090
4091 mod connect {
4092 use super::*;
4093
4094 #[test]
4095 fn connect_when_not_connected_establishes() {
4096 TOKIO_SHARED_RT.block_on(async {
4097 let conn = WebsocketConnection::new("id");
4098 {
4099 let mut st = conn.state.lock().await;
4100 st.ws_write_tx = None;
4101 }
4102 let sig = SignatureGenerator::new(
4103 Some("api_secret".into()),
4104 None::<PrivateKey>,
4105 None::<String>,
4106 );
4107 let cfg = ConfigurationWebsocketApi {
4108 api_key: Some("api_key".into()),
4109 api_secret: Some("api_secret".to_string()),
4110 private_key: None,
4111 private_key_passphrase: None,
4112 ws_url: Some("ws://doesnotexist:1".to_string()),
4113 mode: WebsocketMode::Single,
4114 reconnect_delay: 0,
4115 signature_gen: sig,
4116 timeout: 10,
4117 time_unit: None,
4118 agent: None,
4119 };
4120 let api = WebsocketApi::new(cfg, vec![conn.clone()]);
4121 let res = api.clone().connect().await;
4122 assert!(!matches!(res, Err(WebsocketError::Timeout)));
4123 });
4124 }
4125
4126 #[test]
4127 fn already_connected_returns_ok() {
4128 TOKIO_SHARED_RT.block_on(async {
4129 let conn = WebsocketConnection::new("id2");
4130 let (tx, _) = unbounded_channel();
4131 {
4132 let mut st = conn.state.lock().await;
4133 st.ws_write_tx = Some(tx);
4134 }
4135 let sig = SignatureGenerator::new(
4136 Some("api_secret".to_string()),
4137 None::<PrivateKey>,
4138 None::<String>,
4139 );
4140 let cfg = ConfigurationWebsocketApi {
4141 api_key: Some("api_key".to_string()),
4142 api_secret: Some("api_secret".to_string()),
4143 private_key: None,
4144 private_key_passphrase: None,
4145 ws_url: Some("ws://example.com".to_string()),
4146 mode: WebsocketMode::Single,
4147 reconnect_delay: 0,
4148 signature_gen: sig,
4149 timeout: 10,
4150 time_unit: None,
4151 agent: None,
4152 };
4153 let api = WebsocketApi::new(cfg, vec![conn.clone()]);
4154 let res = api.connect().await;
4155 assert!(res.is_ok());
4156 });
4157 }
4158
4159 #[test]
4160 fn not_connected_returns_error() {
4161 TOKIO_SHARED_RT.block_on(async {
4162 let conn = WebsocketConnection::new("id1");
4163 let sig = SignatureGenerator::new(
4164 Some("api_secret".to_string()),
4165 None::<PrivateKey>,
4166 None::<String>,
4167 );
4168 let cfg = ConfigurationWebsocketApi {
4169 api_key: Some("api_key".to_string()),
4170 api_secret: Some("api_secret".to_string()),
4171 private_key: None,
4172 private_key_passphrase: None,
4173 ws_url: Some("ws://127.0.0.1:9".to_string()),
4174 mode: WebsocketMode::Single,
4175 reconnect_delay: 0,
4176 signature_gen: sig,
4177 timeout: 10,
4178 time_unit: None,
4179 agent: None,
4180 };
4181 let api = WebsocketApi::new(cfg, vec![conn.clone()]);
4182 let res = api.connect().await;
4183 assert!(res.is_err());
4184 });
4185 }
4186
4187 #[test]
4188 fn concurrent_calls_both_error_or_ok() {
4189 TOKIO_SHARED_RT.block_on(async {
4190 let conn = WebsocketConnection::new("id3");
4191 let sig = SignatureGenerator::new(
4192 Some("api_secret".to_string()),
4193 None::<PrivateKey>,
4194 None::<String>,
4195 );
4196 let cfg = ConfigurationWebsocketApi {
4197 api_key: Some("api_key".to_string()),
4198 api_secret: Some("api_secret".to_string()),
4199 private_key: None,
4200 private_key_passphrase: None,
4201 ws_url: Some("wss://invalid-domain".to_string()),
4202 mode: WebsocketMode::Single,
4203 reconnect_delay: 0,
4204 signature_gen: sig,
4205 timeout: 10,
4206 time_unit: None,
4207 agent: None,
4208 };
4209 let api = WebsocketApi::new(cfg, vec![conn.clone()]);
4210 let fut1 = tokio::spawn(api.clone().connect());
4211 let fut2 = tokio::spawn(api.clone().connect());
4212 let r1 = fut1.await.unwrap();
4213 let r2 = fut2.await.unwrap();
4214
4215 assert!(r1.is_err());
4216 assert!(r2.is_err() || r2.is_ok());
4217 });
4218 }
4219
4220 #[test]
4221 fn pool_failure_is_propagated() {
4222 TOKIO_SHARED_RT.block_on(async {
4223 let conn = WebsocketConnection::new("w");
4224 let sig = SignatureGenerator::new(
4225 Some("api_secret".to_string()),
4226 None::<PrivateKey>,
4227 None::<String>,
4228 );
4229 let cfg = ConfigurationWebsocketApi {
4230 api_key: Some("api_key".into()),
4231 api_secret: Some("api_secret".to_string()),
4232 private_key: None,
4233 private_key_passphrase: None,
4234 ws_url: Some("ws://doesnotexist:1".to_string()),
4235 mode: WebsocketMode::Single,
4236 reconnect_delay: 0,
4237 signature_gen: sig,
4238 timeout: 10,
4239 time_unit: None,
4240 agent: None,
4241 };
4242 let api = WebsocketApi::new(cfg, vec![conn.clone()]);
4243 let res = api.clone().connect().await;
4244 match res {
4245 Err(WebsocketError::Handshake(_) | WebsocketError::Timeout) => {}
4246 _ => panic!("expected handshake or timeout error"),
4247 }
4248 });
4249 }
4250 }
4251
4252 mod send_message {
4253 use super::*;
4254
4255 #[test]
4256 fn unsigned_message() {
4257 TOKIO_SHARED_RT.block_on(async {
4258 let api = create_websocket_api(None);
4259 let conn = &api.common.connection_pool[0];
4260 let (tx, mut rx) = unbounded_channel::<Message>();
4261 {
4262 let mut st = conn.state.lock().await;
4263 st.ws_write_tx = Some(tx);
4264 }
4265
4266 let fut = tokio::spawn({
4267 let api = api.clone();
4268 async move {
4269 let mut params = BTreeMap::new();
4270 params.insert("foo".into(), Value::String("bar".into()));
4271 api.send_message::<Value>(
4272 "mymethod",
4273 params,
4274 WebsocketMessageSendOptions {
4275 with_api_key: false,
4276 is_signed: false,
4277 },
4278 )
4279 .await
4280 .unwrap()
4281 }
4282 });
4283
4284 let sent = rx.recv().await.unwrap();
4285 let txt = if let Message::Text(s) = sent {
4286 s
4287 } else {
4288 panic!()
4289 };
4290 let req: Value = serde_json::from_str(&txt).unwrap();
4291 assert_eq!(req["method"], "mymethod");
4292 assert!(req["params"]["foo"] == "bar");
4293 assert!(req["params"].get("apiKey").is_none());
4294 assert!(req["params"].get("timestamp").is_none());
4295 assert!(req["params"].get("signature").is_none());
4296
4297 let id = req["id"].as_str().unwrap().to_string();
4298 let mut st = conn.state.lock().await;
4299 let pending = st.pending_requests.remove(&id).unwrap();
4300 let reply = json!({
4301 "id": id,
4302 "result": { "x": 42 },
4303 "rateLimits": [{ "limit": 7 }]
4304 });
4305 pending.completion.send(Ok(reply)).unwrap();
4306
4307 let resp = fut.await.unwrap();
4308 let rate_limits = resp.rate_limits.unwrap_or_default();
4309
4310 assert!(rate_limits.is_empty());
4311 assert_eq!(resp.raw, json!({"x": 42}));
4312 });
4313 }
4314
4315 #[test]
4316 fn with_api_key_only() {
4317 TOKIO_SHARED_RT.block_on(async {
4318 let api = create_websocket_api(None);
4319 let conn = &api.common.connection_pool[0];
4320 let (tx, mut rx) = unbounded_channel::<Message>();
4321 {
4322 let mut st = conn.state.lock().await;
4323 st.ws_write_tx = Some(tx);
4324 }
4325
4326 let fut = tokio::spawn({
4327 let api = api.clone();
4328 async move {
4329 let params = BTreeMap::new();
4330 api.send_message::<Value>(
4331 "foo",
4332 params,
4333 WebsocketMessageSendOptions {
4334 with_api_key: true,
4335 is_signed: false,
4336 },
4337 )
4338 .await
4339 .unwrap()
4340 }
4341 });
4342
4343 let txt = if let Message::Text(s) = rx.recv().await.unwrap() {
4344 s
4345 } else {
4346 panic!()
4347 };
4348 let req: Value = serde_json::from_str(&txt).unwrap();
4349 assert_eq!(req["params"]["apiKey"], "api_key");
4350
4351 let id = req["id"].as_str().unwrap().to_string();
4352 let mut st = conn.state.lock().await;
4353 let pending = st.pending_requests.remove(&id).unwrap();
4354 pending
4355 .completion
4356 .send(Ok(json!({
4357 "id": id,
4358 "result": {},
4359 "rateLimits": []
4360 })))
4361 .unwrap();
4362
4363 let resp = fut.await.unwrap();
4364
4365 assert_eq!(resp.raw, json!({}));
4366 assert!(st.pending_requests.is_empty());
4367 });
4368 }
4369
4370 #[test]
4371 fn signed_message_has_timestamp_and_signature() {
4372 TOKIO_SHARED_RT.block_on(async {
4373 let api = create_websocket_api(None);
4374 let conn = &api.common.connection_pool[0];
4375 let (tx, mut rx) = unbounded_channel::<Message>();
4376 {
4377 let mut st = conn.state.lock().await;
4378 st.ws_write_tx = Some(tx);
4379 }
4380
4381 let fut = tokio::spawn({
4382 let api = api.clone();
4383 async move {
4384 let mut params = BTreeMap::new();
4385 params.insert("foo".into(), Value::String("bar".into()));
4386 api.send_message::<Value>(
4387 "method",
4388 params,
4389 WebsocketMessageSendOptions {
4390 with_api_key: true,
4391 is_signed: true,
4392 },
4393 )
4394 .await
4395 .unwrap()
4396 }
4397 });
4398
4399 let txt = if let Message::Text(s) = rx.recv().await.unwrap() {
4400 s
4401 } else {
4402 panic!()
4403 };
4404 let req: Value = serde_json::from_str(&txt).unwrap();
4405 let p = &req["params"];
4406 assert!(p["apiKey"] == "api_key");
4407 assert!(p["timestamp"].is_number());
4408 assert!(p["signature"].is_string());
4409
4410 let id = req["id"].as_str().unwrap().to_string();
4411 let mut st = conn.state.lock().await;
4412 let pending = st.pending_requests.remove(&id).unwrap();
4413 pending
4414 .completion
4415 .send(Ok(json!({
4416 "id": id,
4417 "result": { "ok": true },
4418 "rateLimits": []
4419 })))
4420 .unwrap();
4421
4422 let resp = fut.await.unwrap();
4423 assert_eq!(resp.raw, json!({ "ok": true }));
4424 });
4425 }
4426
4427 #[test]
4428 fn error_if_not_connected() {
4429 TOKIO_SHARED_RT.block_on(async {
4430 let api = create_websocket_api(None);
4431 let conn = &api.common.connection_pool[0];
4432 {
4433 let mut st = conn.state.lock().await;
4434 st.ws_write_tx = None;
4435 }
4436 let params = BTreeMap::new();
4437 let err = api
4438 .send_message::<Value>(
4439 "method",
4440 params,
4441 WebsocketMessageSendOptions {
4442 with_api_key: false,
4443 is_signed: false,
4444 },
4445 )
4446 .await
4447 .unwrap_err();
4448 matches!(err, WebsocketError::NotConnected);
4449 });
4450 }
4451 }
4452
4453 mod prepare_url {
4454 use super::*;
4455
4456 #[test]
4457 fn no_time_unit() {
4458 TOKIO_SHARED_RT.block_on(async {
4459 let api = create_websocket_api(None);
4460 let url = "wss://example.com/ws".to_string();
4461 assert_eq!(api.prepare_url(&url), url);
4462 });
4463 }
4464
4465 #[test]
4466 fn appends_time_unit() {
4467 TOKIO_SHARED_RT.block_on(async {
4468 let api = create_websocket_api(Some(TimeUnit::Millisecond));
4469 let base = "wss://example.com/ws".to_string();
4470 let got = api.prepare_url(&base);
4471 assert_eq!(got, format!("{base}?timeUnit=millisecond"));
4472 });
4473 }
4474
4475 #[test]
4476 fn handles_existing_query() {
4477 TOKIO_SHARED_RT.block_on(async {
4478 let api = create_websocket_api(Some(TimeUnit::Microsecond));
4479 let base = "wss://example.com/ws?foo=bar".to_string();
4480 let got = api.prepare_url(&base);
4481 assert_eq!(got, format!("{base}&timeUnit=microsecond"));
4482 });
4483 }
4484 }
4485
4486 mod on_message {
4487 use super::*;
4488
4489 fn create_websocket_api_and_conn() -> (Arc<WebsocketApi>, Arc<WebsocketConnection>) {
4490 let sig_gen = SignatureGenerator::new(
4491 Some("api_secret".to_string()),
4492 None::<_>,
4493 None::<String>,
4494 );
4495 let config = ConfigurationWebsocketApi {
4496 api_key: Some("api_key".to_string()),
4497 api_secret: Some("api_secret".to_string()),
4498 private_key: None,
4499 private_key_passphrase: None,
4500 ws_url: Some("wss://example".to_string()),
4501 mode: WebsocketMode::Single,
4502 reconnect_delay: 0,
4503 signature_gen: sig_gen,
4504 timeout: 1000,
4505 time_unit: None,
4506 agent: None,
4507 };
4508 let conn = WebsocketConnection::new("test");
4509 let api = WebsocketApi::new(config, vec![conn.clone()]);
4510 (api, conn)
4511 }
4512
4513 #[test]
4514 fn resolves_pending_and_removes_request() {
4515 TOKIO_SHARED_RT.block_on(async {
4516 let (api, conn) = create_websocket_api_and_conn();
4517 let (tx, rx) = oneshot::channel();
4518 {
4519 let mut st = conn.state.lock().await;
4520 st.pending_requests
4521 .insert("id1".to_string(), PendingRequest { completion: tx });
4522 }
4523 let msg = json!({"id":"id1","status":200,"foo":"bar"});
4524 api.on_message(msg.to_string(), conn.clone()).await;
4525 let got = rx.await.unwrap().unwrap();
4526 assert_eq!(got, msg);
4527 let st = conn.state.lock().await;
4528 assert!(st.pending_requests.get("id1").is_none());
4529 });
4530 }
4531
4532 #[test]
4533 fn uses_result_when_present() {
4534 TOKIO_SHARED_RT.block_on(async {
4535 let (api, conn) = create_websocket_api_and_conn();
4536 let (tx, rx) = oneshot::channel();
4537 {
4538 let mut st = conn.state.lock().await;
4539 st.pending_requests
4540 .insert("id1".to_string(), PendingRequest { completion: tx });
4541 }
4542 let msg = json!({
4543 "id": "id1",
4544 "status": 200,
4545 "response": [1,2],
4546 "result": {"a":1}
4547 });
4548 api.on_message(msg.to_string(), conn.clone()).await;
4549 let got = rx.await.unwrap().unwrap();
4550 assert_eq!(got.get("result").unwrap(), &json!({"a":1}));
4551 });
4552 }
4553
4554 #[test]
4555 fn uses_response_when_no_result() {
4556 TOKIO_SHARED_RT.block_on(async {
4557 let (api, conn) = create_websocket_api_and_conn();
4558 let (tx, rx) = oneshot::channel();
4559 {
4560 let mut st = conn.state.lock().await;
4561 st.pending_requests
4562 .insert("id1".to_string(), PendingRequest { completion: tx });
4563 }
4564 let msg = json!({
4565 "id": "id1",
4566 "status": 200,
4567 "response": ["ok"]
4568 });
4569 api.on_message(msg.to_string(), conn.clone()).await;
4570 let got = rx.await.unwrap().unwrap();
4571 assert_eq!(got.get("response").unwrap(), &json!(["ok"]));
4572 });
4573 }
4574
4575 #[test]
4576 fn errors_for_status_ge_400() {
4577 TOKIO_SHARED_RT.block_on(async {
4578 let (api, conn) = create_websocket_api_and_conn();
4579 let (tx, rx) = oneshot::channel();
4580 {
4581 let mut st = conn.state.lock().await;
4582 st.pending_requests
4583 .insert("bad".to_string(), PendingRequest { completion: tx });
4584 }
4585 let err_obj = json!({"code":123,"msg":"oops"});
4586 let msg = json!({"id":"bad","status":500,"error":err_obj});
4587 api.on_message(msg.to_string(), conn.clone()).await;
4588 match rx.await.unwrap() {
4589 Err(WebsocketError::ResponseError { code, message }) => {
4590 assert_eq!(code, 123);
4591 assert_eq!(message, "oops");
4592 }
4593 other => panic!("expected ResponseError, got {other:?}"),
4594 }
4595 let st = conn.state.lock().await;
4596 assert!(st.pending_requests.get("bad").is_none());
4597 });
4598 }
4599
4600 #[test]
4601 fn ignores_unknown_id() {
4602 TOKIO_SHARED_RT.block_on(async {
4603 let (api, conn) = create_websocket_api_and_conn();
4604 let msg = json!({"id":"nope","status":200});
4605 api.on_message(msg.to_string(), conn.clone()).await;
4606 let st = conn.state.lock().await;
4607 assert!(st.pending_requests.is_empty());
4608 });
4609 }
4610
4611 #[test]
4612 fn parse_error_ignored() {
4613 TOKIO_SHARED_RT.block_on(async {
4614 let (api, conn) = create_websocket_api_and_conn();
4615 api.on_message("not json".to_string(), conn.clone()).await;
4616 let st = conn.state.lock().await;
4617 assert!(st.pending_requests.is_empty());
4618 });
4619 }
4620
4621 #[test]
4622 fn error_status_sends_error() {
4623 TOKIO_SHARED_RT.block_on(async {
4624 let (api, conn) = create_websocket_api_and_conn();
4625 let (tx, rx) = oneshot::channel();
4626 {
4627 let mut st = conn.state.lock().await;
4628 st.pending_requests
4629 .insert("err".to_string(), PendingRequest { completion: tx });
4630 }
4631 let msg = json!({
4632 "id": "err",
4633 "status": 500,
4634 "error": { "code": 42, "msg": "Bad!" }
4635 });
4636 api.on_message(msg.to_string(), conn.clone()).await;
4637 match rx.await.unwrap() {
4638 Err(WebsocketError::ResponseError { code, message }) => {
4639 assert_eq!(code, 42);
4640 assert_eq!(message, "Bad!");
4641 }
4642 other => panic!("expected ResponseError, got {other:?}"),
4643 }
4644 });
4645 }
4646
4647 #[test]
4648 fn unknown_id_logs_warning_and_leaves_pending() {
4649 TOKIO_SHARED_RT.block_on(async {
4650 let (api, conn) = create_websocket_api_and_conn();
4651 {
4652 let mut st = conn.state.lock().await;
4653 st.pending_requests.insert(
4654 "keep".to_string(),
4655 PendingRequest {
4656 completion: oneshot::channel().0,
4657 },
4658 );
4659 }
4660 api.on_message(
4661 json!({ "id": "foo", "status": 200, "result": 1 }).to_string(),
4662 conn.clone(),
4663 )
4664 .await;
4665 let st = conn.state.lock().await;
4666 assert!(st.pending_requests.contains_key("keep"));
4667 });
4668 }
4669 }
4670 }
4671
4672 mod websocket_streams {
4673 use super::*;
4674
4675 mod initialisation {
4676 use super::*;
4677
4678 #[test]
4679 fn new_initializes_fields() {
4680 TOKIO_SHARED_RT.block_on(async {
4681 let sig_gen = SignatureGenerator::new(
4682 Some("api_secret".to_string()),
4683 None::<PrivateKey>,
4684 None::<String>,
4685 );
4686 let config = ConfigurationWebsocketApi {
4687 api_key: Some("api_key".to_string()),
4688 api_secret: Some("api_secret".to_string()),
4689 private_key: None,
4690 private_key_passphrase: None,
4691 ws_url: Some("wss://example".to_string()),
4692 mode: WebsocketMode::Single,
4693 reconnect_delay: 1000,
4694 signature_gen: sig_gen.clone(),
4695 timeout: 500,
4696 time_unit: None,
4697 agent: None,
4698 };
4699 let conn1 = WebsocketConnection::new("c1");
4700 let conn2 = WebsocketConnection::new("c2");
4701 let api = WebsocketApi::new(config.clone(), vec![conn1.clone(), conn2.clone()]);
4702
4703 assert_eq!(api.common.connection_pool.len(), 2);
4704 assert!(Arc::ptr_eq(&api.common.connection_pool[0], &conn1));
4705 assert!(Arc::ptr_eq(&api.common.connection_pool[1], &conn2));
4706 assert_eq!(api.configuration.ws_url, Some("wss://example".to_string()));
4707 let flag = api.is_connecting.lock().await;
4708 assert!(!*flag);
4709 });
4710 }
4711 }
4712
4713 mod connect {
4714 use super::*;
4715
4716 #[test]
4717 fn establishes_successfully() {
4718 TOKIO_SHARED_RT.block_on(async {
4719 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
4720 let port = listener.local_addr().unwrap().port();
4721
4722 tokio::spawn(async move {
4723 for _ in 0..2 {
4724 if let Ok((stream, _)) = listener.accept().await {
4725 let mut ws = accept_async(stream).await.unwrap();
4726 ws.close(None).await.ok();
4727 }
4728 }
4729 });
4730
4731 let create_websocket_streams = |ws_url: &str| {
4732 let c1 = WebsocketConnection::new("c1");
4733 let c2 = WebsocketConnection::new("c2");
4734 let config = ConfigurationWebsocketStreams {
4735 ws_url: Some(ws_url.to_string()),
4736 mode: WebsocketMode::Pool(2),
4737 reconnect_delay: 500,
4738 time_unit: None,
4739 agent: None,
4740 };
4741 WebsocketStreams::new(config, vec![c1, c2])
4742 };
4743
4744 let url = format!("ws://127.0.0.1:{port}");
4745 let ws = create_websocket_streams(&url);
4746
4747 let res = ws.connect(vec!["stream1".into()]).await;
4748 assert!(res.is_ok());
4749 });
4750 }
4751
4752 #[test]
4753 fn refused_returns_error() {
4754 TOKIO_SHARED_RT.block_on(async {
4755 let ws = create_websocket_streams(Some("ws://127.0.0.1:9"), None);
4756 let res = ws.connect(vec!["stream1".into()]).await;
4757 assert!(res.is_err());
4758 });
4759 }
4760
4761 #[test]
4762 fn invalid_url_returns_error() {
4763 TOKIO_SHARED_RT.block_on(async {
4764 let ws = create_websocket_streams(Some("not-a-url"), None);
4765 let res = ws.connect(vec!["s".into()]).await;
4766 assert!(res.is_err());
4767 });
4768 }
4769 }
4770
4771 mod disconnect {
4772 use super::*;
4773
4774 #[test]
4775 fn disconnect_clears_state_and_streams() {
4776 TOKIO_SHARED_RT.block_on(async {
4777 let ws = create_websocket_streams(None, None);
4778 let conn = &ws.common.connection_pool[0];
4779 {
4780 let mut state = conn.state.lock().await;
4781 state.stream_callbacks.insert("s1".to_string(), Vec::new());
4782 state.pending_subscriptions.push_back("s2".to_string());
4783 }
4784 {
4785 let mut map = ws.connection_streams.lock().await;
4786 map.insert("s3".to_string(), Arc::clone(conn));
4787 }
4788
4789 let res = ws.disconnect().await;
4790 assert!(res.is_ok());
4791
4792 let state = conn.state.lock().await;
4793 assert!(state.stream_callbacks.is_empty());
4794 assert!(state.pending_subscriptions.is_empty());
4795
4796 let map = ws.connection_streams.lock().await;
4797 assert!(map.is_empty());
4798 });
4799 }
4800 }
4801
4802 mod subscribe {
4803 use super::*;
4804
4805 #[test]
4806 fn empty_list_does_nothing() {
4807 TOKIO_SHARED_RT.block_on(async {
4808 let ws = create_websocket_streams(None, None);
4809 ws.clone().subscribe(Vec::new(), None).await;
4810 let map = ws.connection_streams.lock().await;
4811 assert!(map.is_empty());
4812 });
4813 }
4814
4815 #[test]
4816 fn queue_when_not_ready() {
4817 TOKIO_SHARED_RT.block_on(async {
4818 let ws = create_websocket_streams(None, None);
4819 let conn = ws.common.connection_pool[0].clone();
4820 ws.clone().subscribe(vec!["s1".into()], None).await;
4821 let state = conn.state.lock().await;
4822 let pending: Vec<String> =
4823 state.pending_subscriptions.iter().cloned().collect();
4824 assert_eq!(pending, vec!["s1".to_string()]);
4825 });
4826 }
4827
4828 #[test]
4829 fn only_one_subscription_per_stream() {
4830 TOKIO_SHARED_RT.block_on(async {
4831 let ws = create_websocket_streams(None, None);
4832 let conn = ws.common.connection_pool[0].clone();
4833 ws.clone().subscribe(vec!["s1".into()], None).await;
4834 ws.clone().subscribe(vec!["s1".into()], None).await;
4835 let state = conn.state.lock().await;
4836 let pending: Vec<String> =
4837 state.pending_subscriptions.iter().cloned().collect();
4838 assert_eq!(pending, vec!["s1".to_string()]);
4839 });
4840 }
4841
4842 #[test]
4843 fn multiple_streams_assigned() {
4844 TOKIO_SHARED_RT.block_on(async {
4845 let ws = create_websocket_streams(None, None);
4846 ws.clone()
4847 .subscribe(vec!["s1".into(), "s2".into()], None)
4848 .await;
4849 let map = ws.connection_streams.lock().await;
4850 assert!(map.contains_key("s1"));
4851 assert!(map.contains_key("s2"));
4852 });
4853 }
4854
4855 #[test]
4856 fn existing_stream_not_reassigned() {
4857 TOKIO_SHARED_RT.block_on(async {
4858 let ws = create_websocket_streams(None, None);
4859 ws.clone().subscribe(vec!["s1".into()], None).await;
4860 let first_id = {
4861 let map = ws.connection_streams.lock().await;
4862 map.get("s1").unwrap().id.clone()
4863 };
4864 ws.clone()
4865 .subscribe(vec!["s1".into(), "s2".into()], None)
4866 .await;
4867 let map = ws.connection_streams.lock().await;
4868 let second_id = map.get("s1").unwrap().id.clone();
4869 assert_eq!(first_id, second_id);
4870 assert!(map.contains_key("s2"));
4871 });
4872 }
4873 }
4874
4875 mod unsubscribe {
4876 use super::*;
4877
4878 #[test]
4879 fn removes_stream_with_no_callbacks() {
4880 TOKIO_SHARED_RT.block_on(async {
4881 let ws = create_websocket_streams(None, None);
4882 let conn = ws.common.connection_pool[0].clone();
4883
4884 {
4885 let (tx, _rx) = unbounded_channel::<Message>();
4886 let mut st = conn.state.lock().await;
4887 st.ws_write_tx = Some(tx);
4888 }
4889
4890 {
4891 let mut map = ws.connection_streams.lock().await;
4892 map.insert("s1".to_string(), conn.clone());
4893 }
4894 {
4895 let mut st = conn.state.lock().await;
4896 st.stream_callbacks.insert("s1".to_string(), Vec::new());
4897 }
4898
4899 ws.unsubscribe(vec!["s1".to_string()], None).await;
4900
4901 assert!(!ws.connection_streams.lock().await.contains_key("s1"));
4902 assert!(!conn.state.lock().await.stream_callbacks.contains_key("s1"));
4903 });
4904 }
4905
4906 #[test]
4907 fn preserves_stream_with_callbacks() {
4908 TOKIO_SHARED_RT.block_on(async {
4909 let ws = create_websocket_streams(None, None);
4910 let conn = ws.common.connection_pool[1].clone();
4911
4912 {
4913 let mut map = ws.connection_streams.lock().await;
4914 map.insert("s2".to_string(), conn.clone());
4915 }
4916 {
4917 let mut state = conn.state.lock().await;
4918 state
4919 .stream_callbacks
4920 .insert("s2".to_string(), vec![Arc::new(|_: &Value| {})]);
4921 }
4922
4923 ws.unsubscribe(vec!["s2".to_string()], None).await;
4924
4925 assert!(ws.connection_streams.lock().await.contains_key("s2"));
4926 assert!(conn.state.lock().await.stream_callbacks.contains_key("s2"));
4927 });
4928 }
4929
4930 #[test]
4931 fn does_not_send_if_callbacks_exist() {
4932 TOKIO_SHARED_RT.block_on(async {
4933 let ws = create_websocket_streams(None, None);
4934 let conn = ws.common.connection_pool[0].clone();
4935 {
4936 let mut map = ws.connection_streams.lock().await;
4937 map.insert("s1".to_string(), conn.clone());
4938 }
4939 {
4940 let mut state = conn.state.lock().await;
4941 state.stream_callbacks.insert(
4942 "s1".to_string(),
4943 vec![Arc::new(|_: &Value| {}), Arc::new(|_: &Value| {})],
4944 );
4945 }
4946 ws.unsubscribe(vec!["s1".into()], None).await;
4947 assert!(ws.connection_streams.lock().await.contains_key("s1"));
4948 assert!(conn.state.lock().await.stream_callbacks.contains_key("s1"));
4949 });
4950 }
4951
4952 #[test]
4953 fn warns_if_not_associated() {
4954 TOKIO_SHARED_RT.block_on(async {
4955 let ws = create_websocket_streams(None, None);
4956 ws.unsubscribe(vec!["nope".into()], None).await;
4957 });
4958 }
4959
4960 #[test]
4961 fn empty_list_does_nothing() {
4962 TOKIO_SHARED_RT.block_on(async {
4963 let ws = create_websocket_streams(None, None);
4964 let before = ws.connection_streams.lock().await.len();
4965 ws.unsubscribe(Vec::<String>::new(), None).await;
4966 let after = ws.connection_streams.lock().await.len();
4967 assert_eq!(before, after);
4968 });
4969 }
4970
4971 #[test]
4972 fn invalid_custom_id_falls_back() {
4973 TOKIO_SHARED_RT.block_on(async {
4974 let ws = create_websocket_streams(None, None);
4975 let conn = ws.common.connection_pool[0].clone();
4976 {
4977 let mut map = ws.connection_streams.lock().await;
4978 map.insert("foo".to_string(), conn.clone());
4979 }
4980 {
4981 let mut state = conn.state.lock().await;
4982 let (tx, _rx) = unbounded_channel();
4983 state.ws_write_tx = Some(tx);
4984 state.stream_callbacks.insert("foo".to_string(), Vec::new());
4985 }
4986 ws.unsubscribe(vec!["foo".into()], Some("bad-id".into()))
4987 .await;
4988 assert!(!ws.connection_streams.lock().await.contains_key("foo"));
4989 });
4990 }
4991
4992 #[test]
4993 fn removes_even_without_write_channel() {
4994 TOKIO_SHARED_RT.block_on(async {
4995 let ws = create_websocket_streams(None, None);
4996 let conn = ws.common.connection_pool[0].clone();
4997 {
4998 let mut map = ws.connection_streams.lock().await;
4999 map.insert("x".to_string(), conn.clone());
5000 }
5001 {
5002 let mut state = conn.state.lock().await;
5003 let (tx, _rx) = unbounded_channel();
5004 state.ws_write_tx = Some(tx);
5005 state.stream_callbacks.insert("x".to_string(), Vec::new());
5006 }
5007 ws.unsubscribe(vec!["x".into()], None).await;
5008 assert!(!ws.connection_streams.lock().await.contains_key("x"));
5009 });
5010 }
5011 }
5012
5013 mod is_subscribed {
5014 use super::*;
5015
5016 #[test]
5017 fn returns_false_when_not_subscribed() {
5018 TOKIO_SHARED_RT.block_on(async {
5019 let ws = create_websocket_streams(None, None);
5020 assert!(!ws.is_subscribed("unknown").await);
5021 });
5022 }
5023
5024 #[test]
5025 fn returns_true_when_subscribed() {
5026 TOKIO_SHARED_RT.block_on(async {
5027 let ws = create_websocket_streams(None, None);
5028 let conn = ws.common.connection_pool[0].clone();
5029 {
5030 let mut map = ws.connection_streams.lock().await;
5031 map.insert("stream1".to_string(), conn);
5032 }
5033 assert!(ws.is_subscribed("stream1").await);
5034 });
5035 }
5036 }
5037
5038 mod prepare_url {
5039 use super::*;
5040
5041 #[test]
5042 fn without_time_unit_returns_base_url() {
5043 TOKIO_SHARED_RT.block_on(async {
5044 let conns = vec![
5045 WebsocketConnection::new("c1"),
5046 WebsocketConnection::new("c2"),
5047 ];
5048 let config = ConfigurationWebsocketStreams {
5049 ws_url: Some("wss://example".to_string()),
5050 mode: WebsocketMode::Single,
5051 reconnect_delay: 100,
5052 time_unit: None,
5053 agent: None,
5054 };
5055 let ws = WebsocketStreams::new(config, conns);
5056 let url = ws.prepare_url(&["s1".into(), "s2".into()]);
5057 assert_eq!(url, "wss://example/stream?streams=s1/s2");
5058 });
5059 }
5060
5061 #[test]
5062 fn with_time_unit_appends_parameter() {
5063 TOKIO_SHARED_RT.block_on(async {
5064 let conns = vec![WebsocketConnection::new("c1")];
5065 let config = ConfigurationWebsocketStreams {
5066 ws_url: Some("wss://example".to_string()),
5067 mode: WebsocketMode::Single,
5068 reconnect_delay: 100,
5069 time_unit: Some(TimeUnit::Millisecond),
5070 agent: None,
5071 };
5072 let ws = WebsocketStreams::new(config, conns);
5073 let url = ws.prepare_url(&["a".into()]);
5074 assert_eq!(url, "wss://example/stream?streams=a&timeUnit=millisecond");
5075 });
5076 }
5077
5078 #[test]
5079 fn multiple_streams_and_time_unit() {
5080 TOKIO_SHARED_RT.block_on(async {
5081 let conns = vec![WebsocketConnection::new("c1")];
5082 let config = ConfigurationWebsocketStreams {
5083 ws_url: Some("wss://example".to_string()),
5084 mode: WebsocketMode::Single,
5085 reconnect_delay: 100,
5086 time_unit: Some(TimeUnit::Microsecond),
5087 agent: None,
5088 };
5089 let ws = WebsocketStreams::new(config, conns);
5090 let url = ws.prepare_url(&["x".into(), "y".into(), "z".into()]);
5091 assert_eq!(
5092 url,
5093 "wss://example/stream?streams=x/y/z&timeUnit=microsecond"
5094 );
5095 });
5096 }
5097 }
5098
5099 mod handle_stream_assignment {
5100 use super::*;
5101
5102 #[test]
5103 fn assigns_new_streams_to_connections() {
5104 TOKIO_SHARED_RT.block_on(async {
5105 let ws = create_websocket_streams(None, None);
5106 let groups = ws
5107 .clone()
5108 .handle_stream_assignment(vec!["s1".into(), "s2".into()])
5109 .await;
5110 let mut seen_streams = HashSet::new();
5111 for (_conn, streams) in &groups {
5112 for s in streams {
5113 seen_streams.insert(s);
5114 }
5115 }
5116 assert_eq!(
5117 seen_streams,
5118 ["s1".to_string(), "s2".to_string()].iter().collect()
5119 );
5120 assert_eq!(groups.len(), 1);
5121 });
5122 }
5123
5124 #[test]
5125 fn reuses_existing_connection_for_duplicate_stream() {
5126 TOKIO_SHARED_RT.block_on(async {
5127 let ws = create_websocket_streams(None, None);
5128 let _ = ws.clone().handle_stream_assignment(vec!["s1".into()]).await;
5129 let groups = ws
5130 .clone()
5131 .handle_stream_assignment(vec!["s1".into(), "s3".into()])
5132 .await;
5133 let mut all_streams = Vec::new();
5134 for (_conn, streams) in groups {
5135 all_streams.extend(streams);
5136 }
5137 all_streams.sort();
5138 assert_eq!(all_streams, vec!["s1".to_string(), "s3".to_string()]);
5139 });
5140 }
5141
5142 #[test]
5143 fn empty_stream_list_returns_empty() {
5144 TOKIO_SHARED_RT.block_on(async {
5145 let ws = create_websocket_streams(None, None);
5146 let groups = ws.clone().handle_stream_assignment(vec![]).await;
5147 assert!(groups.is_empty());
5148 });
5149 }
5150
5151 #[test]
5152 fn closed_or_reconnecting_forces_reassignment_of_stream() {
5153 TOKIO_SHARED_RT.block_on(async {
5154 let ws = create_websocket_streams(None, None);
5155 let mut groups = ws.clone().handle_stream_assignment(vec!["s1".into()]).await;
5156 let (conn, _) = groups.pop().unwrap();
5157 {
5158 let mut st = conn.state.lock().await;
5159 st.close_initiated = true;
5160 }
5161 let groups2 = ws.clone().handle_stream_assignment(vec!["s2".into()]).await;
5162 assert_eq!(groups2.len(), 1);
5163 let (_new_conn, streams) = &groups2[0];
5164 assert_eq!(streams, &vec!["s2".to_string()]);
5165 });
5166 }
5167
5168 #[test]
5169 fn no_available_connections_falls_back_to_one() {
5170 TOKIO_SHARED_RT.block_on(async {
5171 let ws = create_websocket_streams(None, Some(vec![]));
5172 let assigned = ws.handle_stream_assignment(vec!["foo".into()]).await;
5173 assert_eq!(assigned.len(), 1);
5174 let (_conn, streams) = &assigned[0];
5175 assert_eq!(streams.as_slice(), &["foo".to_string()]);
5176 });
5177 }
5178
5179 #[test]
5180 fn single_connection_groups_multiple_streams() {
5181 TOKIO_SHARED_RT.block_on(async {
5182 let conn = WebsocketConnection::new("c1");
5183 let ws = create_websocket_streams(None, Some(vec![conn.clone()]));
5184 let assigned = ws
5185 .handle_stream_assignment(vec!["s1".into(), "s2".into()])
5186 .await;
5187 assert_eq!(assigned.len(), 1);
5188 let (assigned_conn, streams) = &assigned[0];
5189 assert!(Arc::ptr_eq(assigned_conn, &conn));
5190 assert_eq!(streams.len(), 2);
5191 assert!(streams.contains(&"s1".to_string()));
5192 assert!(streams.contains(&"s2".to_string()));
5193 });
5194 }
5195
5196 #[test]
5197 fn reuse_existing_healthy_connection() {
5198 TOKIO_SHARED_RT.block_on(async {
5199 let conn = WebsocketConnection::new("c");
5200 let ws = create_websocket_streams(None, Some(vec![conn.clone()]));
5201 let _ = ws.handle_stream_assignment(vec!["s1".into()]).await;
5202 let second = ws.handle_stream_assignment(vec!["s1".into()]).await;
5203 assert_eq!(second.len(), 1);
5204 let (assigned_conn, streams) = &second[0];
5205 assert!(Arc::ptr_eq(assigned_conn, &conn));
5206 assert_eq!(streams.as_slice(), &["s1".to_string()]);
5207 });
5208 }
5209
5210 #[test]
5211 fn mix_new_and_assigned_streams() {
5212 TOKIO_SHARED_RT.block_on(async {
5213 let conn = WebsocketConnection::new("c");
5214 let ws = create_websocket_streams(None, Some(vec![conn.clone()]));
5215 let _ = ws
5216 .handle_stream_assignment(vec!["s1".into(), "s2".into()])
5217 .await;
5218 let mixed = ws
5219 .handle_stream_assignment(vec!["s2".into(), "s3".into()])
5220 .await;
5221 assert_eq!(mixed.len(), 1);
5222 let (assigned_conn, streams) = &mixed[0];
5223 assert!(Arc::ptr_eq(assigned_conn, &conn));
5224 let mut got = streams.clone();
5225 got.sort();
5226 assert_eq!(got, vec!["s2".to_string(), "s3".to_string()]);
5227 });
5228 }
5229 }
5230
5231 mod send_subscription_payload {
5232 use super::*;
5233
5234 #[test]
5235 fn subscribe_payload_with_custom_id_fallbacks_if_invalid() {
5236 TOKIO_SHARED_RT.block_on(async {
5237 let ws: Arc<WebsocketStreams> =
5238 create_websocket_streams(Some("ws://example.com"), None);
5239 let conn = &ws.common.connection_pool[0];
5240 let (tx, mut rx) = unbounded_channel();
5241 {
5242 let mut st = conn.state.lock().await;
5243 st.ws_write_tx = Some(tx);
5244 }
5245 ws.send_subscription_payload(
5246 conn.clone(),
5247 vec!["s1".to_string()],
5248 Some("badid".to_string()),
5249 );
5250 let msg = rx.recv().await.expect("no message sent");
5251 if let Message::Text(txt) = msg {
5252 let v: serde_json::Value = serde_json::from_str(&txt).unwrap();
5253 assert_eq!(v["method"], "SUBSCRIBE");
5254 let id = v["id"].as_str().unwrap();
5255 assert_ne!(id, "badid");
5256 assert!(Regex::new(r"^[0-9a-fA-F]{32}$").unwrap().is_match(id));
5257 } else {
5258 panic!("unexpected message: {msg:?}");
5259 }
5260 });
5261 }
5262
5263 #[test]
5264 fn subscribe_payload_with_and_without_custom_id() {
5265 TOKIO_SHARED_RT.block_on(async {
5266 let ws: Arc<WebsocketStreams> =
5267 create_websocket_streams(Some("ws://unused"), None);
5268 let conn = &ws.common.connection_pool[0];
5269 let (tx, mut rx) = unbounded_channel();
5270 {
5271 let mut st = conn.state.lock().await;
5272 st.ws_write_tx = Some(tx);
5273 }
5274 ws.send_subscription_payload(
5275 conn.clone(),
5276 vec!["a".to_string(), "b".to_string()],
5277 Some("deadbeefdeadbeefdeadbeefdeadbeef".to_string()),
5278 );
5279 let msg1 = rx.recv().await.unwrap();
5280 ws.send_subscription_payload(conn.clone(), vec!["x".to_string()], None);
5281 let msg2 = rx.recv().await.unwrap();
5282
5283 if let Message::Text(txt1) = msg1 {
5284 let v1: serde_json::Value = serde_json::from_str(&txt1).unwrap();
5285 assert_eq!(v1["id"], "deadbeefdeadbeefdeadbeefdeadbeef");
5286 assert_eq!(
5287 v1["params"].as_array().unwrap(),
5288 &vec![serde_json::json!("a"), serde_json::json!("b")]
5289 );
5290 } else {
5291 panic!()
5292 }
5293
5294 if let Message::Text(txt2) = msg2 {
5295 let v2: serde_json::Value = serde_json::from_str(&txt2).unwrap();
5296 assert_eq!(v2["method"], "SUBSCRIBE");
5297 let params = v2["params"].as_array().unwrap();
5298 assert_eq!(params.len(), 1);
5299 assert_eq!(params[0], "x");
5300 let id2 = v2["id"].as_str().unwrap();
5301 assert!(Regex::new(r"^[0-9a-fA-F]{32}$").unwrap().is_match(id2));
5302 } else {
5303 panic!()
5304 }
5305 });
5306 }
5307 }
5308
5309 mod on_open {
5310 use super::*;
5311
5312 #[test]
5313 fn sends_pending_subscriptions() {
5314 TOKIO_SHARED_RT.block_on(async {
5315 let ws: Arc<WebsocketStreams> =
5316 create_websocket_streams(Some("ws://example.com"), None);
5317 let conn = &ws.common.connection_pool[0];
5318 let (tx, mut rx) = unbounded_channel();
5319 {
5320 let mut st = conn.state.lock().await;
5321 st.ws_write_tx = Some(tx);
5322 st.pending_subscriptions.push_back("foo".to_string());
5323 st.pending_subscriptions.push_back("bar".to_string());
5324 }
5325 ws.on_open("ws://example.com".to_string(), conn.clone())
5326 .await;
5327 let msg = rx.recv().await.expect("no subscription sent");
5328 if let Message::Text(txt) = msg {
5329 let v: Value = serde_json::from_str(&txt).unwrap();
5330 assert_eq!(v["method"], "SUBSCRIBE");
5331 let params = v["params"].as_array().unwrap();
5332 assert_eq!(
5333 params,
5334 &vec![Value::String("foo".into()), Value::String("bar".into())]
5335 );
5336 } else {
5337 panic!("unexpected message: {msg:?}");
5338 }
5339 let st_after = conn.state.lock().await;
5340 assert!(st_after.pending_subscriptions.is_empty());
5341 });
5342 }
5343
5344 #[test]
5345 fn with_no_pending_subscriptions_sends_nothing() {
5346 TOKIO_SHARED_RT.block_on(async {
5347 let ws: Arc<WebsocketStreams> =
5348 create_websocket_streams(Some("ws://example.com"), None);
5349 let conn = &ws.common.connection_pool[0];
5350 let (tx, mut rx) = unbounded_channel();
5351 {
5352 let mut st = conn.state.lock().await;
5353 st.ws_write_tx = Some(tx);
5354 }
5355 ws.on_open("ws://example.com".to_string(), conn.clone())
5356 .await;
5357 assert!(rx.try_recv().is_err(), "unexpected message sent");
5358 });
5359 }
5360
5361 #[test]
5362 fn clears_pending_without_write_channel() {
5363 TOKIO_SHARED_RT.block_on(async {
5364 let ws: Arc<WebsocketStreams> =
5365 create_websocket_streams(Some("ws://example.com"), None);
5366 let conn = &ws.common.connection_pool[0];
5367 {
5368 let mut st = conn.state.lock().await;
5369 st.pending_subscriptions.push_back("solo".to_string());
5370 }
5371 ws.on_open("ws://example.com".to_string(), conn.clone())
5372 .await;
5373 let st_after = conn.state.lock().await;
5374 assert!(st_after.pending_subscriptions.is_empty());
5375 });
5376 }
5377 }
5378
5379 mod on_message {
5380 use super::*;
5381
5382 #[test]
5383 fn invokes_registered_callback() {
5384 TOKIO_SHARED_RT.block_on(async {
5385 let ws: Arc<WebsocketStreams> =
5386 create_websocket_streams(Some("ws://example.com"), None);
5387 let conn = &ws.common.connection_pool[0];
5388 let called = Arc::new(AtomicBool::new(false));
5389 let called_clone = called.clone();
5390
5391 {
5392 let mut st = conn.state.lock().await;
5393 st.stream_callbacks
5394 .entry("stream1".to_string())
5395 .or_default()
5396 .push(
5397 (Box::new(move |_: &Value| {
5398 called_clone.store(true, Ordering::SeqCst);
5399 })
5400 as Box<dyn Fn(&Value) + Send + Sync>)
5401 .into(),
5402 );
5403 }
5404
5405 let msg = json!({
5406 "stream": "stream1",
5407 "data": { "key": "value" }
5408 })
5409 .to_string();
5410
5411 ws.on_message(msg, conn.clone()).await;
5412
5413 assert!(called.load(Ordering::SeqCst));
5414 });
5415 }
5416
5417 #[test]
5418 fn invokes_all_registered_callbacks() {
5419 TOKIO_SHARED_RT.block_on(async {
5420 let ws: Arc<WebsocketStreams> =
5421 create_websocket_streams(Some("ws://example.com"), None);
5422 let conn = &ws.common.connection_pool[0];
5423 let counter = Arc::new(AtomicUsize::new(0));
5424
5425 {
5426 let mut st = conn.state.lock().await;
5427 let entry = st.stream_callbacks.entry("s".into()).or_default();
5428 let c1 = counter.clone();
5429 entry.push(
5430 (Box::new(move |_: &Value| {
5431 c1.fetch_add(1, Ordering::SeqCst);
5432 }) as Box<dyn Fn(&Value) + Send + Sync>)
5433 .into(),
5434 );
5435 let c2 = counter.clone();
5436 entry.push(
5437 (Box::new(move |_: &Value| {
5438 c2.fetch_add(1, Ordering::SeqCst);
5439 }) as Box<dyn Fn(&Value) + Send + Sync>)
5440 .into(),
5441 );
5442 }
5443
5444 let msg = json!({"stream":"s","data":42}).to_string();
5445 ws.on_message(msg, conn.clone()).await;
5446
5447 assert_eq!(counter.load(Ordering::SeqCst), 2);
5448 });
5449 }
5450
5451 #[test]
5452 fn handles_null_data_field() {
5453 TOKIO_SHARED_RT.block_on(async {
5454 let ws: Arc<WebsocketStreams> =
5455 create_websocket_streams(Some("ws://example.com"), None);
5456 let conn = &ws.common.connection_pool[0];
5457 let called = Arc::new(AtomicUsize::new(0));
5458 {
5459 let mut st = conn.state.lock().await;
5460 st.stream_callbacks.entry("n".into()).or_default().push(
5461 (Box::new({
5462 let c = called.clone();
5463 move |data: &Value| {
5464 if data.is_null() {
5465 c.fetch_add(1, Ordering::SeqCst);
5466 }
5467 }
5468 }) as Box<dyn Fn(&Value) + Send + Sync>)
5469 .into(),
5470 );
5471 }
5472 let msg = json!({"stream":"n","data":null}).to_string();
5473 ws.on_message(msg, conn.clone()).await;
5474 assert_eq!(called.load(Ordering::SeqCst), 1);
5475 });
5476 }
5477
5478 #[test]
5479 fn with_invalid_json_does_not_panic() {
5480 TOKIO_SHARED_RT.block_on(async {
5481 let ws: Arc<WebsocketStreams> =
5482 create_websocket_streams(Some("ws://example.com"), None);
5483 let conn = &ws.common.connection_pool[0];
5484 let bad = "not a json";
5485 ws.on_message(bad.to_string(), conn.clone()).await;
5486 });
5487 }
5488
5489 #[test]
5490 fn without_stream_field_does_nothing() {
5491 TOKIO_SHARED_RT.block_on(async {
5492 let ws: Arc<WebsocketStreams> =
5493 create_websocket_streams(Some("ws://example.com"), None);
5494 let conn = &ws.common.connection_pool[0];
5495 let msg = json!({ "data": { "foo": 1 } }).to_string();
5496 ws.on_message(msg, conn.clone()).await;
5497 });
5498 }
5499
5500 #[test]
5501 fn with_unregistered_stream_does_not_panic() {
5502 TOKIO_SHARED_RT.block_on(async {
5503 let ws: Arc<WebsocketStreams> =
5504 create_websocket_streams(Some("ws://example.com"), None);
5505 let conn = &ws.common.connection_pool[0];
5506 let msg = json!({
5507 "stream": "nope",
5508 "data": { "foo": 1 }
5509 })
5510 .to_string();
5511 ws.on_message(msg, conn.clone()).await;
5512 });
5513 }
5514 }
5515
5516 mod get_reconnect_url {
5517 use super::*;
5518
5519 #[test]
5520 fn single_stream_reconnect_url() {
5521 TOKIO_SHARED_RT.block_on(async {
5522 let ws: Arc<WebsocketStreams> =
5523 create_websocket_streams(Some("ws://example.com"), None);
5524 let c0 = ws.common.connection_pool[0].clone();
5525 {
5526 let mut map = ws.connection_streams.lock().await;
5527 map.insert("s1".to_string(), c0.clone());
5528 }
5529 let url = ws.get_reconnect_url("default_url".into(), c0).await;
5530 assert_eq!(url, "ws://example.com/stream?streams=s1");
5531 });
5532 }
5533
5534 #[test]
5535 fn multiple_streams_same_connection() {
5536 TOKIO_SHARED_RT.block_on(async {
5537 let ws: Arc<WebsocketStreams> =
5538 create_websocket_streams(Some("ws://example.com"), None);
5539 let c0 = ws.common.connection_pool[0].clone();
5540 {
5541 let mut map = ws.connection_streams.lock().await;
5542 map.insert("a".to_string(), c0.clone());
5543 map.insert("b".to_string(), c0.clone());
5544 }
5545 let url = ws.get_reconnect_url("default_url".into(), c0).await;
5546 let suffix = url
5547 .strip_prefix("ws://example.com/stream?streams=")
5548 .unwrap();
5549 let parts: Vec<_> = suffix.split('&').next().unwrap().split('/').collect();
5550 let set = parts.into_iter().collect::<std::collections::HashSet<_>>();
5551 assert_eq!(set, ["a", "b"].iter().copied().collect());
5552 });
5553 }
5554
5555 #[test]
5556 fn reconnect_url_with_time_unit() {
5557 TOKIO_SHARED_RT.block_on(async {
5558 let mut ws: Arc<WebsocketStreams> =
5559 create_websocket_streams(Some("ws://example.com"), None);
5560 Arc::get_mut(&mut ws).unwrap().configuration.time_unit =
5561 Some(TimeUnit::Microsecond);
5562 let c0 = ws.common.connection_pool[0].clone();
5563 {
5564 let mut map = ws.connection_streams.lock().await;
5565 map.insert("x".to_string(), c0.clone());
5566 }
5567 let url = ws.get_reconnect_url("default_url".into(), c0).await;
5568 assert_eq!(
5569 url,
5570 "ws://example.com/stream?streams=x&timeUnit=microsecond"
5571 );
5572 });
5573 }
5574 }
5575 }
5576
5577 mod websocket_stream {
5578 use super::*;
5579
5580 mod on {
5581 use super::*;
5582
5583 #[test]
5584 fn registers_callback_and_stream_callback_for_websocket_streams() {
5585 TOKIO_SHARED_RT.block_on(async {
5586 let ws_base = create_websocket_streams(Some("example.com"), None);
5587 let stream_name = "s1".to_string();
5588 let conn = ws_base.common.connection_pool[0].clone();
5589 {
5590 let mut map = ws_base.connection_streams.lock().await;
5591 map.insert(stream_name.clone(), conn.clone());
5592 }
5593 {
5594 let mut state = conn.state.lock().await;
5595 state
5596 .stream_callbacks
5597 .insert(stream_name.clone(), Vec::new());
5598 }
5599 let stream = Arc::new(WebsocketStream::<Value> {
5600 websocket_base: WebsocketBase::WebsocketStreams(ws_base.clone()),
5601 stream_or_id: stream_name.clone(),
5602 callback: Mutex::new(None),
5603 id: None,
5604 _phantom: PhantomData,
5605 });
5606 let called = Arc::new(Mutex::new(false));
5607 let called_clone = called.clone();
5608 stream
5609 .on("message", move |v: Value| {
5610 let mut lock = called_clone.blocking_lock();
5611 *lock = v == Value::String("x".into());
5612 })
5613 .await;
5614 let cb_guard = stream.callback.lock().await;
5615 assert!(cb_guard.is_some());
5616 let cbs = {
5617 let state = conn.state.lock().await;
5618 state.stream_callbacks.get(&stream_name).unwrap().clone()
5619 };
5620 assert_eq!(cbs.len(), 1);
5621 });
5622 }
5623
5624 #[test]
5625 fn message_twice_registers_two_wrappers_for_websocket_streams() {
5626 TOKIO_SHARED_RT.block_on(async {
5627 let ws_base = create_websocket_streams(Some("example.com"), None);
5628 let stream_name = "s2".to_string();
5629 let conn = ws_base.common.connection_pool[0].clone();
5630 {
5631 let mut map = ws_base.connection_streams.lock().await;
5632 map.insert(stream_name.clone(), conn.clone());
5633 }
5634 {
5635 let mut state = conn.state.lock().await;
5636 state
5637 .stream_callbacks
5638 .insert(stream_name.clone(), Vec::new());
5639 }
5640 let stream = Arc::new(WebsocketStream::<Value> {
5641 websocket_base: WebsocketBase::WebsocketStreams(ws_base.clone()),
5642 stream_or_id: stream_name.clone(),
5643 callback: Mutex::new(None),
5644 id: None,
5645 _phantom: PhantomData,
5646 });
5647 stream.on("message", |_| {}).await;
5648 stream.on("message", |_| {}).await;
5649 let state = conn.state.lock().await;
5650 let callbacks = state.stream_callbacks.get(&stream_name).unwrap();
5651 assert_eq!(callbacks.len(), 2);
5652 });
5653 }
5654
5655 #[test]
5656 fn ignores_non_message_event_for_websocket_streams() {
5657 TOKIO_SHARED_RT.block_on(async {
5658 let ws_base = create_websocket_streams(Some("example.com"), None);
5659 let stream = Arc::new(WebsocketStream::<Value> {
5660 websocket_base: WebsocketBase::WebsocketStreams(ws_base.clone()),
5661 stream_or_id: "s".into(),
5662 callback: Mutex::new(None),
5663 id: None,
5664 _phantom: PhantomData,
5665 });
5666 stream.on("open", |_| {}).await;
5667 let guard = stream.callback.lock().await;
5668 assert!(guard.is_none());
5669 });
5670 }
5671
5672 #[test]
5673 fn registers_callback_and_stream_callback_for_websocket_api() {
5674 TOKIO_SHARED_RT.block_on(async {
5675 let ws_base = create_websocket_api(None);
5676
5677 {
5678 let mut stream_callbacks = ws_base.stream_callbacks.lock().await;
5679 stream_callbacks.insert("id1".to_string(), Vec::new());
5680 }
5681
5682 let stream = Arc::new(WebsocketStream::<Value> {
5683 websocket_base: WebsocketBase::WebsocketApi(ws_base.clone()),
5684 stream_or_id: "id1".to_string(),
5685 callback: Mutex::new(None),
5686 id: None,
5687 _phantom: PhantomData,
5688 });
5689
5690 let called = Arc::new(Mutex::new(false));
5691 let called_clone = called.clone();
5692 stream
5693 .on("message", move |v: Value| {
5694 let mut lock = called_clone.blocking_lock();
5695 *lock = v == Value::String("x".into());
5696 })
5697 .await;
5698
5699 let cb_guard = stream.callback.lock().await;
5700 assert!(cb_guard.is_some());
5701
5702 let stream_callbacks = ws_base.stream_callbacks.lock().await;
5703 let callbacks = stream_callbacks.get("id1").unwrap();
5704 assert_eq!(callbacks.len(), 1);
5705 });
5706 }
5707
5708 #[test]
5709 fn message_twice_registers_two_wrappers_for_websocket_api() {
5710 TOKIO_SHARED_RT.block_on(async {
5711 let ws_base = create_websocket_api(None);
5712
5713 let stream = Arc::new(WebsocketStream::<Value> {
5714 websocket_base: WebsocketBase::WebsocketApi(ws_base.clone()),
5715 stream_or_id: "id2".to_string(),
5716 callback: Mutex::new(None),
5717 id: None,
5718 _phantom: PhantomData,
5719 });
5720
5721 stream.on("message", |_| {}).await;
5722 stream.on("message", |_| {}).await;
5723
5724 let stream_callbacks = ws_base.stream_callbacks.lock().await;
5725 let callbacks = stream_callbacks.get("id2").unwrap();
5726 assert_eq!(callbacks.len(), 2);
5727 });
5728 }
5729
5730 #[test]
5731 fn ignores_non_message_event_for_websocket_api() {
5732 TOKIO_SHARED_RT.block_on(async {
5733 let ws_base = create_websocket_api(None);
5734
5735 let stream = Arc::new(WebsocketStream::<Value> {
5736 websocket_base: WebsocketBase::WebsocketApi(ws_base.clone()),
5737 stream_or_id: "id3".into(),
5738 callback: Mutex::new(None),
5739 id: None,
5740 _phantom: PhantomData,
5741 });
5742
5743 stream.on("open", |_| {}).await;
5744
5745 let guard = stream.callback.lock().await;
5746 assert!(guard.is_none());
5747
5748 let stream_callbacks = ws_base.stream_callbacks.lock().await;
5749 assert!(stream_callbacks.get("id3").is_none());
5750 assert!(stream_callbacks.is_empty());
5751 });
5752 }
5753 }
5754
5755 mod on_message {
5756 use super::*;
5757
5758 #[test]
5759 fn on_message_registers_callback_for_websocket_streams() {
5760 TOKIO_SHARED_RT.block_on(async {
5761 let ws_base = create_websocket_streams(Some("example.com"), None);
5762 let stream_name = "s".to_string();
5763 let conn = ws_base.common.connection_pool[0].clone();
5764 {
5765 let mut map = ws_base.connection_streams.lock().await;
5766 map.insert(stream_name.clone(), conn.clone());
5767 }
5768 {
5769 let mut state = conn.state.lock().await;
5770 state
5771 .stream_callbacks
5772 .insert(stream_name.clone(), Vec::new());
5773 }
5774 let stream = Arc::new(WebsocketStream::<Value> {
5775 websocket_base: WebsocketBase::WebsocketStreams(ws_base.clone()),
5776 stream_or_id: stream_name.clone(),
5777 callback: Mutex::new(None),
5778 id: None,
5779 _phantom: PhantomData,
5780 });
5781 stream.on_message(|_v| {});
5782 let callbacks = &conn.state.lock().await.stream_callbacks[&stream_name];
5783 assert_eq!(callbacks.len(), 1);
5784 });
5785 }
5786
5787 #[test]
5788 fn on_message_twice_registers_two_callbacks_for_websocket_streams() {
5789 TOKIO_SHARED_RT.block_on(async {
5790 let ws_base = create_websocket_streams(Some("example.com"), None);
5791 let stream_name = "s".to_string();
5792 let conn = ws_base.common.connection_pool[0].clone();
5793 {
5794 let mut map = ws_base.connection_streams.lock().await;
5795 map.insert(stream_name.clone(), conn.clone());
5796 }
5797 {
5798 let mut state = conn.state.lock().await;
5799 state
5800 .stream_callbacks
5801 .insert(stream_name.clone(), Vec::new());
5802 }
5803 let stream = Arc::new(WebsocketStream::<Value> {
5804 websocket_base: WebsocketBase::WebsocketStreams(ws_base.clone()),
5805 stream_or_id: stream_name.clone(),
5806 callback: Mutex::new(None),
5807 id: None,
5808 _phantom: PhantomData,
5809 });
5810 stream.on_message(|_v| {});
5811 stream.on_message(|_v| {});
5812 let callbacks = &conn.state.lock().await.stream_callbacks[&stream_name];
5813 assert_eq!(callbacks.len(), 2);
5814 });
5815 }
5816
5817 #[test]
5818 fn on_message_registers_callback_for_websocket_api() {
5819 TOKIO_SHARED_RT.block_on(async {
5820 let ws_base = create_websocket_api(None);
5821 let identifier = "id1".to_string();
5822
5823 let stream = Arc::new(WebsocketStream::<Value> {
5824 websocket_base: WebsocketBase::WebsocketApi(ws_base.clone()),
5825 stream_or_id: identifier.clone(),
5826 callback: Mutex::new(None),
5827 id: None,
5828 _phantom: PhantomData,
5829 });
5830
5831 stream.on_message(|_v: Value| {});
5832
5833 let stream_callbacks = ws_base.stream_callbacks.lock().await;
5834 let callbacks = stream_callbacks.get(&identifier).unwrap();
5835 assert_eq!(callbacks.len(), 1);
5836 });
5837 }
5838
5839 #[test]
5840 fn on_message_twice_registers_two_callbacks_for_websocket_api() {
5841 TOKIO_SHARED_RT.block_on(async {
5842 let ws_base = create_websocket_api(None);
5843 let identifier = "id2".to_string();
5844
5845 let stream = Arc::new(WebsocketStream::<Value> {
5846 websocket_base: WebsocketBase::WebsocketApi(ws_base.clone()),
5847 stream_or_id: identifier.clone(),
5848 callback: Mutex::new(None),
5849 id: None,
5850 _phantom: PhantomData,
5851 });
5852
5853 stream.on_message(|_v: Value| {});
5854 stream.on_message(|_v: Value| {});
5855
5856 let stream_callbacks = ws_base.stream_callbacks.lock().await;
5857 let callbacks = stream_callbacks.get(&identifier).unwrap();
5858 assert_eq!(callbacks.len(), 2);
5859 });
5860 }
5861 }
5862
5863 mod unsubscribe {
5864 use super::*;
5865
5866 #[test]
5867 fn without_callback_does_nothing() {
5868 TOKIO_SHARED_RT.block_on(async {
5869 let ws_base = create_websocket_streams(Some("example.com"), None);
5870 let stream_name = "s1".to_string();
5871 let conn = ws_base.common.connection_pool[0].clone();
5872 {
5873 let mut map = ws_base.connection_streams.lock().await;
5874 map.insert(stream_name.clone(), conn.clone());
5875 }
5876 let mut state = conn.state.lock().await;
5877 state.stream_callbacks.insert(stream_name.clone(), vec![]);
5878 drop(state);
5879 let stream = Arc::new(WebsocketStream::<Value> {
5880 websocket_base: WebsocketBase::WebsocketStreams(ws_base.clone()),
5881 stream_or_id: stream_name.clone(),
5882 callback: Mutex::new(None),
5883 id: None,
5884 _phantom: PhantomData,
5885 });
5886 stream.unsubscribe().await;
5887 let state = conn.state.lock().await;
5888 assert!(state.stream_callbacks.contains_key(&stream_name));
5889 });
5890 }
5891
5892 #[test]
5893 fn removes_registered_callback_and_clears_state() {
5894 TOKIO_SHARED_RT.block_on(async {
5895 let ws_base = create_websocket_streams(Some("example.com"), None);
5896 let stream_name = "s2".to_string();
5897 let conn = ws_base.common.connection_pool[0].clone();
5898 {
5899 let mut map = ws_base.connection_streams.lock().await;
5900 map.insert(stream_name.clone(), conn.clone());
5901 }
5902 {
5903 let mut state = conn.state.lock().await;
5904 state
5905 .stream_callbacks
5906 .insert(stream_name.clone(), Vec::new());
5907 }
5908 let stream = Arc::new(WebsocketStream::<Value> {
5909 websocket_base: WebsocketBase::WebsocketStreams(ws_base.clone()),
5910 stream_or_id: stream_name.clone(),
5911 callback: Mutex::new(None),
5912 id: None,
5913 _phantom: PhantomData,
5914 });
5915 stream.on("message", |_| {}).await;
5916 {
5917 let guard = stream.callback.lock().await;
5918 assert!(guard.is_some());
5919 }
5920 stream.unsubscribe().await;
5921 sleep(Duration::from_millis(10)).await;
5922 let guard = stream.callback.lock().await;
5923 assert!(guard.is_none());
5924 let state = conn.state.lock().await;
5925 assert!(
5926 state
5927 .stream_callbacks
5928 .get(&stream_name)
5929 .is_none_or(std::vec::Vec::is_empty)
5930 );
5931 });
5932 }
5933
5934 #[test]
5935 fn without_callback_does_nothing_for_websocket_api() {
5936 TOKIO_SHARED_RT.block_on(async {
5937 let ws_base = create_websocket_api(None);
5938 let identifier = "id1".to_string();
5939
5940 {
5941 let mut stream_callbacks = ws_base.stream_callbacks.lock().await;
5942 stream_callbacks.insert(identifier.clone(), Vec::new());
5943 }
5944
5945 let stream = Arc::new(WebsocketStream::<Value> {
5946 websocket_base: WebsocketBase::WebsocketApi(ws_base.clone()),
5947 stream_or_id: identifier.clone(),
5948 callback: Mutex::new(None),
5949 id: None,
5950 _phantom: PhantomData,
5951 });
5952
5953 stream.unsubscribe().await;
5954
5955 let stream_callbacks = ws_base.stream_callbacks.lock().await;
5956 assert!(stream_callbacks.contains_key(&identifier));
5957 let callbacks = stream_callbacks.get(&identifier).unwrap();
5958 assert!(callbacks.is_empty());
5959 });
5960 }
5961
5962 #[test]
5963 fn removes_registered_callback_and_clears_state_for_websocket_api() {
5964 TOKIO_SHARED_RT.block_on(async {
5965 let ws_base = create_websocket_api(None);
5966 let identifier = "id2".to_string();
5967
5968 {
5969 let mut stream_callbacks = ws_base.stream_callbacks.lock().await;
5970 stream_callbacks.insert(identifier.clone(), Vec::new());
5971 }
5972
5973 let stream = Arc::new(WebsocketStream::<Value> {
5974 websocket_base: WebsocketBase::WebsocketApi(ws_base.clone()),
5975 stream_or_id: identifier.clone(),
5976 callback: Mutex::new(None),
5977 id: None,
5978 _phantom: PhantomData,
5979 });
5980
5981 stream.on("message", |_| {}).await;
5982
5983 {
5984 let stream_callbacks = ws_base.stream_callbacks.lock().await;
5985 let callbacks = stream_callbacks
5986 .get(&identifier)
5987 .expect("Entry for 'id2' should exist");
5988 assert_eq!(callbacks.len(), 1);
5989 }
5990
5991 stream.unsubscribe().await;
5992
5993 {
5994 let guard = stream.callback.lock().await;
5995 assert!(guard.is_none());
5996 }
5997
5998 {
5999 let stream_callbacks = ws_base.stream_callbacks.lock().await;
6000 let callbacks = stream_callbacks
6001 .get(&identifier)
6002 .expect("Entry for 'id2' should still exist");
6003 assert!(callbacks.is_empty());
6004 }
6005 });
6006 }
6007 }
6008 }
6009
6010 mod create_stream_handler {
6011 use super::*;
6012
6013 #[test]
6014 fn create_stream_handler_without_id_registers_stream() {
6015 TOKIO_SHARED_RT.block_on(async {
6016 let ws = create_websocket_streams(Some("ws://example.com"), None);
6017 let stream_name = "foo".to_string();
6018 let handler = create_stream_handler::<serde_json::Value>(
6019 WebsocketBase::WebsocketStreams(ws.clone()),
6020 stream_name.clone(),
6021 None,
6022 )
6023 .await;
6024 assert_eq!(handler.stream_or_id, stream_name);
6025 assert!(handler.id.is_none());
6026 let map = ws.connection_streams.lock().await;
6027 assert!(map.contains_key(&stream_name));
6028 });
6029 }
6030
6031 #[test]
6032 fn create_stream_handler_with_custom_id_registers_stream_and_id() {
6033 TOKIO_SHARED_RT.block_on(async {
6034 let ws = create_websocket_streams(Some("ws://example.com"), None);
6035 let stream_name = "bar".to_string();
6036 let custom_id = Some("my-custom-id".to_string());
6037 let handler = create_stream_handler::<serde_json::Value>(
6038 WebsocketBase::WebsocketStreams(ws.clone()),
6039 stream_name.clone(),
6040 custom_id.clone(),
6041 )
6042 .await;
6043 assert_eq!(handler.stream_or_id, stream_name);
6044 assert_eq!(handler.id, custom_id);
6045 let map = ws.connection_streams.lock().await;
6046 assert!(map.contains_key(&stream_name));
6047 });
6048 }
6049
6050 #[test]
6051 fn create_stream_handler_without_id_registers_api_stream() {
6052 TOKIO_SHARED_RT.block_on(async {
6053 let ws_base = create_websocket_api(None);
6054 let identifier = "foo-api".to_string();
6055
6056 let handler = create_stream_handler::<Value>(
6057 WebsocketBase::WebsocketApi(ws_base.clone()),
6058 identifier.clone(),
6059 None,
6060 )
6061 .await;
6062
6063 assert_eq!(handler.stream_or_id, identifier);
6064 assert!(handler.id.is_none());
6065 });
6066 }
6067
6068 #[test]
6069 fn create_stream_handler_with_custom_id_registers_api_stream_and_id() {
6070 TOKIO_SHARED_RT.block_on(async {
6071 let ws_base = create_websocket_api(None);
6072 let identifier = "bar-api".to_string();
6073 let custom_id = Some("custom-123".to_string());
6074
6075 let handler = create_stream_handler::<Value>(
6076 WebsocketBase::WebsocketApi(ws_base.clone()),
6077 identifier.clone(),
6078 custom_id.clone(),
6079 )
6080 .await;
6081
6082 assert_eq!(handler.stream_or_id, identifier);
6083 assert_eq!(handler.id, custom_id);
6084 });
6085 }
6086 }
6087}