binance_sdk/common/
websocket.rs

1use async_trait::async_trait;
2use flate2::read::ZlibDecoder;
3use futures::{SinkExt, StreamExt, stream::FuturesUnordered};
4use regex::Regex;
5use serde::de::DeserializeOwned;
6use serde_json::{Value, json};
7use std::{
8    collections::{BTreeMap, HashMap, VecDeque},
9    io::Read,
10    marker::PhantomData,
11    mem::take,
12    sync::{
13        Arc, LazyLock,
14        atomic::{AtomicUsize, Ordering},
15    },
16    time::Duration,
17};
18use tokio::{
19    net::TcpStream,
20    select, spawn,
21    sync::{
22        Mutex, Notify, broadcast,
23        mpsc::{Receiver, Sender, UnboundedSender, channel, unbounded_channel},
24        oneshot,
25    },
26    task::JoinHandle,
27    time::{sleep, timeout},
28};
29use tokio_tungstenite::{
30    Connector, MaybeTlsStream, WebSocketStream, connect_async_tls_with_config,
31    tungstenite::{
32        Message,
33        client::IntoClientRequest,
34        protocol::{CloseFrame, WebSocketConfig, frame::coding::CloseCode},
35    },
36};
37use tokio_util::time::DelayQueue;
38use tracing::{debug, error, info, warn};
39
40use crate::common::utils::{remove_empty_value, sort_object_params};
41
42use super::{
43    config::{AgentConnector, ConfigurationWebsocketApi, ConfigurationWebsocketStreams},
44    errors::WebsocketError,
45    models::{WebsocketApiResponse, WebsocketEvent, WebsocketMode},
46    utils::{get_timestamp, random_string, validate_time_unit},
47};
48
49static ID_REGEX: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"^[0-9a-f]{32}$").unwrap());
50
51pub type WebSocketClient = WebSocketStream<MaybeTlsStream<TcpStream>>;
52
53const MAX_CONN_DURATION: Duration = Duration::from_secs(23 * 60 * 60);
54
55pub struct Subscription {
56    handle: JoinHandle<()>,
57}
58
59impl Subscription {
60    /// Cancels the ongoing WebSocket event subscription and stops the event processing task.
61    ///
62    /// This method aborts the background task responsible for receiving and processing
63    /// WebSocket events, effectively unsubscribing from further event notifications.
64    ///
65    /// # Examples
66    ///
67    ///
68    /// let emitter = `WebsocketEventEmitter::new()`;
69    /// let subscription = emitter.subscribe(|event| {
70    ///     // Handle WebSocket event
71    /// });
72    /// `subscription.unsubscribe()`; // Stop receiving events
73    ///
74    pub fn unsubscribe(self) {
75        self.handle.abort();
76    }
77}
78
79#[derive(Clone)]
80pub enum WebsocketBase {
81    WebsocketApi(Arc<WebsocketApi>),
82    WebsocketStreams(Arc<WebsocketStreams>),
83}
84
85pub struct WebsocketEventEmitter {
86    tx: broadcast::Sender<WebsocketEvent>,
87}
88
89impl Default for WebsocketEventEmitter {
90    fn default() -> Self {
91        Self::new()
92    }
93}
94
95impl WebsocketEventEmitter {
96    #[must_use]
97    pub fn new() -> Self {
98        let (tx, _rx) = broadcast::channel(100);
99        Self { tx }
100    }
101
102    /// Subscribes to WebSocket events and returns a `Subscription` that allows receiving events.
103    ///
104    /// # Arguments
105    ///
106    /// * `callback` - A mutable function that will be called for each received WebSocket event.
107    ///
108    /// # Returns
109    ///
110    /// A `Subscription` that can be used to manage the event subscription.
111    ///
112    /// # Examples
113    ///
114    ///
115    /// let emitter = `WebsocketEventEmitter::new()`;
116    /// let subscription = emitter.subscribe(|event| {
117    ///     // Handle WebSocket event
118    /// });
119    /// // Later, unsubscribe if needed
120    /// `subscription.unsubscribe()`;
121    ///
122    pub fn subscribe<F>(&self, mut callback: F) -> Subscription
123    where
124        F: FnMut(WebsocketEvent) + Send + 'static,
125    {
126        let mut rx = self.tx.subscribe();
127        let handle = spawn(async move {
128            while let Ok(event) = rx.recv().await {
129                callback(event);
130            }
131        });
132        Subscription { handle }
133    }
134
135    /// Sends a WebSocket event to all subscribers of this event emitter.
136    ///
137    /// # Arguments
138    ///
139    /// * `event` - The WebSocket event to be sent.
140    ///
141    /// # Remarks
142    ///
143    /// This method uses a broadcast channel to distribute the event to all registered subscribers.
144    /// If no subscribers are currently listening, the event is silently dropped.
145    fn emit(&self, event: WebsocketEvent) {
146        let _ = self.tx.send(event);
147    }
148}
149
150/// A trait defining the lifecycle and behavior of a WebSocket connection.
151///
152/// This trait provides methods for handling WebSocket connection events,
153/// including connection opening, message handling, and reconnection URL retrieval.
154///
155/// # Methods
156///
157/// * `on_open`: Called when a WebSocket connection is established
158/// * `on_message`: Called when a message is received over the WebSocket
159/// * `get_reconnect_url`: Determines the URL to use for reconnecting
160///
161/// # Thread Safety
162///
163/// Implementors must be safely shareable across threads, as indicated by the `Send + Sync + 'static` bounds.
164#[async_trait]
165pub trait WebsocketHandler: Send + Sync + 'static {
166    async fn on_open(&self, url: String, connection: Arc<WebsocketConnection>);
167    async fn on_message(&self, data: String, connection: Arc<WebsocketConnection>);
168    async fn get_reconnect_url(
169        &self,
170        default_url: String,
171        connection: Arc<WebsocketConnection>,
172    ) -> String;
173}
174
175pub struct PendingRequest {
176    pub completion: oneshot::Sender<Result<Value, WebsocketError>>,
177}
178
179pub struct WebsocketConnectionState {
180    pub reconnection_pending: bool,
181    pub renewal_pending: bool,
182    pub close_initiated: bool,
183    pub pending_requests: HashMap<String, PendingRequest>,
184    pub pending_subscriptions: VecDeque<String>,
185    pub stream_callbacks: HashMap<String, Vec<Arc<dyn Fn(&Value) + Send + Sync + 'static>>>,
186    pub handler: Option<Arc<dyn WebsocketHandler>>,
187    pub ws_write_tx: Option<UnboundedSender<Message>>,
188}
189
190impl Default for WebsocketConnectionState {
191    fn default() -> Self {
192        Self::new()
193    }
194}
195
196impl WebsocketConnectionState {
197    #[must_use]
198    pub fn new() -> Self {
199        Self {
200            reconnection_pending: false,
201            renewal_pending: false,
202            close_initiated: false,
203            pending_requests: HashMap::new(),
204            pending_subscriptions: VecDeque::new(),
205            stream_callbacks: HashMap::new(),
206            handler: None,
207            ws_write_tx: None,
208        }
209    }
210}
211
212pub struct WebsocketConnection {
213    pub id: String,
214    pub drain_notify: Notify,
215    pub state: Mutex<WebsocketConnectionState>,
216}
217
218impl WebsocketConnection {
219    pub fn new(id: impl Into<String>) -> Arc<Self> {
220        Arc::new(Self {
221            id: id.into(),
222            drain_notify: Notify::new(),
223            state: Mutex::new(WebsocketConnectionState::new()),
224        })
225    }
226
227    pub async fn set_handler(&self, handler: Arc<dyn WebsocketHandler>) {
228        let mut conn_state = self.state.lock().await;
229        conn_state.handler = Some(handler);
230    }
231}
232
233struct ReconnectEntry {
234    connection_id: String,
235    url: String,
236    is_renewal: bool,
237}
238
239pub struct WebsocketCommon {
240    pub events: WebsocketEventEmitter,
241    mode: WebsocketMode,
242    round_robin_index: AtomicUsize,
243    connection_pool: Vec<Arc<WebsocketConnection>>,
244    reconnect_tx: Sender<ReconnectEntry>,
245    renewal_tx: Sender<(String, String)>,
246    reconnect_delay: usize,
247    agent: Option<AgentConnector>,
248}
249
250impl WebsocketCommon {
251    #[must_use]
252    pub fn new(
253        mut initial_pool: Vec<Arc<WebsocketConnection>>,
254        mode: WebsocketMode,
255        reconnect_delay: usize,
256        agent: Option<AgentConnector>,
257    ) -> Arc<Self> {
258        if initial_pool.is_empty() {
259            for _ in 0..mode.pool_size() {
260                let id = random_string();
261                initial_pool.push(WebsocketConnection::new(id));
262            }
263        }
264
265        let (reconnect_tx, reconnect_rx) = channel::<ReconnectEntry>(mode.pool_size());
266        let (renewal_tx, renewal_rx) = channel::<(String, String)>(mode.pool_size());
267
268        let common = Arc::new(Self {
269            events: WebsocketEventEmitter::new(),
270            mode,
271            round_robin_index: AtomicUsize::new(0),
272            connection_pool: initial_pool,
273            reconnect_tx,
274            renewal_tx,
275            reconnect_delay,
276            agent,
277        });
278
279        Self::spawn_reconnect_loop(Arc::clone(&common), reconnect_rx);
280        Self::spawn_renewal_loop(&Arc::clone(&common), renewal_rx);
281
282        common
283    }
284
285    /// Spawns an asynchronous loop to handle websocket reconnection attempts
286    ///
287    /// This method manages reconnection logic for websocket connections, including:
288    /// - Scheduling reconnects with a configurable delay
289    /// - Finding the appropriate connection in the connection pool
290    /// - Attempting to reinitialize the connection
291    /// - Logging reconnection failures or warnings
292    ///
293    /// # Arguments
294    /// * `common` - A shared reference to the `WebsocketCommon` instance
295    /// * `reconnect_rx` - A receiver channel for reconnection entries
296    ///
297    /// # Behavior
298    /// - Waits for reconnection entries from the channel
299    /// - Applies a configurable delay before attempting reconnection
300    /// - Attempts to reinitialize the connection with the provided URL
301    /// - Handles and logs any reconnection errors
302    fn spawn_reconnect_loop(common: Arc<Self>, mut reconnect_rx: Receiver<ReconnectEntry>) {
303        spawn(async move {
304            while let Some(entry) = reconnect_rx.recv().await {
305                info!("Scheduling reconnect for id {}", entry.connection_id);
306
307                if !entry.is_renewal {
308                    sleep(Duration::from_millis(common.reconnect_delay as u64)).await;
309                }
310
311                if let Some(conn_arc) = common
312                    .connection_pool
313                    .iter()
314                    .find(|c| c.id == entry.connection_id)
315                    .cloned()
316                {
317                    let common_clone = Arc::clone(&common);
318                    if let Err(err) = common_clone
319                        .init_connect(&entry.url, entry.is_renewal, Some(conn_arc.clone()))
320                        .await
321                    {
322                        error!(
323                            "Reconnect failed for {} → {}: {:?}",
324                            entry.connection_id, entry.url, err
325                        );
326                    }
327
328                    sleep(Duration::from_secs(1)).await;
329                } else {
330                    warn!("No connection {} found for reconnect", entry.connection_id);
331                }
332            }
333        });
334    }
335
336    /// Spawns an asynchronous loop to manage connection renewals
337    ///
338    /// This method handles the periodic renewal of websocket connections by:
339    /// - Maintaining a delay queue for connection expiration
340    /// - Receiving renewal requests for specific connections
341    /// - Triggering reconnection when a connection reaches its maximum duration
342    /// - Attempting to find and renew connections in the connection pool
343    ///
344    /// # Behavior
345    /// - Listens for renewal requests on a channel
346    /// - Tracks connection expiration using a delay queue
347    /// - Initiates reconnection process when a connection expires
348    /// - Handles and logs any renewal failures
349    fn spawn_renewal_loop(common: &Arc<Self>, renewal_rx: Receiver<(String, String)>) {
350        let common = Arc::clone(common);
351        spawn(async move {
352            let mut dq = DelayQueue::new();
353            let mut renewal_rx = renewal_rx;
354
355            loop {
356                select! {
357                    Some((conn_id, url)) = renewal_rx.recv() => {
358                        debug!("Scheduling renewal for {}", conn_id);
359                        dq.insert((conn_id, url), MAX_CONN_DURATION);
360                    }
361
362                    Some(expired) = dq.next() => {
363                        let (conn_id, default_url) = expired.into_inner();
364
365                        if let Some(conn_arc) = common
366                            .connection_pool
367                            .iter()
368                            .find(|c| c.id == conn_id)
369                            .cloned()
370                        {
371                            debug!("Renewing connection {}", conn_id);
372                            let url = common
373                                .get_reconnect_url(&default_url, Arc::clone(&conn_arc))
374                                .await;
375                            if let Err(e) = common.reconnect_tx.send(ReconnectEntry {
376                                connection_id: conn_id.clone(),
377                                url,
378                                is_renewal: true,
379                            }).await {
380                                error!(
381                                    "Failed to enqueue renewal for {}: {:?}",
382                                    conn_id, e
383                                );
384                            }
385                        } else {
386                            warn!("No connection {} found for renewal", conn_id);
387                        }
388                    }
389                }
390            }
391        });
392    }
393
394    /// Checks if a WebSocket connection is ready for use.
395    ///
396    /// # Arguments
397    ///
398    /// * `connection` - The WebSocket connection to check
399    /// * `allow_non_established` - If true, allows connections that are not fully established
400    ///
401    /// # Returns
402    ///
403    /// `true` if the connection is ready, `false` otherwise
404    ///
405    /// # Behavior
406    ///
407    /// A connection is considered ready if:
408    /// - It has a write channel (unless `allow_non_established` is true)
409    /// - No renewal is pending
410    /// - No reconnection is pending
411    /// - No close has been initiated
412    pub async fn is_connection_ready(
413        &self,
414        connection: &WebsocketConnection,
415        allow_non_established: bool,
416    ) -> bool {
417        let conn_state = connection.state.lock().await;
418        (allow_non_established || conn_state.ws_write_tx.is_some())
419            && !conn_state.renewal_pending
420            && !conn_state.reconnection_pending
421            && !conn_state.close_initiated
422    }
423
424    /// Checks if a WebSocket connection is established.
425    ///
426    /// # Arguments
427    ///
428    /// * `connection` - Optional specific WebSocket connection to check
429    ///
430    /// # Returns
431    ///
432    /// `true` if a connection is ready and established, `false` otherwise
433    ///
434    /// # Behavior
435    ///
436    /// - If a specific connection is provided, checks only that connection
437    /// - If no connection is provided, checks all connections in the pool
438    /// - A connection is considered established if it is ready and not in a non-established state
439    async fn is_connected(&self, connection: Option<&Arc<WebsocketConnection>>) -> bool {
440        if let Some(conn_arc) = connection {
441            return self.is_connection_ready(conn_arc, false).await;
442        }
443
444        for conn_arc in &self.connection_pool {
445            if self.is_connection_ready(conn_arc, false).await {
446                return true;
447            }
448        }
449
450        false
451    }
452
453    /// Retrieves a WebSocket connection from the connection pool.
454    ///
455    /// # Arguments
456    ///
457    /// * `allow_non_established` - If `true`, allows selecting a connection that is not fully established
458    ///
459    /// # Returns
460    ///
461    /// An `Arc` to a `WebsocketConnection` from the pool, selected using round-robin strategy
462    ///
463    /// # Errors
464    ///
465    /// Returns `WebsocketError::NotConnected` if no suitable connection is available
466    ///
467    /// # Behavior
468    ///
469    /// - For single connection mode, returns the first connection
470    /// - For multi-connection mode, selects a ready connection using round-robin
471    /// - Filters connections based on `allow_non_established` parameter
472    async fn get_connection(
473        &self,
474        allow_non_established: bool,
475    ) -> Result<Arc<WebsocketConnection>, WebsocketError> {
476        if let WebsocketMode::Single = self.mode {
477            return Ok(Arc::clone(&self.connection_pool[0]));
478        }
479
480        let mut ready = Vec::new();
481        for conn in &self.connection_pool {
482            if self.is_connection_ready(conn, allow_non_established).await {
483                ready.push(Arc::clone(conn));
484            }
485        }
486
487        if ready.is_empty() {
488            return Err(WebsocketError::NotConnected);
489        }
490
491        let idx = self.round_robin_index.fetch_add(1, Ordering::Relaxed) % ready.len();
492
493        Ok(Arc::clone(&ready[idx]))
494    }
495
496    /// Gracefully closes a WebSocket connection by waiting for pending requests to complete.
497    ///
498    /// # Arguments
499    ///
500    /// * `ws_write_tx_to_close` - Sender channel for sending close message
501    /// * `connection` - Shared reference to the WebSocket connection
502    ///
503    /// # Behavior
504    ///
505    /// - Waits up to 30 seconds for all pending requests to complete
506    /// - Logs debug and warning messages during the closing process
507    /// - Sends a normal close frame to the WebSocket
508    ///
509    /// # Returns
510    ///
511    /// `Ok(())` if connection closes successfully, otherwise a `WebsocketError`
512    async fn close_connection_gracefully(
513        &self,
514        ws_write_tx_to_close: UnboundedSender<Message>,
515        connection: Arc<WebsocketConnection>,
516    ) -> Result<(), WebsocketError> {
517        debug!("Waiting for pending requests to complete before disconnecting.");
518
519        let drain = async {
520            loop {
521                {
522                    let conn_state = connection.state.lock().await;
523                    if conn_state.pending_requests.is_empty() {
524                        debug!("All pending requests completed, proceeding to close.");
525                        break;
526                    }
527                }
528                connection.drain_notify.notified().await;
529            }
530        };
531
532        if timeout(Duration::from_secs(30), drain).await.is_err() {
533            warn!("Timeout waiting for pending requests; forcing close.");
534        }
535
536        info!("Closing WebSocket connection for {}", connection.id);
537        let _ = ws_write_tx_to_close.send(Message::Close(Some(CloseFrame {
538            code: CloseCode::Normal,
539            reason: "".into(),
540        })));
541
542        Ok(())
543    }
544
545    /// Retrieves the URL to use for reconnecting to the WebSocket.
546    ///
547    /// # Arguments
548    ///
549    /// * `default_url` - The default URL to use if no custom reconnect URL is provided
550    /// * `connection` - A shared reference to the WebSocket connection
551    ///
552    /// # Returns
553    ///
554    /// The URL to use for reconnecting, either from a custom handler or the default URL
555    ///
556    /// # Behavior
557    ///
558    /// - Checks if a connection handler is available
559    /// - If a handler exists, calls its `get_reconnect_url` method
560    /// - Otherwise, returns the default URL
561    async fn get_reconnect_url(
562        &self,
563        default_url: &str,
564        connection: Arc<WebsocketConnection>,
565    ) -> String {
566        if let Some(handler) = {
567            let conn_state = connection.state.lock().await;
568            conn_state.handler.clone()
569        } {
570            return handler
571                .get_reconnect_url(default_url.to_string(), Arc::clone(&connection))
572                .await;
573        }
574
575        default_url.to_string()
576    }
577
578    /// Handles the WebSocket connection opening event.
579    ///
580    /// This method is called when a WebSocket connection is successfully established. It performs
581    /// the following key actions:
582    /// - Invokes the connection handler's `on_open` method if a handler is present
583    /// - Logs connection information
584    /// - Handles connection renewal and close scenarios
585    /// - Emits a WebSocket open event
586    ///
587    /// # Arguments
588    ///
589    /// * `url` - The URL of the WebSocket server
590    /// * `connection` - A shared reference to the WebSocket connection
591    /// * `old_ws_writer` - Optional previous WebSocket writer for graceful connection handling
592    ///
593    /// # Behavior
594    ///
595    /// - If a connection handler exists, calls its `on_open` method
596    /// - Checks for pending renewal or close states
597    /// - Closes the previous connection if renewal is in progress
598    /// - Emits an open event if the connection is successfully established
599    async fn on_open(
600        &self,
601        url: String,
602        connection: Arc<WebsocketConnection>,
603        old_ws_writer: Option<UnboundedSender<Message>>,
604    ) {
605        if let Some(handler) = {
606            let conn_state = connection.state.lock().await;
607            conn_state.handler.clone()
608        } {
609            handler.on_open(url.clone(), Arc::clone(&connection)).await;
610        }
611
612        let conn_id = &connection.id;
613        info!("Connected to WebSocket Server with id {}: {}", conn_id, url);
614
615        {
616            let mut conn_state = connection.state.lock().await;
617
618            if conn_state.renewal_pending {
619                conn_state.renewal_pending = false;
620                drop(conn_state);
621                if let Some(tx) = old_ws_writer {
622                    info!("Connection renewal in progress; closing previous connection.");
623                    let _ = self
624                        .close_connection_gracefully(tx, Arc::clone(&connection))
625                        .await;
626                }
627                return;
628            }
629
630            if conn_state.close_initiated {
631                drop(conn_state);
632                if let Some(tx) = connection.state.lock().await.ws_write_tx.clone() {
633                    info!("Close initiated; closing connection.");
634                    let _ = self
635                        .close_connection_gracefully(tx, Arc::clone(&connection))
636                        .await;
637                }
638                return;
639            }
640
641            self.events.emit(WebsocketEvent::Open);
642        }
643    }
644
645    /// Handles an incoming WebSocket message
646    ///
647    /// # Arguments
648    ///
649    /// * `msg` - The received message as a string
650    /// * `connection` - A shared reference to the WebSocket connection
651    ///
652    /// # Behavior
653    ///
654    /// - If a connection handler exists, spawns an async task to call its `on_message` method
655    /// - Emits a `WebsocketEvent::Message` event with the received message
656    async fn on_message(&self, msg: String, connection: Arc<WebsocketConnection>) {
657        if let Some(handler) = connection.state.lock().await.handler.clone() {
658            let handler_clone = handler.clone();
659            let data = msg.clone();
660            let conn_clone = connection.clone();
661            spawn(async move {
662                handler_clone.on_message(data, conn_clone).await;
663            });
664        }
665        self.events.emit(WebsocketEvent::Message(msg));
666    }
667
668    /// Creates a WebSocket connection with optional configuration and agent
669    ///
670    /// # Arguments
671    ///
672    /// * `url` - The WebSocket server URL to connect to
673    /// * `agent` - Optional agent connector for configuring the connection
674    ///
675    /// # Returns
676    ///
677    /// A `Result` containing the established WebSocket stream or a `WebsocketError`
678    ///
679    /// # Errors
680    ///
681    /// Returns a `WebsocketError` if:
682    /// - The WebSocket handshake fails
683    /// - The connection times out after 10 seconds
684    ///
685    /// # Behavior
686    ///
687    /// Attempts to establish a WebSocket connection with a configurable timeout,
688    /// supporting optional TLS and custom connectors
689    async fn create_websocket(
690        url: &str,
691        agent: Option<AgentConnector>,
692    ) -> Result<WebSocketStream<MaybeTlsStream<tokio::net::TcpStream>>, WebsocketError> {
693        let req = url
694            .into_client_request()
695            .map_err(|e| WebsocketError::Handshake(e.to_string()))?;
696
697        let ws_config: Option<WebSocketConfig> = None;
698        let disable_nagle = false;
699        let connector: Option<Connector> = agent.map(|dbg| dbg.0);
700
701        let timeout_duration = Duration::from_secs(10);
702        let handshake = connect_async_tls_with_config(req, ws_config, disable_nagle, connector);
703        match timeout(timeout_duration, handshake).await {
704            Ok(Ok((ws_stream, response))) => {
705                debug!("WebSocket connected: {:?}", response);
706                Ok(ws_stream)
707            }
708            Ok(Err(e)) => {
709                let msg = e.to_string();
710                error!("WebSocket handshake failed: {}", msg);
711                Err(WebsocketError::Handshake(msg))
712            }
713            Err(_) => {
714                error!(
715                    "WebSocket connection timed out after {}s",
716                    timeout_duration.as_secs()
717                );
718                Err(WebsocketError::Timeout)
719            }
720        }
721    }
722
723    /// Connects to a WebSocket URL for all connections in the connection pool concurrently
724    ///
725    /// # Arguments
726    ///
727    /// * `url` - The WebSocket server URL to connect to
728    ///
729    /// # Returns
730    ///
731    /// A `Result` indicating whether all connections were successfully established
732    ///
733    /// # Errors
734    ///
735    /// Returns a `WebsocketError` if any connection in the pool fails to establish
736    ///
737    /// # Behavior
738    ///
739    /// Attempts to initialize a WebSocket connection for each connection in the pool
740    /// concurrently, logging successes and failures for each connection attempt
741    async fn connect_pool(self: Arc<Self>, url: &str) -> Result<(), WebsocketError> {
742        let mut tasks = FuturesUnordered::new();
743
744        for conn in &self.connection_pool {
745            let common = Arc::clone(&self);
746            let url = url.to_owned();
747            let conn_clone = Arc::clone(conn);
748
749            tasks.push(async move {
750                match common.init_connect(&url, false, Some(conn_clone)).await {
751                    Ok(()) => {
752                        info!("Successfully connected to {}", url);
753                        Ok(())
754                    }
755                    Err(err) => {
756                        error!("Failed to connect to {}: {:?}", url, err);
757                        Err(err)
758                    }
759                }
760            });
761        }
762
763        while let Some(result) = tasks.next().await {
764            result?;
765        }
766
767        Ok(())
768    }
769
770    /// Initializes a WebSocket connection for a specific connection in the pool
771    ///
772    /// # Arguments
773    ///
774    /// * `url` - The WebSocket server URL to connect to
775    /// * `is_renewal` - Flag indicating whether this is a connection renewal attempt
776    /// * `connection` - Optional specific WebSocket connection to use, otherwise selects from the pool
777    ///
778    /// # Returns
779    ///
780    /// A `Result` indicating whether the connection was successfully established
781    ///
782    /// # Errors
783    ///
784    /// Returns a `WebsocketError` if the connection fails to initialize or establish
785    ///
786    /// # Behavior
787    ///
788    /// Handles connection establishment, splitting read/write streams, spawning reader/writer tasks,
789    /// and managing connection state including renewal, reconnection, and error handling
790    async fn init_connect(
791        self: Arc<Self>,
792        url: &str,
793        is_renewal: bool,
794        connection: Option<Arc<WebsocketConnection>>,
795    ) -> Result<(), WebsocketError> {
796        let conn = connection.unwrap_or(self.get_connection(true).await?);
797
798        {
799            let mut conn_state = conn.state.lock().await;
800            if conn_state.renewal_pending && is_renewal {
801                info!("Renewal in progress {}→{}", conn.id, url);
802                return Ok(());
803            }
804            if conn_state.ws_write_tx.is_some() && !is_renewal {
805                info!("Exists {}; skipping {}", conn.id, url);
806                return Ok(());
807            }
808            if is_renewal {
809                conn_state.renewal_pending = true;
810            }
811        }
812
813        let ws = Self::create_websocket(url, self.agent.clone())
814            .await
815            .map_err(|e| {
816                error!("Handshake failed {}: {:?}", url, e);
817                e
818            })?;
819
820        info!("Established {} → {}", conn.id, url);
821
822        if let Err(e) = self.renewal_tx.try_send((conn.id.clone(), url.to_string())) {
823            error!("Failed to schedule renewal for {}: {:?}", conn.id, e);
824        }
825
826        let (write_half, mut read_half) = ws.split();
827        let (tx, mut rx) = unbounded_channel::<Message>();
828
829        let old_writer = {
830            let mut conn_state = conn.state.lock().await;
831            conn_state.ws_write_tx.replace(tx.clone())
832        };
833
834        let wconn = conn.clone();
835
836        spawn(async move {
837            let mut sink = write_half;
838            while let Some(msg) = rx.recv().await {
839                if sink.send(msg).await.is_err() {
840                    error!("Write error {}", wconn.id);
841                    break;
842                }
843            }
844            debug!("Writer {} exit", wconn.id);
845        });
846
847        self.on_open(url.to_string(), conn.clone(), old_writer)
848            .await;
849
850        let common = self.clone();
851        let reader_conn = conn.clone();
852        let read_url = url.to_string();
853
854        spawn(async move {
855            while let Some(item) = read_half.next().await {
856                match item {
857                    Ok(Message::Text(msg)) => {
858                        common
859                            .on_message(msg.to_string(), Arc::clone(&reader_conn))
860                            .await;
861                    }
862                    Ok(Message::Binary(bin)) => {
863                        let mut decoder = ZlibDecoder::new(&bin[..]);
864                        let mut decompressed = String::new();
865                        if let Err(err) = decoder.read_to_string(&mut decompressed) {
866                            error!("Binary message decompress failed: {:?}", err);
867                            continue;
868                        }
869                        common
870                            .on_message(decompressed, Arc::clone(&reader_conn))
871                            .await;
872                    }
873                    Ok(Message::Ping(payload)) => {
874                        info!("PING received from server on {}", reader_conn.id);
875                        common.events.emit(WebsocketEvent::Ping);
876                        if let Some(tx) = reader_conn.state.lock().await.ws_write_tx.clone() {
877                            let _ = tx.send(Message::Pong(payload));
878                            info!(
879                                "Responded PONG to server's PING message on {}",
880                                reader_conn.id
881                            );
882                        }
883                    }
884                    Ok(Message::Pong(_)) => {
885                        info!("Received PONG from server on {}", reader_conn.id);
886                        common.events.emit(WebsocketEvent::Pong);
887                    }
888                    Ok(Message::Close(frame)) => {
889                        let (code, reason) = frame
890                            .map_or((1000, String::new()), |CloseFrame { code, reason }| {
891                                (code.into(), reason.to_string())
892                            });
893                        common
894                            .events
895                            .emit(WebsocketEvent::Close(code, reason.clone()));
896
897                        let mut conn_state = reader_conn.state.lock().await;
898                        if !conn_state.close_initiated
899                            && !is_renewal
900                            && CloseCode::from(code) != CloseCode::Normal
901                        {
902                            warn!(
903                                "Connection {} closed due to {}: {}",
904                                reader_conn.id, code, reason
905                            );
906                            conn_state.reconnection_pending = true;
907                            drop(conn_state);
908                            let reconnect_url = common
909                                .get_reconnect_url(&read_url, Arc::clone(&reader_conn))
910                                .await;
911                            let _ = common.reconnect_tx.send(ReconnectEntry {
912                                connection_id: reader_conn.id.clone(),
913                                url: reconnect_url,
914                                is_renewal: false,
915                            });
916                        }
917                        break;
918                    }
919                    Err(e) => {
920                        error!("WebSocket error on {}: {:?}", reader_conn.id, e);
921                        common.events.emit(WebsocketEvent::Error(e.to_string()));
922                    }
923                    _ => {}
924                }
925            }
926            debug!("Reader actor for {} exiting", reader_conn.id);
927        });
928
929        Ok(())
930    }
931
932    /// Gracefully disconnects all active WebSocket connections.
933    ///
934    /// This method attempts to close all connections in the connection pool within a 30-second timeout.
935    /// It marks each connection as close-initiated and attempts to close them gracefully.
936    ///
937    /// # Returns
938    ///
939    /// - `Ok(())` if all connections are successfully closed
940    /// - `Err(WebsocketError)` if there are errors during disconnection or a timeout occurs
941    ///
942    /// # Errors
943    ///
944    /// Returns `WebsocketError::Timeout` if disconnection takes longer than 30 seconds
945    ///
946    async fn disconnect(&self) -> Result<(), WebsocketError> {
947        if !self.is_connected(None).await {
948            warn!("No active connection to close.");
949            return Ok(());
950        }
951
952        let mut shutdowns = FuturesUnordered::new();
953        for conn in &self.connection_pool {
954            {
955                let mut conn_state = conn.state.lock().await;
956                conn_state.close_initiated = true;
957                if let Some(tx) = &conn_state.ws_write_tx {
958                    shutdowns.push(self.close_connection_gracefully(tx.clone(), Arc::clone(conn)));
959                }
960            }
961        }
962
963        let close_all = async {
964            while let Some(result) = shutdowns.next().await {
965                result?;
966            }
967            Ok::<(), WebsocketError>(())
968        };
969
970        match timeout(Duration::from_secs(30), close_all).await {
971            Ok(Ok(())) => {
972                info!("Disconnected all WebSocket connections successfully.");
973                Ok(())
974            }
975            Ok(Err(err)) => {
976                error!("Error while disconnecting: {:?}", err);
977                Err(err)
978            }
979            Err(_) => {
980                error!("Timed out while disconnecting WebSocket connections.");
981                Err(WebsocketError::Timeout)
982            }
983        }
984    }
985
986    /// Sends a PING message to all ready WebSocket connections.
987    ///
988    /// This method iterates through the connection pool, identifies ready connections,
989    /// and sends a PING message to each of them. It logs the number of connections
990    /// being pinged and handles any send errors individually.
991    ///
992    /// # Behavior
993    ///
994    /// - Skips connections that are not ready
995    /// - Logs a warning if no connections are ready
996    /// - Sends PING messages concurrently
997    /// - Logs debug/error messages for each PING attempt
998    async fn ping_server(&self) {
999        let mut ready = Vec::new();
1000        for conn in &self.connection_pool {
1001            if self.is_connection_ready(conn, false).await {
1002                let id = conn.id.clone();
1003                let ws_write_tx = {
1004                    let conn_state = conn.state.lock().await;
1005                    conn_state.ws_write_tx.clone()
1006                };
1007                ready.push((id, ws_write_tx));
1008            }
1009        }
1010
1011        if ready.is_empty() {
1012            warn!("No ready connections for PING.");
1013            return;
1014        }
1015        info!("Sending PING to {} WebSocket connections.", ready.len());
1016
1017        let mut tasks = FuturesUnordered::new();
1018        for (id, ws_write_tx_opt) in ready {
1019            if let Some(tx) = ws_write_tx_opt {
1020                tasks.push(async move {
1021                    if let Err(e) = tx.send(Message::Ping(Vec::new().into())) {
1022                        error!("Failed to send PING to {}: {:?}", id, e);
1023                    } else {
1024                        debug!("Sent PING to connection {}", id);
1025                    }
1026                });
1027            } else {
1028                error!("Connection {} was ready but has no write channel", id);
1029            }
1030        }
1031
1032        while tasks.next().await.is_some() {}
1033    }
1034
1035    /// Sends a WebSocket message and optionally waits for a reply.
1036    ///
1037    /// # Arguments
1038    ///
1039    /// * `payload` - The message payload to send
1040    /// * `id` - Optional request identifier, required when waiting for a reply
1041    /// * `wait_for_reply` - Whether to wait for a response to the message
1042    /// * `timeout` - Maximum duration to wait for a reply
1043    /// * `connection` - Optional specific WebSocket connection to use
1044    ///
1045    /// # Returns
1046    ///
1047    /// A receiver for the response if `wait_for_reply` is true, otherwise `None`
1048    ///
1049    /// # Errors
1050    ///
1051    /// Returns a `WebsocketError` if the connection is not ready or the send fails
1052    async fn send(
1053        &self,
1054        payload: String,
1055        id: Option<String>,
1056        wait_for_reply: bool,
1057        timeout: Duration,
1058        connection: Option<Arc<WebsocketConnection>>,
1059    ) -> Result<Option<oneshot::Receiver<Result<Value, WebsocketError>>>, WebsocketError> {
1060        let conn = if let Some(c) = connection {
1061            c
1062        } else {
1063            self.get_connection(false).await?
1064        };
1065
1066        if !self.is_connected(Some(&conn)).await {
1067            warn!("Send attempted on a non-connected socket");
1068            return Err(WebsocketError::NotConnected);
1069        }
1070
1071        let ws_write_tx = {
1072            let conn_state = conn.state.lock().await;
1073            conn_state
1074                .ws_write_tx
1075                .clone()
1076                .ok_or(WebsocketError::NotConnected)?
1077        };
1078
1079        debug!("Sending message to WebSocket on connection {}", conn.id);
1080
1081        ws_write_tx
1082            .send(Message::Text(payload.clone().into()))
1083            .map_err(|_| WebsocketError::NotConnected)?;
1084
1085        if !wait_for_reply {
1086            return Ok(None);
1087        }
1088
1089        let request_id = id.ok_or_else(|| {
1090            error!("id is required when waiting for a reply");
1091            WebsocketError::NotConnected
1092        })?;
1093
1094        let (tx, rx) = oneshot::channel();
1095        {
1096            let mut conn_state = conn.state.lock().await;
1097            conn_state
1098                .pending_requests
1099                .insert(request_id.clone(), PendingRequest { completion: tx });
1100        }
1101
1102        let conn_clone = Arc::clone(&conn);
1103        spawn(async move {
1104            sleep(timeout).await;
1105            let mut conn_state = conn_clone.state.lock().await;
1106            if let Some(pending_req) = conn_state.pending_requests.remove(&request_id) {
1107                let _ = pending_req.completion.send(Err(WebsocketError::Timeout));
1108            }
1109        });
1110
1111        Ok(Some(rx))
1112    }
1113}
1114
1115pub struct WebsocketMessageSendOptions {
1116    pub with_api_key: bool,
1117    pub is_signed: bool,
1118}
1119
1120pub struct WebsocketApi {
1121    pub common: Arc<WebsocketCommon>,
1122    configuration: ConfigurationWebsocketApi,
1123    is_connecting: Arc<Mutex<bool>>,
1124    stream_callbacks: Mutex<HashMap<String, Vec<Arc<dyn Fn(&Value) + Send + Sync + 'static>>>>,
1125}
1126
1127impl WebsocketApi {
1128    #[must_use]
1129    /// Creates a new WebSocket API instance with the given configuration and connection pool.
1130    ///
1131    /// # Arguments
1132    ///
1133    /// * `configuration` - Configuration settings for the WebSocket API
1134    /// * `connection_pool` - A vector of WebSocket connections to be used
1135    ///
1136    /// # Returns
1137    ///
1138    /// An `Arc`-wrapped `WebsocketApi` instance ready for use
1139    ///
1140    /// # Panics
1141    ///
1142    /// This function will panic if the configuration is not valid.
1143    ///
1144    /// # Examples
1145    ///
1146    ///
1147    /// let api = `WebsocketApi::new(config`, `connection_pool`);
1148    ///
1149    pub fn new(
1150        configuration: ConfigurationWebsocketApi,
1151        connection_pool: Vec<Arc<WebsocketConnection>>,
1152    ) -> Arc<Self> {
1153        let agent_clone = configuration.agent.clone();
1154        let common = WebsocketCommon::new(
1155            connection_pool,
1156            configuration.mode.clone(),
1157            usize::try_from(configuration.reconnect_delay)
1158                .expect("reconnect_delay should fit in usize"),
1159            agent_clone,
1160        );
1161
1162        Arc::new(Self {
1163            common: Arc::clone(&common),
1164            configuration,
1165            is_connecting: Arc::new(Mutex::new(false)),
1166            stream_callbacks: Mutex::new(HashMap::new()),
1167        })
1168    }
1169
1170    /// Connects to a WebSocket server with a configurable timeout and connection handling.
1171    ///
1172    /// This method attempts to establish a WebSocket connection if not already connected.
1173    /// It prevents multiple simultaneous connection attempts and supports a connection pool.
1174    ///
1175    /// # Errors
1176    ///
1177    /// Returns a `WebsocketError` if:
1178    /// - Connection fails
1179    /// - Connection times out after 10 seconds
1180    ///
1181    /// # Behavior
1182    ///
1183    /// - Checks if already connected and returns early if so
1184    /// - Prevents multiple concurrent connection attempts
1185    /// - Sets a WebSocket handler for the connection pool
1186    /// - Attempts to connect with a 10-second timeout
1187    ///
1188    /// # Returns
1189    ///
1190    /// `Ok(())` if connection is successful, otherwise a `WebsocketError`
1191    pub async fn connect(self: Arc<Self>) -> Result<(), WebsocketError> {
1192        if self.common.is_connected(None).await {
1193            info!("WebSocket connection already established");
1194            return Ok(());
1195        }
1196
1197        {
1198            let mut flag = self.is_connecting.lock().await;
1199            if *flag {
1200                info!("Already connecting...");
1201                return Ok(());
1202            }
1203            *flag = true;
1204        }
1205
1206        let url = self.prepare_url(self.configuration.ws_url.as_deref().unwrap_or_default());
1207
1208        let handler: Arc<dyn WebsocketHandler> = self.clone();
1209        for slot in &self.common.connection_pool {
1210            slot.set_handler(handler.clone()).await;
1211        }
1212
1213        let result = select! {
1214            () = sleep(Duration::from_secs(10)) => Err(WebsocketError::Timeout),
1215            r = self.common.clone().connect_pool(&url) => r,
1216        };
1217
1218        {
1219            let mut flag = self.is_connecting.lock().await;
1220            *flag = false;
1221        }
1222
1223        result
1224    }
1225
1226    /// Disconnects the WebSocket connection.
1227    ///
1228    /// # Returns
1229    ///
1230    /// `Ok(())` if disconnection is successful, otherwise a `WebsocketError`
1231    ///
1232    /// # Errors
1233    ///
1234    /// Returns a `WebsocketError` if:
1235    /// - Disconnection fails
1236    /// - Connection is not established
1237    ///
1238    pub async fn disconnect(&self) -> Result<(), WebsocketError> {
1239        self.common.disconnect().await
1240    }
1241
1242    /// Checks if the WebSocket connection is currently established.
1243    ///
1244    /// # Returns
1245    ///
1246    /// `true` if the connection is active, `false` otherwise.
1247    pub async fn is_connected(&self) -> bool {
1248        self.common.is_connected(None).await
1249    }
1250
1251    /// Sends a ping to the WebSocket server to maintain the connection.
1252    ///
1253    /// This method calls the underlying connection's ping mechanism to check
1254    /// and keep the WebSocket connection alive.
1255    pub async fn ping_server(&self) {
1256        self.common.ping_server().await;
1257    }
1258
1259    /// Sends a WebSocket message with the specified method and payload.
1260    ///
1261    /// This method prepares and sends a WebSocket request with optional API key and signature.
1262    /// It handles connection status, generates a unique request ID, and processes the response.
1263    ///
1264    /// # Arguments
1265    ///
1266    /// * `method` - The WebSocket API method to be called
1267    /// * `payload` - A map of parameters to be sent with the request
1268    /// * `options` - Configuration options for message sending (API key, signing)
1269    ///
1270    /// # Returns
1271    ///
1272    /// A deserialized response of type `R` or a `WebsocketError` if the request fails
1273    ///
1274    /// # Panics
1275    ///
1276    /// Panics if:
1277    ///
1278    /// - The WebSocket is not connected
1279    /// - The request cannot be processed
1280    /// - No response is received within the timeout
1281    ///
1282    /// # Errors
1283    ///
1284    /// Returns `WebsocketError` if:
1285    /// - The WebSocket is not connected
1286    /// - The request cannot be processed
1287    /// - No response is received within the timeout
1288    pub async fn send_message<R>(
1289        &self,
1290        method: &str,
1291        mut payload: BTreeMap<String, Value>,
1292        options: WebsocketMessageSendOptions,
1293    ) -> Result<WebsocketApiResponse<R>, WebsocketError>
1294    where
1295        R: DeserializeOwned + Send + Sync + 'static,
1296    {
1297        if !self.common.is_connected(None).await {
1298            return Err(WebsocketError::NotConnected);
1299        }
1300
1301        let id = payload
1302            .get("id")
1303            .and_then(Value::as_str)
1304            .filter(|s| ID_REGEX.is_match(s))
1305            .map_or_else(random_string, String::from);
1306
1307        payload.remove("id");
1308
1309        let mut params = remove_empty_value(payload.into_iter());
1310        if options.with_api_key || options.is_signed {
1311            params.insert(
1312                "apiKey".into(),
1313                Value::String(
1314                    self.configuration
1315                        .api_key
1316                        .clone()
1317                        .expect("API key must be set"),
1318                ),
1319            );
1320        }
1321        if options.is_signed {
1322            let ts = get_timestamp();
1323            let ts_i64 = i64::try_from(ts).map_err(|e| WebsocketError::Protocol(e.to_string()))?;
1324            params.insert(
1325                "timestamp".into(),
1326                Value::Number(serde_json::Number::from(ts_i64)),
1327            );
1328            let mut sorted_params = sort_object_params(&params);
1329            let sig = self
1330                .configuration
1331                .signature_gen
1332                .get_signature(&sorted_params)
1333                .map_err(|e| WebsocketError::Protocol(e.to_string()))?;
1334            sorted_params.insert("signature".into(), Value::String(sig));
1335            params = sorted_params.into_iter().collect();
1336        }
1337
1338        let request = json!({
1339            "id": id,
1340            "method": method,
1341            "params": params,
1342        });
1343        debug!("Sending message to WebSocket API: {:?}", request);
1344
1345        let timeout = Duration::from_millis(self.configuration.timeout);
1346        let maybe_rx = self
1347            .common
1348            .send(
1349                serde_json::to_string(&request).unwrap(),
1350                Some(id.clone()),
1351                true,
1352                timeout,
1353                None,
1354            )
1355            .await?;
1356
1357        let msg: Value = if let Some(rx) = maybe_rx {
1358            rx.await.unwrap_or(Err(WebsocketError::Timeout))?
1359        } else {
1360            return Err(WebsocketError::NoResponse);
1361        };
1362
1363        let raw = msg
1364            .get("result")
1365            .or_else(|| msg.get("response"))
1366            .cloned()
1367            .unwrap_or(Value::Null);
1368
1369        let rate_limits = msg
1370            .get("rateLimits")
1371            .and_then(Value::as_array)
1372            .map(|arr| {
1373                arr.iter()
1374                    .filter_map(|v| serde_json::from_value(v.clone()).ok())
1375                    .collect()
1376            })
1377            .unwrap_or_default();
1378
1379        Ok(WebsocketApiResponse {
1380            raw,
1381            rate_limits,
1382            _marker: PhantomData,
1383        })
1384    }
1385
1386    /// Prepares a WebSocket URL by appending a validated time unit parameter.
1387    ///
1388    /// This method checks if a time unit is configured and validates it. If valid,
1389    /// the time unit is appended to the URL as a query parameter. If no time unit
1390    /// is specified or the validation fails, the original URL is returned.
1391    ///
1392    /// # Arguments
1393    ///
1394    /// * `ws_url` - The base WebSocket URL to be modified
1395    ///
1396    /// # Returns
1397    ///
1398    /// A modified URL with the time unit parameter, or the original URL if no
1399    /// modification is possible
1400    fn prepare_url(&self, ws_url: &str) -> String {
1401        let mut url = ws_url.to_string();
1402
1403        let time_unit = match &self.configuration.time_unit {
1404            Some(u) => u.to_string(),
1405            None => return url,
1406        };
1407
1408        match validate_time_unit(&time_unit) {
1409            Ok(Some(validated)) => {
1410                let sep = if url.contains('?') { '&' } else { '?' };
1411                url.push(sep);
1412                url.push_str("timeUnit=");
1413                url.push_str(validated);
1414            }
1415            Ok(None) => {}
1416            Err(e) => {
1417                error!("Invalid time unit provided: {:?}", e);
1418            }
1419        }
1420
1421        url
1422    }
1423}
1424
1425#[async_trait]
1426impl WebsocketHandler for WebsocketApi {
1427    /// Callback method invoked when a WebSocket connection is successfully opened.
1428    ///
1429    /// This method is called after a WebSocket connection is established. Currently,
1430    /// it does not perform any actions and serves as a placeholder for potential
1431    /// connection initialization or logging.
1432    ///
1433    /// # Arguments
1434    ///
1435    /// * `_url` - The URL of the WebSocket connection that was opened
1436    /// * `_connection` - An Arc-wrapped WebSocket connection context
1437    ///
1438    /// # Remarks
1439    ///
1440    /// This method can be overridden by implementations to add custom logic
1441    /// when a WebSocket connection is first opened.
1442    async fn on_open(&self, _url: String, _connection: Arc<WebsocketConnection>) {}
1443
1444    /// Handles incoming WebSocket messages by parsing the JSON payload and processing pending requests.
1445    ///
1446    /// This method is responsible for:
1447    /// - Parsing the received WebSocket message as JSON
1448    /// - Matching the message to a pending request by its ID
1449    /// - Sending the response back to the original request's completion channel
1450    /// - Handling both successful and error responses
1451    ///
1452    /// # Arguments
1453    ///
1454    /// * `data` - The raw WebSocket message as a string
1455    /// * `connection` - The WebSocket connection context associated with the message
1456    ///
1457    /// # Behavior
1458    ///
1459    /// - If message parsing fails, logs an error and returns
1460    /// - For known request IDs, sends the response to the corresponding completion channel
1461    /// - Warns about responses for unknown or timed-out requests
1462    /// - Differentiates between successful (status < 400) and error responses
1463    async fn on_message(&self, data: String, connection: Arc<WebsocketConnection>) {
1464        let msg: Value = match serde_json::from_str(&data) {
1465            Ok(v) => v,
1466            Err(err) => {
1467                error!("Failed to parse WebSocket message {} – {}", data, err);
1468                return;
1469            }
1470        };
1471
1472        if let Some(id) = msg.get("id").and_then(Value::as_str) {
1473            let maybe_sender = {
1474                let mut conn_state = connection.state.lock().await;
1475                conn_state.pending_requests.remove(id)
1476            };
1477
1478            if let Some(PendingRequest { completion }) = maybe_sender {
1479                connection.drain_notify.notify_one();
1480                let status = msg.get("status").and_then(Value::as_u64).unwrap_or(200);
1481                if status >= 400 {
1482                    let error_map = msg
1483                        .get("error")
1484                        .and_then(Value::as_object)
1485                        .unwrap_or(&serde_json::Map::new())
1486                        .clone();
1487
1488                    let code = error_map
1489                        .get("code")
1490                        .and_then(Value::as_i64)
1491                        .unwrap_or(status as i64);
1492
1493                    let message = error_map
1494                        .get("msg")
1495                        .and_then(Value::as_str)
1496                        .unwrap_or("Unknown error")
1497                        .to_string();
1498
1499                    let _ = completion.send(Err(WebsocketError::ResponseError { code, message }));
1500                } else {
1501                    let _ = completion.send(Ok(msg.clone()));
1502                }
1503            }
1504
1505            return;
1506        }
1507
1508        if let Some(event) = msg.get("event") {
1509            if event.get("e").is_some() {
1510                for callbacks in self.stream_callbacks.lock().await.values() {
1511                    for callback in callbacks {
1512                        callback(event);
1513                    }
1514                }
1515
1516                return;
1517            }
1518        }
1519
1520        warn!(
1521            "Received response for unknown or timed-out request: {}",
1522            data
1523        );
1524    }
1525
1526    /// Generates the URL to use for reconnecting to a WebSocket connection.
1527    ///
1528    /// # Arguments
1529    ///
1530    /// * `default_url` - The original URL to potentially modify for reconnection
1531    /// * `_connection` - The WebSocket connection context (currently unused)
1532    ///
1533    /// # Returns
1534    ///
1535    /// A `String` representing the URL to use for reconnecting
1536    async fn get_reconnect_url(
1537        &self,
1538        default_url: String,
1539        _connection: Arc<WebsocketConnection>,
1540    ) -> String {
1541        default_url
1542    }
1543}
1544
1545pub struct WebsocketStreams {
1546    pub common: Arc<WebsocketCommon>,
1547    is_connecting: Mutex<bool>,
1548    connection_streams: Mutex<HashMap<String, Arc<WebsocketConnection>>>,
1549    configuration: ConfigurationWebsocketStreams,
1550}
1551
1552impl WebsocketStreams {
1553    /// Creates a new `WebsocketStreams` instance with the given configuration and connection pool.
1554    ///
1555    /// # Arguments
1556    ///
1557    /// * `configuration` - Configuration settings for the WebSocket streams
1558    /// * `connection_pool` - A vector of WebSocket connections to use
1559    ///
1560    /// # Returns
1561    ///
1562    /// An `Arc`-wrapped `WebsocketStreams` instance
1563    ///
1564    /// # Panics
1565    ///
1566    /// Panics if the `reconnect_delay` cannot be converted to `usize`
1567    #[must_use]
1568    pub fn new(
1569        configuration: ConfigurationWebsocketStreams,
1570        connection_pool: Vec<Arc<WebsocketConnection>>,
1571    ) -> Arc<Self> {
1572        let agent_clone = configuration.agent.clone();
1573        let common = WebsocketCommon::new(
1574            connection_pool,
1575            configuration.mode.clone(),
1576            usize::try_from(configuration.reconnect_delay)
1577                .expect("reconnect_delay should fit in usize"),
1578            agent_clone,
1579        );
1580        Arc::new(Self {
1581            common,
1582            is_connecting: Mutex::new(false),
1583            connection_streams: Mutex::new(HashMap::new()),
1584            configuration,
1585        })
1586    }
1587
1588    /// Establishes a WebSocket connection for the given streams.
1589    ///
1590    /// This method attempts to connect to a WebSocket server using the connection pool.
1591    /// If a connection is already established or in progress, it returns immediately.
1592    ///
1593    /// # Arguments
1594    ///
1595    /// * `streams` - A vector of stream identifiers to connect to
1596    ///
1597    /// # Returns
1598    ///
1599    /// A `Result` indicating whether the connection was successful or an error occurred
1600    ///
1601    /// # Errors
1602    ///
1603    /// Returns a `WebsocketError` if the connection fails or times out after 10 seconds
1604    pub async fn connect(self: Arc<Self>, streams: Vec<String>) -> Result<(), WebsocketError> {
1605        if self.common.is_connected(None).await {
1606            info!("WebSocket connection already established");
1607            return Ok(());
1608        }
1609
1610        {
1611            let mut flag = self.is_connecting.lock().await;
1612            if *flag {
1613                info!("Already connecting...");
1614                return Ok(());
1615            }
1616            *flag = true;
1617        }
1618
1619        let url = self.prepare_url(&streams);
1620
1621        let handler: Arc<dyn WebsocketHandler> = self.clone();
1622        for conn in &self.common.connection_pool {
1623            conn.set_handler(handler.clone()).await;
1624        }
1625
1626        let connect_res = select! {
1627            () = sleep(Duration::from_secs(10)) => Err(WebsocketError::Timeout),
1628            r = self.common.clone().connect_pool(&url) => r,
1629        };
1630
1631        {
1632            let mut flag = self.is_connecting.lock().await;
1633            *flag = false;
1634        }
1635
1636        connect_res
1637    }
1638
1639    /// Disconnects all WebSocket connections and clears associated state.
1640    ///
1641    /// # Returns
1642    ///
1643    /// A `Result` indicating whether the disconnection was successful or an error occurred
1644    ///
1645    /// # Errors
1646    ///
1647    /// Returns a `WebsocketError` if there are issues during the disconnection process
1648    ///
1649    /// # Side Effects
1650    ///
1651    /// - Clears stream callbacks for all connections
1652    /// - Clears pending subscriptions for all connections
1653    /// - Removes all connection stream mappings
1654    pub async fn disconnect(&self) -> Result<(), WebsocketError> {
1655        for connection in &self.common.connection_pool {
1656            let mut conn_state = connection.state.lock().await;
1657            conn_state.stream_callbacks.clear();
1658            conn_state.pending_subscriptions.clear();
1659        }
1660        self.connection_streams.lock().await.clear();
1661        self.common.disconnect().await
1662    }
1663
1664    /// Checks if the WebSocket connection is currently active.
1665    ///
1666    /// # Returns
1667    ///
1668    /// `true` if the WebSocket connection is established, `false` otherwise.
1669    pub async fn is_connected(&self) -> bool {
1670        self.common.is_connected(None).await
1671    }
1672
1673    /// Sends a ping to the WebSocket server to maintain the connection.
1674    ///
1675    /// This method delegates the ping operation to the underlying common WebSocket connection.
1676    /// It is typically used to keep the connection alive and check its status.
1677    ///
1678    /// # Side Effects
1679    ///
1680    /// Sends a ping request to the WebSocket server through the common connection.
1681    pub async fn ping_server(&self) {
1682        self.common.ping_server().await;
1683    }
1684
1685    /// Subscribes to multiple WebSocket streams, handling connection and queuing logic.
1686    ///
1687    /// # Arguments
1688    ///
1689    /// * `streams` - A vector of stream names to subscribe to
1690    /// * `id` - An optional request identifier for the subscription
1691    ///
1692    /// # Behavior
1693    ///
1694    /// - Filters out streams already subscribed
1695    /// - Assigns streams to appropriate connections
1696    /// - Handles subscription for active connections
1697    /// - Queues subscriptions for inactive connections
1698    ///
1699    /// # Side Effects
1700    ///
1701    /// - Sends subscription payloads for active connections
1702    /// - Adds pending subscriptions for inactive connections
1703    pub async fn subscribe(self: Arc<Self>, streams: Vec<String>, id: Option<String>) {
1704        let streams: Vec<String> = {
1705            let map = self.connection_streams.lock().await;
1706            streams
1707                .into_iter()
1708                .filter(|s| !map.contains_key(s))
1709                .collect()
1710        };
1711        let connection_streams = self.handle_stream_assignment(streams).await;
1712
1713        for (conn, streams) in connection_streams {
1714            if !self.common.is_connected(Some(&conn)).await {
1715                info!(
1716                    "Connection is not ready. Queuing subscription for streams: {:?}",
1717                    streams
1718                );
1719                let mut conn_state = conn.state.lock().await;
1720                conn_state.pending_subscriptions.extend(streams.clone());
1721                continue;
1722            }
1723            self.send_subscription_payload(conn.clone(), streams.clone(), id.clone());
1724        }
1725    }
1726
1727    /// Unsubscribes from specified WebSocket streams.
1728    ///
1729    /// # Arguments
1730    ///
1731    /// * `streams` - A vector of stream names to unsubscribe from
1732    /// * `id` - An optional request identifier for the unsubscription
1733    ///
1734    /// # Behavior
1735    ///
1736    /// - Validates the request identifier or generates a random one
1737    /// - Checks for active connections and subscribed streams
1738    /// - Sends unsubscribe payload for streams with active callbacks
1739    /// - Removes stream from connection streams and callbacks
1740    ///
1741    /// # Side Effects
1742    ///
1743    /// - Sends unsubscribe request to WebSocket server
1744    /// - Removes stream tracking from internal state
1745    ///
1746    /// # Async
1747    ///
1748    /// This method is asynchronous and requires `.await` when called
1749    ///
1750    /// # Panics
1751    ///
1752    /// This method may panic if the request identifier is not valid.
1753    ///
1754    pub async fn unsubscribe(&self, streams: Vec<String>, id: Option<String>) {
1755        let request_id = id
1756            .filter(|s| ID_REGEX.is_match(s))
1757            .unwrap_or_else(random_string);
1758
1759        for stream in streams {
1760            let maybe_conn = { self.connection_streams.lock().await.get(&stream).cloned() };
1761
1762            let conn = if let Some(c) = maybe_conn {
1763                if !self.common.is_connected(Some(&c)).await {
1764                    warn!(
1765                        "Stream {} not associated with an active connection.",
1766                        stream
1767                    );
1768                    continue;
1769                }
1770                c
1771            } else {
1772                warn!("Stream {} was not subscribed.", stream);
1773                continue;
1774            };
1775
1776            let callbacks = {
1777                let conn_state = conn.state.lock().await;
1778                conn_state
1779                    .stream_callbacks
1780                    .get(&stream)
1781                    .is_none_or(std::vec::Vec::is_empty)
1782            };
1783
1784            if !callbacks {
1785                continue;
1786            }
1787
1788            let payload = json!({
1789                "method": "UNSUBSCRIBE",
1790                "params": [stream.clone()],
1791                "id": request_id,
1792            });
1793
1794            info!("UNSUBSCRIBE → {:?}", payload);
1795
1796            let common = Arc::clone(&self.common);
1797            let conn_clone = Arc::clone(&conn);
1798            let msg = serde_json::to_string(&payload).unwrap();
1799            spawn(async move {
1800                let _ = common
1801                    .send(msg, None, false, Duration::ZERO, Some(conn_clone))
1802                    .await;
1803            });
1804
1805            {
1806                let mut connection_streams = self.connection_streams.lock().await;
1807                connection_streams.remove(&stream);
1808            }
1809            {
1810                let mut conn_state = conn.state.lock().await;
1811                conn_state.stream_callbacks.remove(&stream);
1812            }
1813        }
1814    }
1815
1816    /// Checks if a specific stream is currently subscribed.
1817    ///
1818    /// # Arguments
1819    ///
1820    /// * `stream` - The stream identifier to check for subscription status
1821    ///
1822    /// # Returns
1823    ///
1824    /// `true` if the stream is subscribed, `false` otherwise
1825    ///
1826    /// # Async
1827    ///
1828    /// This method is asynchronous and requires `.await` when called
1829    pub async fn is_subscribed(&self, stream: &str) -> bool {
1830        self.connection_streams.lock().await.contains_key(stream)
1831    }
1832
1833    /// Prepares a WebSocket URL for streaming with optional stream names and time unit configuration.
1834    ///
1835    /// # Arguments
1836    ///
1837    /// * `streams` - A slice of stream names to be included in the URL
1838    ///
1839    /// # Returns
1840    ///
1841    /// A fully constructed WebSocket URL with optional stream and time unit parameters
1842    ///
1843    /// # Notes
1844    ///
1845    /// - If no time unit is specified, the base URL is returned
1846    /// - Validates and appends the time unit parameter if provided and valid
1847    /// - Handles URL parameter separator based on existing query parameters
1848    fn prepare_url(&self, streams: &[String]) -> String {
1849        let mut url = format!(
1850            "{}/stream?streams={}",
1851            self.configuration.ws_url.as_deref().unwrap_or(""),
1852            streams.join("/")
1853        );
1854
1855        let time_unit = match &self.configuration.time_unit {
1856            Some(u) => u.to_string(),
1857            None => return url,
1858        };
1859
1860        match validate_time_unit(&time_unit) {
1861            Ok(Some(validated)) => {
1862                let sep = if url.contains('?') { '&' } else { '?' };
1863                url.push(sep);
1864                url.push_str("timeUnit=");
1865                url.push_str(validated);
1866            }
1867            Ok(None) => {}
1868            Err(e) => {
1869                error!("Invalid time unit provided: {:?}", e);
1870            }
1871        }
1872
1873        url
1874    }
1875
1876    /// Handles stream assignment by finding or creating WebSocket connections for a list of streams.
1877    ///
1878    /// This method attempts to assign streams to existing WebSocket connections or creates new
1879    /// connections if needed. It groups streams by their assigned connections and handles scenarios
1880    /// such as closed or pending reconnection connections.
1881    ///
1882    /// # Arguments
1883    ///
1884    /// * `streams` - A vector of stream names to be assigned
1885    ///
1886    /// # Returns
1887    ///
1888    /// A vector of tuples containing WebSocket connections and their associated streams
1889    ///
1890    /// # Errors
1891    ///
1892    /// Returns an empty result if no connections can be established for the streams
1893    async fn handle_stream_assignment(
1894        &self,
1895        streams: Vec<String>,
1896    ) -> Vec<(Arc<WebsocketConnection>, Vec<String>)> {
1897        let mut connection_streams: Vec<(String, Arc<WebsocketConnection>)> = Vec::new();
1898
1899        for stream in streams {
1900            let mut conn_opt = {
1901                let map = self.connection_streams.lock().await;
1902                map.get(&stream).cloned()
1903            };
1904
1905            let need_new = if let Some(conn) = &conn_opt {
1906                let state = conn.state.lock().await;
1907                state.close_initiated || state.reconnection_pending
1908            } else {
1909                true
1910            };
1911
1912            if need_new {
1913                match self.common.get_connection(true).await {
1914                    Ok(new_conn) => {
1915                        let mut map = self.connection_streams.lock().await;
1916                        map.insert(stream.clone(), new_conn.clone());
1917                        conn_opt = Some(new_conn);
1918                    }
1919                    Err(err) => {
1920                        warn!(
1921                            "No available WebSocket connection to subscribe stream `{}`: {:?}",
1922                            stream, err
1923                        );
1924                        continue;
1925                    }
1926                }
1927            }
1928
1929            if let Some(conn) = conn_opt {
1930                {
1931                    let mut conn_state = conn.state.lock().await;
1932                    conn_state
1933                        .stream_callbacks
1934                        .entry(stream.clone())
1935                        .or_default();
1936                }
1937                connection_streams.push((stream.clone(), conn));
1938            }
1939        }
1940
1941        let mut groups: Vec<(Arc<WebsocketConnection>, Vec<String>)> = Vec::new();
1942        for (stream, conn) in connection_streams {
1943            if let Some((_, vec)) = groups.iter_mut().find(|(c, _)| Arc::ptr_eq(c, &conn)) {
1944                vec.push(stream);
1945            } else {
1946                groups.push((conn, vec![stream]));
1947            }
1948        }
1949
1950        groups
1951    }
1952
1953    /// Sends a WebSocket subscription payload for the specified streams.
1954    ///
1955    /// # Arguments
1956    ///
1957    /// * `connection` - The WebSocket connection to send the subscription on
1958    /// * `streams` - A vector of stream names to subscribe to
1959    /// * `id` - An optional request ID for the subscription (will be randomly generated if not provided)
1960    ///
1961    /// # Remarks
1962    ///
1963    /// This method constructs a SUBSCRIBE payload, logs it, and sends it asynchronously using the WebSocket connection.
1964    /// If serialization fails, an error is logged and the method returns without sending.
1965    fn send_subscription_payload(
1966        &self,
1967        connection: Arc<WebsocketConnection>,
1968        streams: Vec<String>,
1969        id: Option<String>,
1970    ) {
1971        let request_id = id
1972            .filter(|s| ID_REGEX.is_match(s))
1973            .unwrap_or_else(random_string);
1974
1975        let payload = json!({
1976            "method": "SUBSCRIBE",
1977            "params": streams,
1978            "id": request_id,
1979        });
1980
1981        info!("SUBSCRIBE → {:?}", payload);
1982
1983        let common = Arc::clone(&self.common);
1984        let msg = match serde_json::to_string(&payload) {
1985            Ok(s) => s,
1986            Err(e) => {
1987                error!("Failed to serialize SUBSCRIBE payload: {}", e);
1988                return;
1989            }
1990        };
1991        let conn_clone = Arc::clone(&connection);
1992
1993        spawn(async move {
1994            let _ = common
1995                .send(msg, None, false, Duration::ZERO, Some(conn_clone))
1996                .await;
1997        });
1998    }
1999}
2000
2001#[async_trait]
2002impl WebsocketHandler for WebsocketStreams {
2003    /// Handles the WebSocket connection opening by processing any pending subscriptions.
2004    ///
2005    /// This method is called when a WebSocket connection is established. It retrieves
2006    /// any pending stream subscriptions from the connection state and sends them
2007    /// immediately using the `send_subscription_payload` method.
2008    ///
2009    /// # Arguments
2010    ///
2011    /// * `_url` - The URL of the WebSocket connection (unused)
2012    /// * `connection` - The WebSocket connection that has just been opened
2013    ///
2014    /// # Remarks
2015    ///
2016    /// If there are any pending subscriptions, they are sent as a batch subscription
2017    /// payload. The method uses a lock to safely access and clear the pending subscriptions
2018    /// from the connection state.
2019    async fn on_open(&self, _url: String, connection: Arc<WebsocketConnection>) {
2020        let pending_subs: Vec<String> = {
2021            let mut conn_state = connection.state.lock().await;
2022            take(&mut conn_state.pending_subscriptions)
2023                .into_iter()
2024                .collect()
2025        };
2026
2027        if !pending_subs.is_empty() {
2028            info!("Processing queued subscriptions for connection");
2029            self.send_subscription_payload(connection.clone(), pending_subs, None);
2030        }
2031    }
2032
2033    /// Handles incoming WebSocket stream messages by parsing the JSON payload and invoking registered stream callbacks.
2034    ///
2035    /// This method processes WebSocket messages with a specific structure, extracting the stream name and data.
2036    /// It retrieves and executes any registered callbacks associated with the stream name.
2037    ///
2038    /// # Arguments
2039    ///
2040    /// * `data` - The raw WebSocket message as a JSON-formatted string
2041    /// * `connection` - The WebSocket connection through which the message was received
2042    ///
2043    /// # Behavior
2044    ///
2045    /// - Parses the JSON message
2046    /// - Extracts the stream name and data payload
2047    /// - Looks up and invokes any registered callbacks for the stream
2048    /// - Silently returns if message parsing or stream extraction fails
2049    async fn on_message(&self, data: String, connection: Arc<WebsocketConnection>) {
2050        let msg: Value = match serde_json::from_str(&data) {
2051            Ok(v) => v,
2052            Err(err) => {
2053                error!(
2054                    "Failed to parse WebSocket stream message {} – {}",
2055                    data, err
2056                );
2057                return;
2058            }
2059        };
2060
2061        let (stream_name, payload) = match (
2062            msg.get("stream").and_then(Value::as_str),
2063            msg.get("data").cloned(),
2064        ) {
2065            (Some(name), Some(data)) => (name.to_string(), data),
2066            _ => return,
2067        };
2068
2069        let callbacks = {
2070            let conn_state = connection.state.lock().await;
2071            conn_state
2072                .stream_callbacks
2073                .get(&stream_name)
2074                .cloned()
2075                .unwrap_or_else(Vec::new)
2076        };
2077
2078        for callback in callbacks {
2079            callback(&payload);
2080        }
2081    }
2082
2083    /// Retrieves the reconnection URL for a specific WebSocket connection by identifying all streams associated with that connection.
2084    ///
2085    /// # Arguments
2086    ///
2087    /// * `_default_url` - A default URL that can be used if no specific reconnection URL is determined
2088    /// * `connection` - The WebSocket connection for which to generate a reconnection URL
2089    ///
2090    /// # Returns
2091    ///
2092    /// A URL string that can be used to reconnect to the WebSocket, based on the streams associated with the given connection
2093    async fn get_reconnect_url(
2094        &self,
2095        _default_url: String,
2096        connection: Arc<WebsocketConnection>,
2097    ) -> String {
2098        let connection_streams = self.connection_streams.lock().await;
2099        let reconnect_streams = connection_streams
2100            .iter()
2101            .filter_map(|(stream, conn_arc)| {
2102                if Arc::ptr_eq(conn_arc, &connection) {
2103                    Some(stream.clone())
2104                } else {
2105                    None
2106                }
2107            })
2108            .collect::<Vec<_>>();
2109        self.prepare_url(&reconnect_streams)
2110    }
2111}
2112
2113pub struct WebsocketStream<T> {
2114    websocket_base: WebsocketBase,
2115    stream_or_id: String,
2116    callback: Mutex<Option<Arc<dyn Fn(&Value) + Send + Sync>>>,
2117    pub id: Option<String>,
2118    _phantom: PhantomData<T>,
2119}
2120
2121impl<T> WebsocketStream<T>
2122where
2123    T: DeserializeOwned + Send + 'static,
2124{
2125    /// Registers a callback function for a specific event on the WebSocket stream.
2126    ///
2127    /// This method currently only supports the "message" event. When a message is received,
2128    /// the provided callback function will be invoked with the deserialized payload.
2129    ///
2130    /// # Arguments
2131    ///
2132    /// * `event` - The event type to listen for (currently only "message" is supported)
2133    /// * `callback_fn` - A function that will be called with the deserialized message payload
2134    ///
2135    /// # Errors
2136    ///
2137    /// Logs an error if the payload cannot be deserialized into the expected type
2138    ///
2139    /// # Examples
2140    ///
2141    ///
2142    /// stream.on("message", |data: `MyType`| {
2143    ///     // Handle the deserialized message
2144    /// });
2145    async fn on<F>(&self, event: &str, callback_fn: F)
2146    where
2147        F: Fn(T) + Send + Sync + 'static,
2148    {
2149        if event != "message" {
2150            return;
2151        }
2152
2153        let cb_wrapper: Arc<dyn Fn(&Value) + Send + Sync> =
2154            Arc::new(
2155                move |v: &Value| match serde_json::from_value::<T>(v.clone()) {
2156                    Ok(data) => callback_fn(data),
2157                    Err(e) => error!("Failed to deserialize stream payload: {:?}", e),
2158                },
2159            );
2160
2161        {
2162            let mut guard = self.callback.lock().await;
2163            *guard = Some(cb_wrapper.clone());
2164        }
2165
2166        match &self.websocket_base {
2167            WebsocketBase::WebsocketStreams(ws_streams) => {
2168                let conn = {
2169                    let map = ws_streams.connection_streams.lock().await;
2170                    map.get(&self.stream_or_id)
2171                        .cloned()
2172                        .expect("stream must be subscribed")
2173                };
2174
2175                {
2176                    let mut conn_state = conn.state.lock().await;
2177                    let entry = conn_state
2178                        .stream_callbacks
2179                        .entry(self.stream_or_id.clone())
2180                        .or_default();
2181
2182                    if !entry
2183                        .iter()
2184                        .any(|existing| Arc::ptr_eq(existing, &cb_wrapper))
2185                    {
2186                        entry.push(cb_wrapper);
2187                    }
2188                }
2189            }
2190            WebsocketBase::WebsocketApi(ws_api) => {
2191                let mut stream_callbacks = ws_api.stream_callbacks.lock().await;
2192                let entry = stream_callbacks
2193                    .entry(self.stream_or_id.clone())
2194                    .or_default();
2195
2196                if !entry
2197                    .iter()
2198                    .any(|existing| Arc::ptr_eq(existing, &cb_wrapper))
2199                {
2200                    entry.push(cb_wrapper);
2201                }
2202            }
2203        }
2204    }
2205
2206    /// Synchronously sets a message callback for the WebSocket stream on the current thread.
2207    ///
2208    /// # Arguments
2209    ///
2210    /// * `callback_fn` - A function that will be called with the deserialized message payload
2211    ///
2212    /// # Panics
2213    ///
2214    /// Panics if the thread runtime fails to be created or if the thread join fails
2215    ///
2216    /// # Examples
2217    ///
2218    ///
2219    /// let stream = `Arc::new(WebsocketStream::new())`;
2220    /// `stream.on_message(|data`: `MyType`| {
2221    ///     // Handle the deserialized message
2222    /// });
2223    ///
2224    pub fn on_message<F>(self: &Arc<Self>, callback_fn: F)
2225    where
2226        T: Send + Sync,
2227        F: Fn(T) + Send + Sync + 'static,
2228    {
2229        let handler: Arc<Self> = Arc::clone(self);
2230
2231        std::thread::spawn(move || {
2232            let rt = tokio::runtime::Builder::new_current_thread()
2233                .enable_all()
2234                .build()
2235                .expect("failed to build Tokio runtime");
2236
2237            rt.block_on(handler.on("message", callback_fn));
2238        })
2239        .join()
2240        .expect("on_message thread panicked");
2241    }
2242
2243    /// Unsubscribes from the current WebSocket stream and removes the associated callback.
2244    ///
2245    /// This method performs the following actions:
2246    /// - Removes the current callback associated with the stream
2247    /// - Removes the callback from the connection's stream callbacks
2248    /// - Asynchronously unsubscribes from the stream using the WebSocket streams base
2249    ///
2250    /// # Panics
2251    ///
2252    /// Panics if the stream is not subscribed to
2253    ///
2254    /// # Notes
2255    /// - If no callback is present, no action is taken
2256    /// - Spawns an asynchronous task to handle the unsubscription process
2257    pub async fn unsubscribe(&self) {
2258        let maybe_cb = {
2259            let mut guard = self.callback.lock().await;
2260            guard.take()
2261        };
2262
2263        if let Some(cb) = maybe_cb {
2264            match &self.websocket_base {
2265                WebsocketBase::WebsocketStreams(ws_streams) => {
2266                    let conn = {
2267                        let map = ws_streams.connection_streams.lock().await;
2268                        map.get(&self.stream_or_id)
2269                            .cloned()
2270                            .expect("stream must have been subscribed")
2271                    };
2272
2273                    {
2274                        let mut conn_state = conn.state.lock().await;
2275                        if let Some(list) = conn_state.stream_callbacks.get_mut(&self.stream_or_id)
2276                        {
2277                            list.retain(|existing| !Arc::ptr_eq(existing, &cb));
2278                        }
2279                    }
2280
2281                    let stream = self.stream_or_id.clone();
2282                    let id = self.id.clone();
2283                    let websocket_streams_base = Arc::clone(ws_streams);
2284                    spawn(async move {
2285                        websocket_streams_base.unsubscribe(vec![stream], id).await;
2286                    });
2287                }
2288                WebsocketBase::WebsocketApi(ws_api) => {
2289                    let mut stream_callbacks = ws_api.stream_callbacks.lock().await;
2290                    if let Some(list) = stream_callbacks.get_mut(&self.stream_or_id) {
2291                        list.retain(|existing| !Arc::ptr_eq(existing, &cb));
2292                    }
2293                }
2294            }
2295        }
2296    }
2297}
2298
2299pub async fn create_stream_handler<T>(
2300    websocket_base: WebsocketBase,
2301    stream_or_id: String,
2302    id: Option<String>,
2303) -> Arc<WebsocketStream<T>>
2304where
2305    T: DeserializeOwned + Send + 'static,
2306{
2307    match &websocket_base {
2308        WebsocketBase::WebsocketStreams(ws_streams) => {
2309            ws_streams
2310                .clone()
2311                .subscribe(vec![stream_or_id.clone()], id.clone())
2312                .await;
2313        }
2314        WebsocketBase::WebsocketApi(_) => {}
2315    }
2316
2317    Arc::new(WebsocketStream {
2318        websocket_base,
2319        stream_or_id,
2320        id,
2321        callback: Mutex::new(None),
2322        _phantom: PhantomData,
2323    })
2324}
2325
2326#[cfg(test)]
2327mod tests {
2328    use crate::TOKIO_SHARED_RT;
2329    use crate::common::utils::SignatureGenerator;
2330    use crate::common::websocket::{
2331        PendingRequest, ReconnectEntry, WebsocketApi, WebsocketBase, WebsocketCommon,
2332        WebsocketConnection, WebsocketEvent, WebsocketEventEmitter, WebsocketHandler,
2333        WebsocketMessageSendOptions, WebsocketMode, WebsocketStream, WebsocketStreams,
2334        create_stream_handler,
2335    };
2336    use crate::config::{ConfigurationWebsocketApi, ConfigurationWebsocketStreams, PrivateKey};
2337    use crate::errors::WebsocketError;
2338    use crate::models::TimeUnit;
2339    use async_trait::async_trait;
2340    use futures::{SinkExt, StreamExt};
2341    use regex::Regex;
2342    use serde_json::{Value, json};
2343    use std::collections::{BTreeMap, HashSet};
2344    use std::marker::PhantomData;
2345    use std::net::SocketAddr;
2346    use std::sync::{
2347        Arc,
2348        atomic::{AtomicBool, AtomicUsize, Ordering},
2349    };
2350    use tokio::net::TcpListener;
2351    use tokio::sync::{Mutex, mpsc::unbounded_channel, oneshot};
2352    use tokio::time::{Duration, advance, pause, resume, sleep, timeout};
2353    use tokio_tungstenite::{accept_async, tungstenite, tungstenite::Message};
2354
2355    fn subscribe_events(common: &WebsocketCommon) -> Arc<Mutex<Vec<WebsocketEvent>>> {
2356        let events = Arc::new(Mutex::new(Vec::new()));
2357        let events_clone = events.clone();
2358        common.events.subscribe(move |event| {
2359            let events_clone = events_clone.clone();
2360            tokio::spawn(async move {
2361                events_clone.lock().await.push(event);
2362            });
2363        });
2364        events
2365    }
2366
2367    async fn create_connection(
2368        id: &str,
2369        has_writer: bool,
2370        reconnection_pending: bool,
2371        renewal_pending: bool,
2372        close_initiated: bool,
2373    ) -> Arc<WebsocketConnection> {
2374        let conn = WebsocketConnection::new(id);
2375        let mut st = conn.state.lock().await;
2376        st.reconnection_pending = reconnection_pending;
2377        st.renewal_pending = renewal_pending;
2378        st.close_initiated = close_initiated;
2379        if has_writer {
2380            let (tx, _) = unbounded_channel::<Message>();
2381            st.ws_write_tx = Some(tx);
2382        } else {
2383            st.ws_write_tx = None;
2384        }
2385        drop(st);
2386        conn
2387    }
2388
2389    fn create_websocket_api(time_unit: Option<TimeUnit>) -> Arc<WebsocketApi> {
2390        let sig_gen = SignatureGenerator::new(
2391            Some("api_secret".into()),
2392            None::<PrivateKey>,
2393            None::<String>,
2394        );
2395        let config = ConfigurationWebsocketApi {
2396            api_key: Some("api_key".into()),
2397            api_secret: Some("api_secret".into()),
2398            private_key: None,
2399            private_key_passphrase: None,
2400            ws_url: Some("wss://example.com".into()),
2401            mode: WebsocketMode::Single,
2402            reconnect_delay: 1000,
2403            signature_gen: sig_gen,
2404            timeout: 500,
2405            time_unit,
2406            agent: None,
2407        };
2408        let conn = WebsocketConnection::new("c1");
2409        WebsocketApi::new(config, vec![conn])
2410    }
2411
2412    fn create_websocket_streams(
2413        ws_url: Option<&str>,
2414        conns: Option<Vec<Arc<WebsocketConnection>>>,
2415    ) -> Arc<WebsocketStreams> {
2416        let mut connections: Vec<Arc<WebsocketConnection>> = vec![];
2417        if conns.is_none() {
2418            connections.push(WebsocketConnection::new("c1"));
2419            connections.push(WebsocketConnection::new("c2"));
2420        } else {
2421            connections = conns.expect("Expected connections to be set");
2422        }
2423        let config = ConfigurationWebsocketStreams {
2424            ws_url: Some(ws_url.unwrap_or("example.com").to_string()),
2425            mode: WebsocketMode::Single,
2426            reconnect_delay: 500,
2427            time_unit: None,
2428            agent: None,
2429        };
2430        WebsocketStreams::new(config, connections)
2431    }
2432
2433    mod event_emitter {
2434        use super::*;
2435
2436        #[test]
2437        fn event_emitter_subscribe_and_emit() {
2438            TOKIO_SHARED_RT.block_on(async {
2439                let emitter = WebsocketEventEmitter::new();
2440                let (tx, rx) = oneshot::channel();
2441                let tx = Arc::new(std::sync::Mutex::new(Some(tx)));
2442                let tx_clone = tx.clone();
2443                let _sub = emitter.subscribe(move |event| {
2444                    if let Some(sender) = tx_clone.lock().unwrap().take() {
2445                        let _ = sender.send(event);
2446                    }
2447                });
2448                emitter.emit(WebsocketEvent::Open);
2449                let received = timeout(Duration::from_millis(100), rx)
2450                    .await
2451                    .expect("timed out");
2452                assert_eq!(received, Ok(WebsocketEvent::Open));
2453            });
2454        }
2455    }
2456
2457    mod websocket_common {
2458        use super::*;
2459
2460        mod initialisation {
2461            use super::*;
2462
2463            #[test]
2464            fn single_mode() {
2465                TOKIO_SHARED_RT.block_on(async {
2466                    let common = WebsocketCommon::new(vec![], WebsocketMode::Single, 0, None);
2467                    assert_eq!(common.connection_pool.len(), 1);
2468                });
2469            }
2470
2471            #[test]
2472            fn pool_mode() {
2473                TOKIO_SHARED_RT.block_on(async {
2474                    let common = WebsocketCommon::new(vec![], WebsocketMode::Pool(3), 0, None);
2475                    assert_eq!(common.connection_pool.len(), 3);
2476                });
2477            }
2478        }
2479
2480        mod spawn_reconnect_loop {
2481            use super::*;
2482
2483            #[test]
2484            fn successful_reconnect_entry_triggers_init_connect() {
2485                TOKIO_SHARED_RT.block_on(async {
2486                    let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
2487                    let addr = listener.local_addr().unwrap();
2488                    tokio::spawn(async move {
2489                        if let Ok((stream, _)) = listener.accept().await {
2490                            let mut ws = accept_async(stream).await.unwrap();
2491                            let _ = ws.close(None).await;
2492                        }
2493                    });
2494
2495                    let conn = WebsocketConnection::new("c1");
2496                    let common =
2497                        WebsocketCommon::new(vec![conn.clone()], WebsocketMode::Single, 10, None);
2498                    let url = format!("ws://{addr}");
2499                    common
2500                        .reconnect_tx
2501                        .send(ReconnectEntry {
2502                            connection_id: "c1".into(),
2503                            url: url.clone(),
2504                            is_renewal: false,
2505                        })
2506                        .await
2507                        .unwrap();
2508
2509                    sleep(Duration::from_secs(2)).await;
2510
2511                    let st = conn.state.lock().await;
2512                    assert!(st.ws_write_tx.is_some());
2513                });
2514            }
2515
2516            #[test]
2517            fn reconnect_entry_with_unknown_id_is_ignored() {
2518                TOKIO_SHARED_RT.block_on(async {
2519                    let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
2520                    let addr = listener.local_addr().unwrap();
2521                    tokio::spawn(async move {
2522                        if let Ok((stream, _)) = listener.accept().await {
2523                            let mut ws = accept_async(stream).await.unwrap();
2524                            let _ = ws.close(None).await;
2525                        }
2526                    });
2527
2528                    let conn = WebsocketConnection::new("c1");
2529                    let common =
2530                        WebsocketCommon::new(vec![conn.clone()], WebsocketMode::Single, 5, None);
2531                    let url = format!("ws://{addr}");
2532                    common
2533                        .reconnect_tx
2534                        .send(ReconnectEntry {
2535                            connection_id: "other".into(),
2536                            url,
2537                            is_renewal: false,
2538                        })
2539                        .await
2540                        .unwrap();
2541
2542                    sleep(Duration::from_secs(1)).await;
2543
2544                    let st = conn.state.lock().await;
2545                    assert!(st.ws_write_tx.is_none());
2546                });
2547            }
2548
2549            #[test]
2550            fn renewal_entries_bypass_initial_delay() {
2551                TOKIO_SHARED_RT.block_on(async {
2552                    let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
2553                    let addr = listener.local_addr().unwrap();
2554                    tokio::spawn(async move {
2555                        if let Ok((stream, _)) = listener.accept().await {
2556                            let mut ws = accept_async(stream).await.unwrap();
2557                            let _ = ws.close(None).await;
2558                        }
2559                    });
2560
2561                    let conn = WebsocketConnection::new("renew");
2562                    let common =
2563                        WebsocketCommon::new(vec![conn.clone()], WebsocketMode::Single, 200, None);
2564                    let url = format!("ws://{addr}");
2565                    common
2566                        .reconnect_tx
2567                        .send(ReconnectEntry {
2568                            connection_id: "renew".into(),
2569                            url: url.clone(),
2570                            is_renewal: true,
2571                        })
2572                        .await
2573                        .unwrap();
2574
2575                    sleep(Duration::from_secs(2)).await;
2576
2577                    let st = conn.state.lock().await;
2578
2579                    assert!(st.ws_write_tx.is_some());
2580                });
2581            }
2582
2583            #[test]
2584            fn non_renewal_entries_respect_initial_delay() {
2585                TOKIO_SHARED_RT.block_on(async {
2586                    let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
2587                    let addr = listener.local_addr().unwrap();
2588                    tokio::spawn(async move {
2589                        if let Ok((stream, _)) = listener.accept().await {
2590                            let mut ws = accept_async(stream).await.unwrap();
2591                            let _ = ws.close(None).await;
2592                        }
2593                    });
2594
2595                    let conn = WebsocketConnection::new("nonrenew");
2596                    let common =
2597                        WebsocketCommon::new(vec![conn.clone()], WebsocketMode::Single, 200, None);
2598                    let url = format!("ws://{addr}");
2599                    common
2600                        .reconnect_tx
2601                        .send(ReconnectEntry {
2602                            connection_id: "nonrenew".into(),
2603                            url: url.clone(),
2604                            is_renewal: false,
2605                        })
2606                        .await
2607                        .unwrap();
2608
2609                    sleep(Duration::from_millis(100)).await;
2610                    assert!(conn.state.lock().await.ws_write_tx.is_none());
2611
2612                    sleep(Duration::from_secs(2)).await;
2613
2614                    assert!(conn.state.lock().await.ws_write_tx.is_some());
2615                });
2616            }
2617        }
2618
2619        mod spawn_renewal_loop {
2620            use super::*;
2621
2622            #[tokio::test]
2623            async fn scheduling_renewal_does_not_panic_for_known_connection() {
2624                pause();
2625
2626                let conn = WebsocketConnection::new("known");
2627                let common =
2628                    WebsocketCommon::new(vec![conn.clone()], WebsocketMode::Single, 0, None);
2629                let url = "wss://example".to_string();
2630                common
2631                    .renewal_tx
2632                    .send((conn.id.clone(), url))
2633                    .await
2634                    .unwrap();
2635                advance(Duration::from_secs(23 * 60 * 60 + 1)).await;
2636
2637                resume();
2638            }
2639
2640            #[tokio::test]
2641            async fn scheduling_renewal_ignored_for_unknown_connection() {
2642                pause();
2643
2644                let conn = WebsocketConnection::new("c1");
2645                let common =
2646                    WebsocketCommon::new(vec![conn.clone()], WebsocketMode::Single, 0, None);
2647                common
2648                    .renewal_tx
2649                    .send(("other".into(), "u".into()))
2650                    .await
2651                    .unwrap();
2652                advance(Duration::from_secs(23 * 60 * 60 + 1)).await;
2653
2654                resume();
2655            }
2656        }
2657
2658        mod is_connection_ready {
2659            use super::*;
2660
2661            #[test]
2662            fn is_connection_ready() {
2663                TOKIO_SHARED_RT.block_on(async {
2664                    let conn = WebsocketConnection::new("c1");
2665                    let common =
2666                        WebsocketCommon::new(vec![conn.clone()], WebsocketMode::Single, 0, None);
2667                    assert!(!common.is_connection_ready(&conn, false).await);
2668                    assert!(common.is_connection_ready(&conn, true).await);
2669                });
2670            }
2671
2672            #[test]
2673            fn connection_ready_basic() {
2674                TOKIO_SHARED_RT.block_on(async {
2675                    let conn = create_connection("c1", true, false, false, false).await;
2676                    let common =
2677                        WebsocketCommon::new(vec![conn.clone()], WebsocketMode::Single, 0, None);
2678                    assert!(common.is_connection_ready(&conn, false).await);
2679                });
2680            }
2681
2682            #[test]
2683            fn connection_not_ready_without_writer() {
2684                TOKIO_SHARED_RT.block_on(async {
2685                    let conn = create_connection("c1", false, false, false, false).await;
2686                    let common =
2687                        WebsocketCommon::new(vec![conn.clone()], WebsocketMode::Single, 0, None);
2688                    assert!(!common.is_connection_ready(&conn, false).await);
2689                    assert!(common.is_connection_ready(&conn, true).await);
2690                });
2691            }
2692
2693            #[test]
2694            fn connection_not_ready_when_flagged() {
2695                TOKIO_SHARED_RT.block_on(async {
2696                    let conn1 = create_connection("c1", true, true, false, false).await;
2697                    let conn2 = create_connection("c2", true, false, true, false).await;
2698                    let conn3 = create_connection("c3", true, false, false, true).await;
2699
2700                    let common = WebsocketCommon::new(
2701                        vec![conn1.clone(), conn2.clone(), conn3.clone()],
2702                        WebsocketMode::Pool(3),
2703                        0,
2704                        None,
2705                    );
2706
2707                    assert!(!common.is_connection_ready(&conn1, false).await);
2708                    assert!(!common.is_connection_ready(&conn2, false).await);
2709                    assert!(!common.is_connection_ready(&conn3, false).await);
2710                });
2711            }
2712        }
2713
2714        mod is_connected {
2715            use super::*;
2716
2717            #[test]
2718            fn with_pool_various_connections() {
2719                TOKIO_SHARED_RT.block_on(async {
2720                    let conn_a = create_connection("a", true, false, false, false).await;
2721                    let conn_b = create_connection("b", false, false, false, false).await;
2722                    let conn_c = create_connection("c", true, true, false, false).await;
2723                    let pool = vec![conn_a.clone(), conn_b.clone(), conn_c.clone()];
2724                    let common = WebsocketCommon::new(pool, WebsocketMode::Pool(3), 0, None);
2725
2726                    assert!(common.is_connected(None).await);
2727                    assert!(common.is_connected(Some(&conn_a)).await);
2728                    assert!(!common.is_connected(Some(&conn_b)).await);
2729                    assert!(!common.is_connected(Some(&conn_c)).await);
2730                });
2731            }
2732
2733            #[test]
2734            fn with_pool_all_bad_connections() {
2735                TOKIO_SHARED_RT.block_on(async {
2736                    let bad1 = create_connection("c1", false, false, false, false).await;
2737                    let bad2 = create_connection("c2", true, true, false, false).await;
2738                    let bad3 = create_connection("c3", true, false, false, true).await;
2739                    let common = WebsocketCommon::new(
2740                        vec![bad1, bad2, bad3],
2741                        WebsocketMode::Pool(3),
2742                        0,
2743                        None,
2744                    );
2745
2746                    assert!(!common.is_connected(None).await);
2747                });
2748            }
2749
2750            #[test]
2751            fn with_pool_ignore_close_initiated() {
2752                TOKIO_SHARED_RT.block_on(async {
2753                    let good = create_connection("c1", true, false, false, false).await;
2754                    let closed = create_connection("c2", true, false, false, true).await;
2755                    let bad = create_connection("c3", false, false, false, false).await;
2756                    let common = WebsocketCommon::new(
2757                        vec![closed.clone(), good.clone(), bad.clone()],
2758                        WebsocketMode::Pool(3),
2759                        0,
2760                        None,
2761                    );
2762
2763                    assert!(common.is_connected(None).await);
2764                    assert!(!common.is_connected(Some(&closed)).await);
2765                });
2766            }
2767        }
2768
2769        mod get_connection {
2770            use super::*;
2771
2772            #[test]
2773            fn single_mode() {
2774                TOKIO_SHARED_RT.block_on(async {
2775                    let common = WebsocketCommon::new(vec![], WebsocketMode::Single, 0, None);
2776                    let conn = common
2777                        .get_connection(false)
2778                        .await
2779                        .expect("should get connection");
2780                    assert_eq!(conn.id, common.connection_pool[0].id);
2781                });
2782            }
2783
2784            #[test]
2785            fn pool_mode_not_ready() {
2786                TOKIO_SHARED_RT.block_on(async {
2787                    let common = WebsocketCommon::new(vec![], WebsocketMode::Pool(2), 0, None);
2788                    let result = common.get_connection(false).await;
2789                    assert!(matches!(
2790                        result,
2791                        Err(crate::errors::WebsocketError::NotConnected)
2792                    ));
2793                });
2794            }
2795
2796            #[test]
2797            fn pool_mode_with_ready() {
2798                TOKIO_SHARED_RT.block_on(async {
2799                    let conn1 = WebsocketConnection::new("c1");
2800                    let conn2 = WebsocketConnection::new("c2");
2801                    let (tx1, _rx1) = unbounded_channel();
2802                    {
2803                        let mut s1 = conn1.state.lock().await;
2804                        s1.ws_write_tx = Some(tx1);
2805                    }
2806                    let pool = vec![conn1.clone(), conn2.clone()];
2807                    let common = WebsocketCommon::new(pool, WebsocketMode::Pool(2), 0, None);
2808                    let result = common.get_connection(false).await;
2809                    assert!(result.is_ok());
2810                    let chosen = result.unwrap();
2811                    assert_eq!(chosen.id, conn1.id);
2812                });
2813            }
2814        }
2815
2816        mod close_connection_gracefully {
2817            use super::*;
2818
2819            #[tokio::test]
2820            async fn waits_for_pending_requests_then_closes() {
2821                pause();
2822
2823                let conn = WebsocketConnection::new("c1");
2824                let (tx, mut rx) = unbounded_channel::<Message>();
2825                let (req_tx, _req_rx) = oneshot::channel();
2826                {
2827                    let mut st = conn.state.lock().await;
2828                    st.pending_requests
2829                        .insert("r".to_string(), PendingRequest { completion: req_tx });
2830                }
2831                let common =
2832                    WebsocketCommon::new(vec![conn.clone()], WebsocketMode::Single, 0, None);
2833                let close_fut = common.close_connection_gracefully(tx.clone(), conn.clone());
2834                advance(Duration::from_secs(1)).await;
2835                {
2836                    let mut st = conn.state.lock().await;
2837                    st.pending_requests.clear();
2838                }
2839                conn.drain_notify.notify_waiters();
2840                advance(Duration::from_secs(1)).await;
2841                close_fut.await.unwrap();
2842                match rx.try_recv() {
2843                    Ok(Message::Close(_)) => {}
2844                    other => panic!("expected Close, got {other:?}"),
2845                }
2846
2847                resume();
2848            }
2849
2850            #[tokio::test]
2851            async fn force_closes_after_timeout() {
2852                pause();
2853
2854                let conn = WebsocketConnection::new("c2");
2855                let (tx, mut rx) = unbounded_channel::<Message>();
2856                let (req_tx, _req_rx) = oneshot::channel();
2857                {
2858                    let mut st = conn.state.lock().await;
2859                    st.pending_requests.insert(
2860                        "request_id".to_string(),
2861                        PendingRequest { completion: req_tx },
2862                    );
2863                }
2864                let common =
2865                    WebsocketCommon::new(vec![conn.clone()], WebsocketMode::Single, 0, None);
2866                let close_fut = common.close_connection_gracefully(tx.clone(), conn.clone());
2867                advance(Duration::from_secs(30)).await;
2868                close_fut.await.unwrap();
2869                match rx.try_recv() {
2870                    Ok(Message::Close(_)) => {}
2871                    other => panic!("expected Close on timeout, got {other:?}"),
2872                }
2873
2874                resume();
2875            }
2876        }
2877
2878        mod get_reconnect_url {
2879            use super::*;
2880
2881            struct DummyHandler {
2882                url: String,
2883            }
2884
2885            #[async_trait::async_trait]
2886            impl WebsocketHandler for DummyHandler {
2887                async fn on_open(&self, _url: String, _connection: Arc<WebsocketConnection>) {}
2888                async fn on_message(&self, _data: String, _connection: Arc<WebsocketConnection>) {}
2889                async fn get_reconnect_url(
2890                    &self,
2891                    _default_url: String,
2892                    _connection: Arc<WebsocketConnection>,
2893                ) -> String {
2894                    self.url.clone()
2895                }
2896            }
2897
2898            #[test]
2899            fn returns_default_when_no_handler() {
2900                TOKIO_SHARED_RT.block_on(async {
2901                    let conn = WebsocketConnection::new("c1");
2902                    let common =
2903                        WebsocketCommon::new(vec![conn.clone()], WebsocketMode::Single, 0, None);
2904                    let default = "wss://default".to_string();
2905                    let result = common.get_reconnect_url(&default, conn.clone()).await;
2906                    assert_eq!(result, default);
2907                });
2908            }
2909
2910            #[test]
2911            fn returns_handler_url_when_set() {
2912                TOKIO_SHARED_RT.block_on(async {
2913                    let conn = WebsocketConnection::new("c2");
2914                    let handler = Arc::new(DummyHandler {
2915                        url: "wss://custom".into(),
2916                    });
2917                    conn.set_handler(handler).await;
2918                    let common =
2919                        WebsocketCommon::new(vec![conn.clone()], WebsocketMode::Single, 0, None);
2920                    let default = "wss://default".to_string();
2921                    let result = common.get_reconnect_url(&default, conn.clone()).await;
2922                    assert_eq!(result, "wss://custom");
2923                });
2924            }
2925        }
2926
2927        mod on_open {
2928            use super::*;
2929
2930            struct DummyHandler {
2931                called: Arc<Mutex<bool>>,
2932                opened_url: Arc<Mutex<Option<String>>>,
2933            }
2934
2935            #[async_trait]
2936            impl WebsocketHandler for DummyHandler {
2937                async fn on_open(&self, url: String, _connection: Arc<WebsocketConnection>) {
2938                    let mut flag = self.called.lock().await;
2939                    *flag = true;
2940                    let mut store = self.opened_url.lock().await;
2941                    *store = Some(url);
2942                }
2943                async fn on_message(&self, _data: String, _connection: Arc<WebsocketConnection>) {}
2944                async fn get_reconnect_url(
2945                    &self,
2946                    default_url: String,
2947                    _connection: Arc<WebsocketConnection>,
2948                ) -> String {
2949                    default_url
2950                }
2951            }
2952
2953            #[test]
2954            fn emits_open_and_calls_handler() {
2955                TOKIO_SHARED_RT.block_on(async {
2956                    let conn = WebsocketConnection::new("c1");
2957                    let called = Arc::new(Mutex::new(false));
2958                    let opened_url = Arc::new(Mutex::new(None));
2959                    let handler = Arc::new(DummyHandler {
2960                        called: called.clone(),
2961                        opened_url: opened_url.clone(),
2962                    });
2963
2964                    conn.set_handler(handler.clone()).await;
2965                    let common =
2966                        WebsocketCommon::new(vec![conn.clone()], WebsocketMode::Single, 0, None);
2967                    let events = subscribe_events(&common);
2968                    common
2969                        .on_open("wss://example.com".into(), conn.clone(), None)
2970                        .await;
2971
2972                    sleep(std::time::Duration::from_millis(10)).await;
2973
2974                    let evs = events.lock().await;
2975                    assert!(evs.iter().any(|e| matches!(e, WebsocketEvent::Open)));
2976                    assert!(*called.lock().await);
2977                    assert_eq!(
2978                        opened_url.lock().await.as_deref(),
2979                        Some("wss://example.com")
2980                    );
2981                });
2982            }
2983
2984            #[test]
2985            fn handles_renewal_pending_and_closes_old_writer() {
2986                TOKIO_SHARED_RT.block_on(async {
2987                    let conn = WebsocketConnection::new("c2");
2988                    let (old_tx, mut old_rx) = unbounded_channel::<Message>();
2989                    {
2990                        let mut st = conn.state.lock().await;
2991                        st.renewal_pending = true;
2992                    }
2993                    let common =
2994                        WebsocketCommon::new(vec![conn.clone()], WebsocketMode::Single, 0, None);
2995                    common
2996                        .on_open("url".into(), conn.clone(), Some(old_tx.clone()))
2997                        .await;
2998                    assert!(!conn.state.lock().await.renewal_pending);
2999                    match old_rx.try_recv() {
3000                        Ok(Message::Close(_)) => {}
3001                        other => panic!("expected Close, got {other:?}"),
3002                    }
3003                });
3004            }
3005        }
3006
3007        mod on_message {
3008            use super::*;
3009
3010            struct DummyHandler {
3011                called_with: Arc<Mutex<Vec<String>>>,
3012            }
3013
3014            #[async_trait]
3015            impl WebsocketHandler for DummyHandler {
3016                async fn on_open(&self, _url: String, _connection: Arc<WebsocketConnection>) {}
3017                async fn on_message(&self, data: String, _connection: Arc<WebsocketConnection>) {
3018                    self.called_with.lock().await.push(data);
3019                }
3020                async fn get_reconnect_url(
3021                    &self,
3022                    default_url: String,
3023                    _connection: Arc<WebsocketConnection>,
3024                ) -> String {
3025                    default_url
3026                }
3027            }
3028
3029            #[test]
3030            fn emits_message_event_without_handler() {
3031                TOKIO_SHARED_RT.block_on(async {
3032                    let conn = WebsocketConnection::new("c1");
3033                    let common =
3034                        WebsocketCommon::new(vec![conn.clone()], WebsocketMode::Single, 0, None);
3035                    let events = subscribe_events(&common);
3036                    common.on_message("msg".into(), conn.clone()).await;
3037
3038                    sleep(Duration::from_millis(10)).await;
3039
3040                    let locked = events.lock().await;
3041                    assert!(
3042                        locked
3043                            .iter()
3044                            .any(|e| matches!(e, WebsocketEvent::Message(m) if m == "msg"))
3045                    );
3046                });
3047            }
3048
3049            #[test]
3050            fn calls_handler_and_emits_message() {
3051                TOKIO_SHARED_RT.block_on(async {
3052                    let conn = WebsocketConnection::new("c2");
3053                    let called = Arc::new(Mutex::new(Vec::new()));
3054                    let handler = Arc::new(DummyHandler {
3055                        called_with: called.clone(),
3056                    });
3057                    conn.set_handler(handler.clone()).await;
3058
3059                    let common =
3060                        WebsocketCommon::new(vec![conn.clone()], WebsocketMode::Single, 0, None);
3061                    let events = subscribe_events(&common);
3062                    common.on_message("msg".into(), conn.clone()).await;
3063
3064                    sleep(Duration::from_millis(10)).await;
3065
3066                    let evs = events.lock().await;
3067                    assert!(
3068                        evs.iter()
3069                            .any(|e| matches!(e, WebsocketEvent::Message(m) if m == "msg"))
3070                    );
3071                    let msgs = called.lock().await;
3072                    assert_eq!(msgs.as_slice(), &["msg".to_string()]);
3073                });
3074            }
3075        }
3076
3077        mod create_websocket {
3078            use super::*;
3079
3080            #[test]
3081            fn successful_connection() {
3082                TOKIO_SHARED_RT.block_on(async {
3083                    let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
3084                    let addr: SocketAddr = listener.local_addr().unwrap();
3085                    tokio::spawn(async move {
3086                        if let Ok((stream, _)) = listener.accept().await {
3087                            if let Ok(mut ws_stream) = accept_async(stream).await {
3088                                let _ = ws_stream.close(None).await;
3089                            }
3090                        }
3091                    });
3092
3093                    let url = format!("ws://{addr}");
3094                    let res = WebsocketCommon::create_websocket(&url, None).await;
3095                    assert!(res.is_ok(), "Expected successful handshake, got {res:?}");
3096                });
3097            }
3098
3099            #[test]
3100            fn invalid_url_returns_handshake_error() {
3101                TOKIO_SHARED_RT.block_on(async {
3102                    let res = WebsocketCommon::create_websocket("not-a-valid-url", None).await;
3103                    assert!(matches!(res, Err(WebsocketError::Handshake(_))));
3104                });
3105            }
3106
3107            #[test]
3108            fn unreachable_host_returns_handshake_error() {
3109                TOKIO_SHARED_RT.block_on(async {
3110                    let res = WebsocketCommon::create_websocket("ws://127.0.0.1:1", None).await;
3111                    assert!(matches!(res, Err(WebsocketError::Handshake(_))));
3112                });
3113            }
3114        }
3115
3116        mod connect_pool {
3117            use super::*;
3118
3119            #[test]
3120            fn connects_all_in_pool() {
3121                TOKIO_SHARED_RT.block_on(async {
3122                    let pool_size = 3;
3123                    let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
3124                    let addr = listener.local_addr().unwrap();
3125                    tokio::spawn(async move {
3126                        for _ in 0..pool_size {
3127                            if let Ok((stream, _)) = listener.accept().await {
3128                                let mut ws = accept_async(stream).await.unwrap();
3129                                let _ = ws.close(None).await;
3130                            }
3131                        }
3132                    });
3133                    let conns: Vec<Arc<WebsocketConnection>> = (0..pool_size)
3134                        .map(|i| WebsocketConnection::new(format!("c{i}")))
3135                        .collect();
3136                    let common = WebsocketCommon::new(
3137                        conns.clone(),
3138                        WebsocketMode::Pool(pool_size),
3139                        0,
3140                        None,
3141                    );
3142                    let url = format!("ws://{addr}");
3143                    common.clone().connect_pool(&url).await.unwrap();
3144                    for conn in conns {
3145                        let st = conn.state.lock().await;
3146                        assert!(st.ws_write_tx.is_some());
3147                    }
3148                });
3149            }
3150
3151            #[test]
3152            fn fails_if_any_refused() {
3153                TOKIO_SHARED_RT.block_on(async {
3154                    let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
3155                    let addr = listener.local_addr().unwrap();
3156                    let pool_size = 3;
3157                    tokio::spawn(async move {
3158                        for _ in 0..2 {
3159                            if let Ok((stream, _)) = listener.accept().await {
3160                                let mut ws = accept_async(stream).await.unwrap();
3161                                let _ = ws.close(None).await;
3162                            }
3163                        }
3164                    });
3165                    let mut conns = Vec::new();
3166                    let valid_url = format!("ws://{addr}");
3167                    for i in 0..2 {
3168                        conns.push(WebsocketConnection::new(format!("c{i}")));
3169                    }
3170                    conns.push(WebsocketConnection::new("bad"));
3171                    let common = WebsocketCommon::new(
3172                        conns.clone(),
3173                        WebsocketMode::Pool(pool_size),
3174                        0,
3175                        None,
3176                    );
3177                    let res = common.clone().connect_pool(&valid_url).await;
3178                    assert!(matches!(res, Err(WebsocketError::Handshake(_))));
3179                });
3180            }
3181
3182            #[test]
3183            fn fails_on_invalid_url() {
3184                TOKIO_SHARED_RT.block_on(async {
3185                    let conns = vec![WebsocketConnection::new("c1")];
3186                    let common = WebsocketCommon::new(conns, WebsocketMode::Pool(1), 0, None);
3187                    let res = common.connect_pool("not-a-url").await;
3188                    assert!(matches!(res, Err(WebsocketError::Handshake(_))));
3189                });
3190            }
3191
3192            #[test]
3193            fn fails_if_mixed_success_and_invalid_url() {
3194                TOKIO_SHARED_RT.block_on(async {
3195                    let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
3196                    let addr = listener.local_addr().unwrap();
3197                    tokio::spawn(async move {
3198                        if let Ok((stream, _)) = listener.accept().await {
3199                            let mut ws = accept_async(stream).await.unwrap();
3200                            let _ = ws.close(None).await;
3201                        }
3202                    });
3203                    let good = WebsocketConnection::new("good");
3204                    let bad = WebsocketConnection::new("bad");
3205                    let common =
3206                        WebsocketCommon::new(vec![good, bad], WebsocketMode::Pool(2), 0, None);
3207                    let url = format!("ws://{addr}");
3208                    let res = common.connect_pool(&url).await;
3209                    assert!(matches!(res, Err(WebsocketError::Handshake(_))));
3210                });
3211            }
3212
3213            #[test]
3214            fn init_connect_invoked_for_each() {
3215                TOKIO_SHARED_RT.block_on(async {
3216                    let pool_size = 2;
3217                    let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
3218                    let addr = listener.local_addr().unwrap();
3219                    tokio::spawn(async move {
3220                        for _ in 0..pool_size {
3221                            if let Ok((stream, _)) = listener.accept().await {
3222                                let mut ws = accept_async(stream).await.unwrap();
3223                                let _ = ws.close(None).await;
3224                            }
3225                        }
3226                    });
3227                    let conns: Vec<Arc<WebsocketConnection>> = (0..pool_size)
3228                        .map(|i| WebsocketConnection::new(format!("c{i}")))
3229                        .collect();
3230                    let common = WebsocketCommon::new(
3231                        conns.clone(),
3232                        WebsocketMode::Pool(pool_size),
3233                        0,
3234                        None,
3235                    );
3236                    let url = format!("ws://{addr}");
3237                    common.clone().connect_pool(&url).await.unwrap();
3238                    for conn in conns {
3239                        let st = conn.state.lock().await;
3240                        assert!(st.ws_write_tx.is_some());
3241                    }
3242                });
3243            }
3244
3245            #[test]
3246            fn single_mode_uses_first_connection() {
3247                TOKIO_SHARED_RT.block_on(async {
3248                    let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
3249                    let addr = listener.local_addr().unwrap();
3250                    tokio::spawn(async move {
3251                        if let Ok((stream, _)) = listener.accept().await {
3252                            let mut ws = accept_async(stream).await.unwrap();
3253                            let _ = ws.close(None).await;
3254                        }
3255                    });
3256                    let conn = WebsocketConnection::new("c1");
3257                    let common =
3258                        WebsocketCommon::new(vec![conn.clone()], WebsocketMode::Single, 0, None);
3259                    let url = format!("ws://{addr}");
3260                    common.connect_pool(&url).await.unwrap();
3261                    let st = conn.state.lock().await;
3262                    assert!(st.ws_write_tx.is_some());
3263                });
3264            }
3265        }
3266
3267        mod init_connect {
3268            use super::*;
3269
3270            #[test]
3271            fn pool_mode_none_connection_uses_first() {
3272                TOKIO_SHARED_RT.block_on(async {
3273                    let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
3274                    let addr = listener.local_addr().unwrap();
3275                    tokio::spawn(async move {
3276                        for _ in 0..2 {
3277                            if let Ok((stream, _)) = listener.accept().await {
3278                                let mut ws = accept_async(stream).await.unwrap();
3279                                ws.close(None).await.ok();
3280                            }
3281                        }
3282                    });
3283
3284                    let c1 = WebsocketConnection::new("c1");
3285                    let c2 = WebsocketConnection::new("c2");
3286                    let common = WebsocketCommon::new(
3287                        vec![c1.clone(), c2.clone()],
3288                        WebsocketMode::Pool(2),
3289                        0,
3290                        None,
3291                    );
3292                    let url = format!("ws://{addr}");
3293
3294                    common
3295                        .clone()
3296                        .init_connect(&url, false, None)
3297                        .await
3298                        .unwrap();
3299                    let st1 = c1.state.lock().await;
3300                    let st2 = c2.state.lock().await;
3301
3302                    assert!(st1.ws_write_tx.is_some());
3303                    assert!(st2.ws_write_tx.is_none());
3304                });
3305            }
3306
3307            #[test]
3308            fn writer_channel_can_send_text() {
3309                TOKIO_SHARED_RT.block_on(async {
3310                    let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
3311                    let addr = listener.local_addr().unwrap();
3312                    let received = Arc::new(Mutex::new(None::<String>));
3313                    let received_clone = received.clone();
3314
3315                    tokio::spawn(async move {
3316                        if let Ok((stream, _)) = listener.accept().await {
3317                            let mut ws = accept_async(stream).await.unwrap();
3318                            if let Some(Ok(Message::Text(txt))) = ws.next().await {
3319                                *received_clone.lock().await = Some(txt.to_string());
3320                            }
3321                            ws.close(None).await.ok();
3322                        }
3323                    });
3324
3325                    let conn = WebsocketConnection::new("cw");
3326                    let common =
3327                        WebsocketCommon::new(vec![conn.clone()], WebsocketMode::Single, 0, None);
3328                    let url = format!("ws://{addr}");
3329                    common
3330                        .clone()
3331                        .init_connect(&url, false, Some(conn.clone()))
3332                        .await
3333                        .unwrap();
3334
3335                    let tx = conn.state.lock().await.ws_write_tx.clone().unwrap();
3336                    tx.send(Message::Text("ping".into())).unwrap();
3337
3338                    sleep(Duration::from_millis(50)).await;
3339
3340                    let lock = received.lock().await;
3341                    assert_eq!(lock.as_deref(), Some("ping"));
3342                });
3343            }
3344
3345            #[test]
3346            fn responds_to_ping_with_pong() {
3347                TOKIO_SHARED_RT.block_on(async {
3348                    let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
3349                    let addr = listener.local_addr().unwrap();
3350
3351                    let saw_pong = Arc::new(Mutex::new(false));
3352                    let saw_pong2 = saw_pong.clone();
3353
3354                    tokio::spawn(async move {
3355                        if let Ok((stream, _)) = listener.accept().await {
3356                            let mut ws = accept_async(stream).await.unwrap();
3357                            ws.send(Message::Ping(vec![1, 2, 3].into())).await.unwrap();
3358                            if let Some(Ok(Message::Pong(payload))) = ws.next().await {
3359                                if payload[..] == [1, 2, 3] {
3360                                    *saw_pong2.lock().await = true;
3361                                }
3362                            }
3363                            let _ = ws.close(None).await;
3364                        }
3365                    });
3366
3367                    let conn = WebsocketConnection::new("c-ping");
3368                    let common =
3369                        WebsocketCommon::new(vec![conn.clone()], WebsocketMode::Single, 0, None);
3370                    let url = format!("ws://{addr}");
3371                    common
3372                        .clone()
3373                        .init_connect(&url, false, Some(conn))
3374                        .await
3375                        .unwrap();
3376
3377                    sleep(Duration::from_millis(50)).await;
3378
3379                    assert!(*saw_pong.lock().await, "server should have seen a Pong");
3380                });
3381            }
3382
3383            #[test]
3384            fn handshake_error_on_invalid_url() {
3385                TOKIO_SHARED_RT.block_on(async {
3386                    let conn = WebsocketConnection::new("c-invalid");
3387                    let common =
3388                        WebsocketCommon::new(vec![conn.clone()], WebsocketMode::Single, 0, None);
3389                    let res = common
3390                        .clone()
3391                        .init_connect("not-a-url", false, Some(conn.clone()))
3392                        .await;
3393                    assert!(matches!(res, Err(WebsocketError::Handshake(_))));
3394                });
3395            }
3396
3397            #[test]
3398            fn skip_if_writer_exists_and_not_renewal() {
3399                TOKIO_SHARED_RT.block_on(async {
3400                    let conn = WebsocketConnection::new("c-writer");
3401                    let (tx, mut rx) = unbounded_channel::<Message>();
3402                    {
3403                        let mut st = conn.state.lock().await;
3404                        st.ws_write_tx = Some(tx.clone());
3405                    }
3406                    let common =
3407                        WebsocketCommon::new(vec![conn.clone()], WebsocketMode::Single, 0, None);
3408                    let res = common
3409                        .clone()
3410                        .init_connect("ws://127.0.0.1:1", false, Some(conn.clone()))
3411                        .await;
3412
3413                    assert!(res.is_ok());
3414                    assert!(rx.try_recv().is_err());
3415                });
3416            }
3417
3418            #[test]
3419            fn short_circuit_on_already_renewing() {
3420                TOKIO_SHARED_RT.block_on(async {
3421                    let conn = WebsocketConnection::new("c-renew");
3422                    {
3423                        let mut st = conn.state.lock().await;
3424                        st.renewal_pending = true;
3425                    }
3426                    let common =
3427                        WebsocketCommon::new(vec![conn.clone()], WebsocketMode::Single, 0, None);
3428                    let res = common
3429                        .clone()
3430                        .init_connect("ws://127.0.0.1:1", true, Some(conn.clone()))
3431                        .await;
3432
3433                    assert!(res.is_ok());
3434                    assert!(conn.state.lock().await.ws_write_tx.is_none());
3435                });
3436            }
3437
3438            #[test]
3439            fn is_renewal_true_sets_and_clears_flag() {
3440                TOKIO_SHARED_RT.block_on(async {
3441                    let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
3442                    let addr = listener.local_addr().unwrap();
3443                    tokio::spawn(async move {
3444                        if let Ok((stream, _)) = listener.accept().await {
3445                            let mut ws = accept_async(stream).await.unwrap();
3446                            let _ = ws.close(None).await;
3447                        }
3448                    });
3449
3450                    let conn = WebsocketConnection::new("c-new-renew");
3451                    let common =
3452                        WebsocketCommon::new(vec![conn.clone()], WebsocketMode::Single, 0, None);
3453                    let url = format!("ws://{addr}");
3454                    let res = common
3455                        .clone()
3456                        .init_connect(&url, true, Some(conn.clone()))
3457                        .await;
3458
3459                    assert!(res.is_ok());
3460                    let st = conn.state.lock().await;
3461                    assert!(st.ws_write_tx.is_some());
3462                    assert!(!st.renewal_pending);
3463                });
3464            }
3465
3466            #[test]
3467            fn default_connection_selected_when_none_passed() {
3468                TOKIO_SHARED_RT.block_on(async {
3469                    let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
3470                    let addr = listener.local_addr().unwrap();
3471                    tokio::spawn(async move {
3472                        if let Ok((stream, _)) = listener.accept().await {
3473                            let mut ws = accept_async(stream).await.unwrap();
3474                            let _ = ws.close(None).await;
3475                        }
3476                    });
3477                    let conn = WebsocketConnection::new("c-default");
3478                    let common =
3479                        WebsocketCommon::new(vec![conn.clone()], WebsocketMode::Single, 0, None);
3480                    let url = format!("ws://{addr}");
3481                    let res = common.clone().init_connect(&url, false, None).await;
3482
3483                    assert!(res.is_ok());
3484                    assert!(conn.state.lock().await.ws_write_tx.is_some());
3485                });
3486            }
3487
3488            #[test]
3489            fn schedules_reconnect_on_abnormal_close() {
3490                TOKIO_SHARED_RT.block_on(async {
3491                    let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
3492                    let addr = listener.local_addr().unwrap();
3493                    tokio::spawn(async move {
3494                        if let Ok((stream, _)) = listener.accept().await {
3495                            let mut ws = accept_async(stream).await.unwrap();
3496                            ws.close(Some(tungstenite::protocol::CloseFrame {
3497                                code: tungstenite::protocol::frame::coding::CloseCode::Abnormal,
3498                                reason: "oops".into(),
3499                            }))
3500                            .await
3501                            .ok();
3502                        }
3503                    });
3504                    let conn = WebsocketConnection::new("c-close");
3505                    let common =
3506                        WebsocketCommon::new(vec![conn.clone()], WebsocketMode::Single, 10, None);
3507                    let url = format!("ws://{addr}");
3508                    common
3509                        .clone()
3510                        .init_connect(&url, false, Some(conn.clone()))
3511                        .await
3512                        .unwrap();
3513
3514                    sleep(Duration::from_millis(50)).await;
3515
3516                    let st = conn.state.lock().await;
3517                    assert!(
3518                        st.reconnection_pending,
3519                        "expected reconnection_pending to be true after abnormal close"
3520                    );
3521                });
3522            }
3523        }
3524
3525        mod disconnect {
3526            use super::*;
3527
3528            #[test]
3529            fn returns_ok_when_no_connections_are_ready() {
3530                TOKIO_SHARED_RT.block_on(async {
3531                    let conn = WebsocketConnection::new("c1");
3532                    let common =
3533                        WebsocketCommon::new(vec![conn.clone()], WebsocketMode::Single, 0, None);
3534                    let res = common.disconnect().await;
3535
3536                    assert!(res.is_ok());
3537                    assert!(!conn.state.lock().await.close_initiated);
3538                });
3539            }
3540
3541            #[test]
3542            fn closes_all_ready_connections() {
3543                TOKIO_SHARED_RT.block_on(async {
3544                    let conn1 = WebsocketConnection::new("c1");
3545                    let conn2 = WebsocketConnection::new("c2");
3546                    let (tx1, mut rx1) = unbounded_channel::<Message>();
3547                    let (tx2, mut rx2) = unbounded_channel::<Message>();
3548                    {
3549                        let mut s1 = conn1.state.lock().await;
3550                        s1.ws_write_tx = Some(tx1);
3551                    }
3552                    {
3553                        let mut s2 = conn2.state.lock().await;
3554                        s2.ws_write_tx = Some(tx2);
3555                    }
3556                    let common = WebsocketCommon::new(
3557                        vec![conn1.clone(), conn2.clone()],
3558                        WebsocketMode::Pool(2),
3559                        0,
3560                        None,
3561                    );
3562                    let fut = common.disconnect();
3563
3564                    sleep(Duration::from_millis(50)).await;
3565
3566                    fut.await.unwrap();
3567
3568                    assert!(conn1.state.lock().await.close_initiated);
3569                    assert!(conn2.state.lock().await.close_initiated);
3570
3571                    match (rx1.try_recv(), rx2.try_recv()) {
3572                        (Ok(Message::Close(_)), Ok(Message::Close(_))) => {}
3573                        other => panic!("expected two Closes, got {other:?}"),
3574                    }
3575                });
3576            }
3577
3578            #[test]
3579            fn does_not_mark_close_initiated_if_no_writer() {
3580                TOKIO_SHARED_RT.block_on(async {
3581                    let conn = WebsocketConnection::new("c-new");
3582                    let common =
3583                        WebsocketCommon::new(vec![conn.clone()], WebsocketMode::Single, 0, None);
3584                    common.disconnect().await.unwrap();
3585
3586                    assert!(!conn.state.lock().await.close_initiated);
3587                });
3588            }
3589
3590            #[test]
3591            fn mixed_pool_marks_all_and_closes_only_writers() {
3592                TOKIO_SHARED_RT.block_on(async {
3593                    let conn_w = WebsocketConnection::new("with");
3594                    let conn_wo = WebsocketConnection::new("without");
3595                    let (tx, mut rx) = unbounded_channel::<Message>();
3596                    {
3597                        let mut st = conn_w.state.lock().await;
3598                        st.ws_write_tx = Some(tx);
3599                    }
3600                    let common = WebsocketCommon::new(
3601                        vec![conn_w.clone(), conn_wo.clone()],
3602                        WebsocketMode::Pool(2),
3603                        0,
3604                        None,
3605                    );
3606                    let fut = common.disconnect();
3607
3608                    sleep(Duration::from_millis(50)).await;
3609
3610                    fut.await.unwrap();
3611
3612                    assert!(conn_w.state.lock().await.close_initiated);
3613                    assert!(conn_wo.state.lock().await.close_initiated);
3614                    assert!(matches!(rx.try_recv(), Ok(Message::Close(_))));
3615                });
3616            }
3617
3618            #[test]
3619            fn after_disconnect_not_connected() {
3620                TOKIO_SHARED_RT.block_on(async {
3621                    let conn = WebsocketConnection::new("c1");
3622                    let (tx, mut _rx) = unbounded_channel::<Message>();
3623                    {
3624                        let mut st = conn.state.lock().await;
3625                        st.ws_write_tx = Some(tx);
3626                    }
3627                    let common =
3628                        WebsocketCommon::new(vec![conn.clone()], WebsocketMode::Single, 0, None);
3629                    common.disconnect().await.unwrap();
3630                    assert!(!common.is_connected(Some(&conn)).await);
3631                });
3632            }
3633        }
3634
3635        mod ping_server {
3636            use super::*;
3637
3638            #[test]
3639            fn sends_ping_to_all_ready_connections() {
3640                TOKIO_SHARED_RT.block_on(async {
3641                    let mut conns = Vec::new();
3642                    for i in 0..3 {
3643                        let conn = WebsocketConnection::new(format!("c{i}"));
3644                        let (tx, rx) = unbounded_channel::<Message>();
3645                        {
3646                            let mut st = conn.state.lock().await;
3647                            st.ws_write_tx = Some(tx);
3648                        }
3649                        conns.push((conn, rx));
3650                    }
3651                    let common = WebsocketCommon::new(
3652                        conns.iter().map(|(c, _)| c.clone()).collect(),
3653                        WebsocketMode::Pool(3),
3654                        0,
3655                        None,
3656                    );
3657                    common.ping_server().await;
3658                    for (_, mut rx) in conns {
3659                        match rx.try_recv() {
3660                            Ok(Message::Ping(payload)) if payload.is_empty() => {}
3661                            other => panic!("expected empty-payload Ping, got {other:?}"),
3662                        }
3663                    }
3664                });
3665            }
3666
3667            #[test]
3668            fn skips_not_ready_and_partial() {
3669                TOKIO_SHARED_RT.block_on(async {
3670                    let ready = WebsocketConnection::new("ready");
3671                    let not_ready = WebsocketConnection::new("not-ready");
3672                    let (tx_r, mut rx_r) = unbounded_channel::<Message>();
3673                    {
3674                        let mut st = ready.state.lock().await;
3675                        st.ws_write_tx = Some(tx_r);
3676                    }
3677                    {
3678                        let mut st = not_ready.state.lock().await;
3679                        st.ws_write_tx = None;
3680                    }
3681                    let common = WebsocketCommon::new(
3682                        vec![ready.clone(), not_ready.clone()],
3683                        WebsocketMode::Pool(2),
3684                        0,
3685                        None,
3686                    );
3687                    common.ping_server().await;
3688                    match rx_r.try_recv() {
3689                        Ok(Message::Ping(payload)) if payload.is_empty() => {}
3690                        other => panic!("expected Ping on ready, got {other:?}"),
3691                    }
3692                });
3693            }
3694
3695            #[test]
3696            fn no_ping_when_flags_block() {
3697                TOKIO_SHARED_RT.block_on(async {
3698                    let conn = WebsocketConnection::new("c1");
3699                    let (tx, mut rx) = unbounded_channel::<Message>();
3700                    {
3701                        let mut st = conn.state.lock().await;
3702                        st.ws_write_tx = Some(tx);
3703                        st.reconnection_pending = true;
3704                    }
3705                    let common =
3706                        WebsocketCommon::new(vec![conn.clone()], WebsocketMode::Single, 0, None);
3707                    common.ping_server().await;
3708                    assert!(rx.try_recv().is_err());
3709                });
3710            }
3711        }
3712
3713        mod send {
3714            use super::*;
3715
3716            #[test]
3717            fn round_robin_send_without_specific() {
3718                TOKIO_SHARED_RT.block_on(async {
3719                    let conn1 = WebsocketConnection::new("c1");
3720                    let conn2 = WebsocketConnection::new("c2");
3721                    let (tx1, mut rx1) = unbounded_channel::<Message>();
3722                    let (tx2, mut rx2) = unbounded_channel::<Message>();
3723                    {
3724                        let mut s1 = conn1.state.lock().await;
3725                        s1.ws_write_tx = Some(tx1);
3726                    }
3727                    {
3728                        let mut s2 = conn2.state.lock().await;
3729                        s2.ws_write_tx = Some(tx2);
3730                    }
3731                    let common = WebsocketCommon::new(
3732                        vec![conn1.clone(), conn2.clone()],
3733                        WebsocketMode::Pool(2),
3734                        0,
3735                        None,
3736                    );
3737
3738                    let res1 = common
3739                        .send("a".into(), None, false, Duration::from_secs(1), None)
3740                        .await
3741                        .unwrap();
3742                    assert!(res1.is_none());
3743
3744                    let res2 = common
3745                        .send("b".into(), None, false, Duration::from_secs(1), None)
3746                        .await
3747                        .unwrap();
3748                    assert!(res2.is_none());
3749
3750                    assert_eq!(
3751                        if let Message::Text(t) = rx1.try_recv().unwrap() {
3752                            t
3753                        } else {
3754                            panic!()
3755                        },
3756                        "a"
3757                    );
3758                    assert_eq!(
3759                        if let Message::Text(t) = rx2.try_recv().unwrap() {
3760                            t
3761                        } else {
3762                            panic!()
3763                        },
3764                        "b"
3765                    );
3766                });
3767            }
3768
3769            #[test]
3770            fn round_robin_skips_not_ready() {
3771                TOKIO_SHARED_RT.block_on(async {
3772                    let conn1 = WebsocketConnection::new("c1");
3773                    let conn2 = WebsocketConnection::new("c2");
3774                    let (tx2, mut rx2) = unbounded_channel::<Message>();
3775                    {
3776                        let mut s1 = conn1.state.lock().await;
3777                        s1.ws_write_tx = None;
3778                    }
3779                    {
3780                        let mut s2 = conn2.state.lock().await;
3781                        s2.ws_write_tx = Some(tx2);
3782                    }
3783                    let common = WebsocketCommon::new(
3784                        vec![conn1.clone(), conn2.clone()],
3785                        WebsocketMode::Pool(2),
3786                        0,
3787                        None,
3788                    );
3789                    let res = common
3790                        .send("bar".into(), None, false, Duration::from_secs(1), None)
3791                        .await
3792                        .unwrap();
3793                    assert!(res.is_none());
3794                    match rx2.try_recv().unwrap() {
3795                        Message::Text(t) => assert_eq!(t, "bar"),
3796                        other => panic!("unexpected {other:?}"),
3797                    }
3798                });
3799            }
3800
3801            #[test]
3802            fn sync_send_on_specific_connection() {
3803                TOKIO_SHARED_RT.block_on(async {
3804                    let conn1 = WebsocketConnection::new("c1");
3805                    let conn2 = WebsocketConnection::new("c2");
3806                    let (tx2, mut rx2) = unbounded_channel::<Message>();
3807                    {
3808                        let mut st = conn2.state.lock().await;
3809                        st.ws_write_tx = Some(tx2);
3810                    }
3811                    let common = WebsocketCommon::new(
3812                        vec![conn1.clone(), conn2.clone()],
3813                        WebsocketMode::Pool(2),
3814                        0,
3815                        None,
3816                    );
3817                    let res = common
3818                        .send(
3819                            "payload".into(),
3820                            Some("id".into()),
3821                            false,
3822                            Duration::from_secs(1),
3823                            Some(conn2.clone()),
3824                        )
3825                        .await
3826                        .unwrap();
3827                    assert!(res.is_none());
3828                    match rx2.try_recv() {
3829                        Ok(Message::Text(t)) => assert_eq!(t, "payload"),
3830                        other => panic!("expected Text, got {other:?}"),
3831                    }
3832                });
3833            }
3834
3835            #[test]
3836            fn sync_send_with_id_does_not_insert_pending() {
3837                TOKIO_SHARED_RT.block_on(async {
3838                    let conn = WebsocketConnection::new("c1");
3839                    let (tx, mut rx) = unbounded_channel::<Message>();
3840                    {
3841                        let mut st = conn.state.lock().await;
3842                        st.ws_write_tx = Some(tx);
3843                    }
3844                    let common =
3845                        WebsocketCommon::new(vec![conn.clone()], WebsocketMode::Single, 0, None);
3846                    let res = common
3847                        .send(
3848                            "msg".into(),
3849                            Some("id".into()),
3850                            false,
3851                            Duration::from_secs(1),
3852                            Some(conn.clone()),
3853                        )
3854                        .await
3855                        .unwrap();
3856                    assert!(res.is_none());
3857                    assert!(conn.state.lock().await.pending_requests.is_empty());
3858                    match rx.try_recv().unwrap() {
3859                        Message::Text(t) => assert_eq!(t, "msg"),
3860                        other => panic!("unexpected {other:?}"),
3861                    }
3862                });
3863            }
3864
3865            #[test]
3866            fn sync_send_error_if_not_ready() {
3867                TOKIO_SHARED_RT.block_on(async {
3868                    let conn = WebsocketConnection::new("c1");
3869                    let common =
3870                        WebsocketCommon::new(vec![conn.clone()], WebsocketMode::Single, 0, None);
3871                    let err = common
3872                        .send(
3873                            "msg".into(),
3874                            Some("id".into()),
3875                            false,
3876                            Duration::from_secs(1),
3877                            Some(conn.clone()),
3878                        )
3879                        .await
3880                        .unwrap_err();
3881                    assert!(matches!(err, WebsocketError::NotConnected));
3882                });
3883            }
3884
3885            #[test]
3886            fn sync_send_error_when_no_ready() {
3887                TOKIO_SHARED_RT.block_on(async {
3888                    let conn = WebsocketConnection::new("c1");
3889                    let common =
3890                        WebsocketCommon::new(vec![conn.clone()], WebsocketMode::Single, 0, None);
3891                    let err = common
3892                        .send("msg".into(), None, false, Duration::from_secs(1), None)
3893                        .await
3894                        .unwrap_err();
3895                    assert!(matches!(err, WebsocketError::NotConnected));
3896                });
3897            }
3898
3899            #[test]
3900            fn async_send_and_receive() {
3901                TOKIO_SHARED_RT.block_on(async {
3902                    let conn = WebsocketConnection::new("c1");
3903                    let (tx, mut rx) = unbounded_channel::<Message>();
3904                    {
3905                        let mut st = conn.state.lock().await;
3906                        st.ws_write_tx = Some(tx);
3907                    }
3908                    let common =
3909                        WebsocketCommon::new(vec![conn.clone()], WebsocketMode::Single, 0, None);
3910                    let fut = common
3911                        .send(
3912                            "hello".into(),
3913                            Some("id".into()),
3914                            true,
3915                            Duration::from_secs(5),
3916                            Some(conn.clone()),
3917                        )
3918                        .await
3919                        .unwrap()
3920                        .unwrap();
3921                    match rx.try_recv() {
3922                        Ok(Message::Text(t)) => assert_eq!(t, "hello"),
3923                        other => panic!("expected Text, got {other:?}"),
3924                    }
3925                    {
3926                        let mut st = conn.state.lock().await;
3927                        let pr = st.pending_requests.remove("id").unwrap();
3928                        pr.completion.send(Ok(serde_json::json!("ok"))).unwrap();
3929                    }
3930                    let resp = fut.await.unwrap().unwrap();
3931                    assert_eq!(resp, serde_json::json!("ok"));
3932                });
3933            }
3934
3935            #[test]
3936            fn async_send_default_connection() {
3937                TOKIO_SHARED_RT.block_on(async {
3938                    let conn = WebsocketConnection::new("c1");
3939                    let (tx, mut rx) = unbounded_channel::<Message>();
3940                    {
3941                        let mut st = conn.state.lock().await;
3942                        st.ws_write_tx = Some(tx);
3943                    }
3944                    let common =
3945                        WebsocketCommon::new(vec![conn.clone()], WebsocketMode::Single, 0, None);
3946                    let fut = common
3947                        .send(
3948                            "msg".into(),
3949                            Some("id".into()),
3950                            true,
3951                            Duration::from_secs(5),
3952                            None,
3953                        )
3954                        .await
3955                        .unwrap()
3956                        .unwrap();
3957                    match rx.try_recv() {
3958                        Ok(Message::Text(t)) => assert_eq!(t, "msg"),
3959                        _ => panic!("no text"),
3960                    }
3961                    {
3962                        let mut st = conn.state.lock().await;
3963                        let pr = st.pending_requests.remove("id").unwrap();
3964                        pr.completion.send(Ok(serde_json::json!(123))).unwrap();
3965                    }
3966                    let resp = fut.await.unwrap().unwrap();
3967                    assert_eq!(resp, serde_json::json!(123));
3968                });
3969            }
3970
3971            #[test]
3972            fn async_send_error_if_no_id() {
3973                TOKIO_SHARED_RT.block_on(async {
3974                    let conn = WebsocketConnection::new("c§");
3975                    let (tx, _rx) = unbounded_channel::<Message>();
3976                    {
3977                        let mut st = conn.state.lock().await;
3978                        st.ws_write_tx = Some(tx);
3979                    }
3980                    let common =
3981                        WebsocketCommon::new(vec![conn.clone()], WebsocketMode::Single, 0, None);
3982                    let err = common
3983                        .send(
3984                            "msg".into(),
3985                            None,
3986                            true,
3987                            Duration::from_secs(1),
3988                            Some(conn.clone()),
3989                        )
3990                        .await
3991                        .unwrap_err();
3992                    assert!(matches!(err, WebsocketError::NotConnected));
3993                });
3994            }
3995
3996            #[test]
3997            fn timeout_rejects_async() {
3998                TOKIO_SHARED_RT.block_on(async {
3999                    pause();
4000                    let conn = WebsocketConnection::new("c1");
4001                    let (tx, _rx) = unbounded_channel::<Message>();
4002                    {
4003                        let mut st = conn.state.lock().await;
4004                        st.ws_write_tx = Some(tx);
4005                    }
4006                    let common =
4007                        WebsocketCommon::new(vec![conn.clone()], WebsocketMode::Single, 0, None);
4008                    let fut = common
4009                        .send(
4010                            "msg".into(),
4011                            Some("id".into()),
4012                            true,
4013                            Duration::from_secs(1),
4014                            Some(conn.clone()),
4015                        )
4016                        .await
4017                        .unwrap()
4018                        .unwrap();
4019                    advance(Duration::from_secs(1)).await;
4020                    let res = fut.await.unwrap();
4021                    assert!(res.is_err(), "expected timeout error");
4022                    assert!(!conn.state.lock().await.pending_requests.contains_key("id"));
4023                });
4024            }
4025
4026            #[test]
4027            fn async_send_errors_if_no_connection_ready() {
4028                TOKIO_SHARED_RT.block_on(async {
4029                    let conn = WebsocketConnection::new("c1");
4030                    let common =
4031                        WebsocketCommon::new(vec![conn.clone()], WebsocketMode::Single, 0, None);
4032                    let err = common
4033                        .send(
4034                            "msg".into(),
4035                            Some("id".into()),
4036                            true,
4037                            Duration::from_secs(1),
4038                            None,
4039                        )
4040                        .await
4041                        .unwrap_err();
4042                    assert!(matches!(err, WebsocketError::NotConnected));
4043                });
4044            }
4045        }
4046    }
4047
4048    mod websocket_api {
4049        use super::*;
4050
4051        mod initialisation {
4052            use super::*;
4053
4054            #[test]
4055            fn new_initializes_common() {
4056                TOKIO_SHARED_RT.block_on(async {
4057                    let conn = WebsocketConnection::new("id");
4058                    let pool = vec![conn.clone()];
4059
4060                    let sig_gen = SignatureGenerator::new(
4061                        Some("api_secret".to_string()),
4062                        None::<PrivateKey>,
4063                        None::<String>,
4064                    );
4065
4066                    let config = ConfigurationWebsocketApi {
4067                        api_key: Some("api_key".to_string()),
4068                        api_secret: Some("api_secret".to_string()),
4069                        private_key: None,
4070                        private_key_passphrase: None,
4071                        ws_url: Some("wss://example".to_string()),
4072                        mode: WebsocketMode::Single,
4073                        reconnect_delay: 1000,
4074                        signature_gen: sig_gen,
4075                        timeout: 500,
4076                        time_unit: None,
4077                        agent: None,
4078                    };
4079
4080                    let api = WebsocketApi::new(config, pool.clone());
4081
4082                    assert_eq!(api.common.connection_pool.len(), 1);
4083                    assert_eq!(api.common.mode, WebsocketMode::Single);
4084
4085                    let flag = *api.is_connecting.lock().await;
4086                    assert!(!flag);
4087                });
4088            }
4089        }
4090
4091        mod connect {
4092            use super::*;
4093
4094            #[test]
4095            fn connect_when_not_connected_establishes() {
4096                TOKIO_SHARED_RT.block_on(async {
4097                    let conn = WebsocketConnection::new("id");
4098                    {
4099                        let mut st = conn.state.lock().await;
4100                        st.ws_write_tx = None;
4101                    }
4102                    let sig = SignatureGenerator::new(
4103                        Some("api_secret".into()),
4104                        None::<PrivateKey>,
4105                        None::<String>,
4106                    );
4107                    let cfg = ConfigurationWebsocketApi {
4108                        api_key: Some("api_key".into()),
4109                        api_secret: Some("api_secret".to_string()),
4110                        private_key: None,
4111                        private_key_passphrase: None,
4112                        ws_url: Some("ws://doesnotexist:1".to_string()),
4113                        mode: WebsocketMode::Single,
4114                        reconnect_delay: 0,
4115                        signature_gen: sig,
4116                        timeout: 10,
4117                        time_unit: None,
4118                        agent: None,
4119                    };
4120                    let api = WebsocketApi::new(cfg, vec![conn.clone()]);
4121                    let res = api.clone().connect().await;
4122                    assert!(!matches!(res, Err(WebsocketError::Timeout)));
4123                });
4124            }
4125
4126            #[test]
4127            fn already_connected_returns_ok() {
4128                TOKIO_SHARED_RT.block_on(async {
4129                    let conn = WebsocketConnection::new("id2");
4130                    let (tx, _) = unbounded_channel();
4131                    {
4132                        let mut st = conn.state.lock().await;
4133                        st.ws_write_tx = Some(tx);
4134                    }
4135                    let sig = SignatureGenerator::new(
4136                        Some("api_secret".to_string()),
4137                        None::<PrivateKey>,
4138                        None::<String>,
4139                    );
4140                    let cfg = ConfigurationWebsocketApi {
4141                        api_key: Some("api_key".to_string()),
4142                        api_secret: Some("api_secret".to_string()),
4143                        private_key: None,
4144                        private_key_passphrase: None,
4145                        ws_url: Some("ws://example.com".to_string()),
4146                        mode: WebsocketMode::Single,
4147                        reconnect_delay: 0,
4148                        signature_gen: sig,
4149                        timeout: 10,
4150                        time_unit: None,
4151                        agent: None,
4152                    };
4153                    let api = WebsocketApi::new(cfg, vec![conn.clone()]);
4154                    let res = api.connect().await;
4155                    assert!(res.is_ok());
4156                });
4157            }
4158
4159            #[test]
4160            fn not_connected_returns_error() {
4161                TOKIO_SHARED_RT.block_on(async {
4162                    let conn = WebsocketConnection::new("id1");
4163                    let sig = SignatureGenerator::new(
4164                        Some("api_secret".to_string()),
4165                        None::<PrivateKey>,
4166                        None::<String>,
4167                    );
4168                    let cfg = ConfigurationWebsocketApi {
4169                        api_key: Some("api_key".to_string()),
4170                        api_secret: Some("api_secret".to_string()),
4171                        private_key: None,
4172                        private_key_passphrase: None,
4173                        ws_url: Some("ws://127.0.0.1:9".to_string()),
4174                        mode: WebsocketMode::Single,
4175                        reconnect_delay: 0,
4176                        signature_gen: sig,
4177                        timeout: 10,
4178                        time_unit: None,
4179                        agent: None,
4180                    };
4181                    let api = WebsocketApi::new(cfg, vec![conn.clone()]);
4182                    let res = api.connect().await;
4183                    assert!(res.is_err());
4184                });
4185            }
4186
4187            #[test]
4188            fn concurrent_calls_both_error_or_ok() {
4189                TOKIO_SHARED_RT.block_on(async {
4190                    let conn = WebsocketConnection::new("id3");
4191                    let sig = SignatureGenerator::new(
4192                        Some("api_secret".to_string()),
4193                        None::<PrivateKey>,
4194                        None::<String>,
4195                    );
4196                    let cfg = ConfigurationWebsocketApi {
4197                        api_key: Some("api_key".to_string()),
4198                        api_secret: Some("api_secret".to_string()),
4199                        private_key: None,
4200                        private_key_passphrase: None,
4201                        ws_url: Some("wss://invalid-domain".to_string()),
4202                        mode: WebsocketMode::Single,
4203                        reconnect_delay: 0,
4204                        signature_gen: sig,
4205                        timeout: 10,
4206                        time_unit: None,
4207                        agent: None,
4208                    };
4209                    let api = WebsocketApi::new(cfg, vec![conn.clone()]);
4210                    let fut1 = tokio::spawn(api.clone().connect());
4211                    let fut2 = tokio::spawn(api.clone().connect());
4212                    let r1 = fut1.await.unwrap();
4213                    let r2 = fut2.await.unwrap();
4214
4215                    assert!(r1.is_err());
4216                    assert!(r2.is_err() || r2.is_ok());
4217                });
4218            }
4219
4220            #[test]
4221            fn pool_failure_is_propagated() {
4222                TOKIO_SHARED_RT.block_on(async {
4223                    let conn = WebsocketConnection::new("w");
4224                    let sig = SignatureGenerator::new(
4225                        Some("api_secret".to_string()),
4226                        None::<PrivateKey>,
4227                        None::<String>,
4228                    );
4229                    let cfg = ConfigurationWebsocketApi {
4230                        api_key: Some("api_key".into()),
4231                        api_secret: Some("api_secret".to_string()),
4232                        private_key: None,
4233                        private_key_passphrase: None,
4234                        ws_url: Some("ws://doesnotexist:1".to_string()),
4235                        mode: WebsocketMode::Single,
4236                        reconnect_delay: 0,
4237                        signature_gen: sig,
4238                        timeout: 10,
4239                        time_unit: None,
4240                        agent: None,
4241                    };
4242                    let api = WebsocketApi::new(cfg, vec![conn.clone()]);
4243                    let res = api.clone().connect().await;
4244                    match res {
4245                        Err(WebsocketError::Handshake(_) | WebsocketError::Timeout) => {}
4246                        _ => panic!("expected handshake or timeout error"),
4247                    }
4248                });
4249            }
4250        }
4251
4252        mod send_message {
4253            use super::*;
4254
4255            #[test]
4256            fn unsigned_message() {
4257                TOKIO_SHARED_RT.block_on(async {
4258                    let api = create_websocket_api(None);
4259                    let conn = &api.common.connection_pool[0];
4260                    let (tx, mut rx) = unbounded_channel::<Message>();
4261                    {
4262                        let mut st = conn.state.lock().await;
4263                        st.ws_write_tx = Some(tx);
4264                    }
4265
4266                    let fut = tokio::spawn({
4267                        let api = api.clone();
4268                        async move {
4269                            let mut params = BTreeMap::new();
4270                            params.insert("foo".into(), Value::String("bar".into()));
4271                            api.send_message::<Value>(
4272                                "mymethod",
4273                                params,
4274                                WebsocketMessageSendOptions {
4275                                    with_api_key: false,
4276                                    is_signed: false,
4277                                },
4278                            )
4279                            .await
4280                            .unwrap()
4281                        }
4282                    });
4283
4284                    let sent = rx.recv().await.unwrap();
4285                    let txt = if let Message::Text(s) = sent {
4286                        s
4287                    } else {
4288                        panic!()
4289                    };
4290                    let req: Value = serde_json::from_str(&txt).unwrap();
4291                    assert_eq!(req["method"], "mymethod");
4292                    assert!(req["params"]["foo"] == "bar");
4293                    assert!(req["params"].get("apiKey").is_none());
4294                    assert!(req["params"].get("timestamp").is_none());
4295                    assert!(req["params"].get("signature").is_none());
4296
4297                    let id = req["id"].as_str().unwrap().to_string();
4298                    let mut st = conn.state.lock().await;
4299                    let pending = st.pending_requests.remove(&id).unwrap();
4300                    let reply = json!({
4301                        "id": id,
4302                        "result": { "x": 42 },
4303                        "rateLimits": [{ "limit": 7 }]
4304                    });
4305                    pending.completion.send(Ok(reply)).unwrap();
4306
4307                    let resp = fut.await.unwrap();
4308                    let rate_limits = resp.rate_limits.unwrap_or_default();
4309
4310                    assert!(rate_limits.is_empty());
4311                    assert_eq!(resp.raw, json!({"x": 42}));
4312                });
4313            }
4314
4315            #[test]
4316            fn with_api_key_only() {
4317                TOKIO_SHARED_RT.block_on(async {
4318                    let api = create_websocket_api(None);
4319                    let conn = &api.common.connection_pool[0];
4320                    let (tx, mut rx) = unbounded_channel::<Message>();
4321                    {
4322                        let mut st = conn.state.lock().await;
4323                        st.ws_write_tx = Some(tx);
4324                    }
4325
4326                    let fut = tokio::spawn({
4327                        let api = api.clone();
4328                        async move {
4329                            let params = BTreeMap::new();
4330                            api.send_message::<Value>(
4331                                "foo",
4332                                params,
4333                                WebsocketMessageSendOptions {
4334                                    with_api_key: true,
4335                                    is_signed: false,
4336                                },
4337                            )
4338                            .await
4339                            .unwrap()
4340                        }
4341                    });
4342
4343                    let txt = if let Message::Text(s) = rx.recv().await.unwrap() {
4344                        s
4345                    } else {
4346                        panic!()
4347                    };
4348                    let req: Value = serde_json::from_str(&txt).unwrap();
4349                    assert_eq!(req["params"]["apiKey"], "api_key");
4350
4351                    let id = req["id"].as_str().unwrap().to_string();
4352                    let mut st = conn.state.lock().await;
4353                    let pending = st.pending_requests.remove(&id).unwrap();
4354                    pending
4355                        .completion
4356                        .send(Ok(json!({
4357                            "id": id,
4358                            "result": {},
4359                            "rateLimits": []
4360                        })))
4361                        .unwrap();
4362
4363                    let resp = fut.await.unwrap();
4364
4365                    assert_eq!(resp.raw, json!({}));
4366                    assert!(st.pending_requests.is_empty());
4367                });
4368            }
4369
4370            #[test]
4371            fn signed_message_has_timestamp_and_signature() {
4372                TOKIO_SHARED_RT.block_on(async {
4373                    let api = create_websocket_api(None);
4374                    let conn = &api.common.connection_pool[0];
4375                    let (tx, mut rx) = unbounded_channel::<Message>();
4376                    {
4377                        let mut st = conn.state.lock().await;
4378                        st.ws_write_tx = Some(tx);
4379                    }
4380
4381                    let fut = tokio::spawn({
4382                        let api = api.clone();
4383                        async move {
4384                            let mut params = BTreeMap::new();
4385                            params.insert("foo".into(), Value::String("bar".into()));
4386                            api.send_message::<Value>(
4387                                "method",
4388                                params,
4389                                WebsocketMessageSendOptions {
4390                                    with_api_key: true,
4391                                    is_signed: true,
4392                                },
4393                            )
4394                            .await
4395                            .unwrap()
4396                        }
4397                    });
4398
4399                    let txt = if let Message::Text(s) = rx.recv().await.unwrap() {
4400                        s
4401                    } else {
4402                        panic!()
4403                    };
4404                    let req: Value = serde_json::from_str(&txt).unwrap();
4405                    let p = &req["params"];
4406                    assert!(p["apiKey"] == "api_key");
4407                    assert!(p["timestamp"].is_number());
4408                    assert!(p["signature"].is_string());
4409
4410                    let id = req["id"].as_str().unwrap().to_string();
4411                    let mut st = conn.state.lock().await;
4412                    let pending = st.pending_requests.remove(&id).unwrap();
4413                    pending
4414                        .completion
4415                        .send(Ok(json!({
4416                            "id": id,
4417                            "result": { "ok": true },
4418                            "rateLimits": []
4419                        })))
4420                        .unwrap();
4421
4422                    let resp = fut.await.unwrap();
4423                    assert_eq!(resp.raw, json!({ "ok": true }));
4424                });
4425            }
4426
4427            #[test]
4428            fn error_if_not_connected() {
4429                TOKIO_SHARED_RT.block_on(async {
4430                    let api = create_websocket_api(None);
4431                    let conn = &api.common.connection_pool[0];
4432                    {
4433                        let mut st = conn.state.lock().await;
4434                        st.ws_write_tx = None;
4435                    }
4436                    let params = BTreeMap::new();
4437                    let err = api
4438                        .send_message::<Value>(
4439                            "method",
4440                            params,
4441                            WebsocketMessageSendOptions {
4442                                with_api_key: false,
4443                                is_signed: false,
4444                            },
4445                        )
4446                        .await
4447                        .unwrap_err();
4448                    matches!(err, WebsocketError::NotConnected);
4449                });
4450            }
4451        }
4452
4453        mod prepare_url {
4454            use super::*;
4455
4456            #[test]
4457            fn no_time_unit() {
4458                TOKIO_SHARED_RT.block_on(async {
4459                    let api = create_websocket_api(None);
4460                    let url = "wss://example.com/ws".to_string();
4461                    assert_eq!(api.prepare_url(&url), url);
4462                });
4463            }
4464
4465            #[test]
4466            fn appends_time_unit() {
4467                TOKIO_SHARED_RT.block_on(async {
4468                    let api = create_websocket_api(Some(TimeUnit::Millisecond));
4469                    let base = "wss://example.com/ws".to_string();
4470                    let got = api.prepare_url(&base);
4471                    assert_eq!(got, format!("{base}?timeUnit=millisecond"));
4472                });
4473            }
4474
4475            #[test]
4476            fn handles_existing_query() {
4477                TOKIO_SHARED_RT.block_on(async {
4478                    let api = create_websocket_api(Some(TimeUnit::Microsecond));
4479                    let base = "wss://example.com/ws?foo=bar".to_string();
4480                    let got = api.prepare_url(&base);
4481                    assert_eq!(got, format!("{base}&timeUnit=microsecond"));
4482                });
4483            }
4484        }
4485
4486        mod on_message {
4487            use super::*;
4488
4489            fn create_websocket_api_and_conn() -> (Arc<WebsocketApi>, Arc<WebsocketConnection>) {
4490                let sig_gen = SignatureGenerator::new(
4491                    Some("api_secret".to_string()),
4492                    None::<_>,
4493                    None::<String>,
4494                );
4495                let config = ConfigurationWebsocketApi {
4496                    api_key: Some("api_key".to_string()),
4497                    api_secret: Some("api_secret".to_string()),
4498                    private_key: None,
4499                    private_key_passphrase: None,
4500                    ws_url: Some("wss://example".to_string()),
4501                    mode: WebsocketMode::Single,
4502                    reconnect_delay: 0,
4503                    signature_gen: sig_gen,
4504                    timeout: 1000,
4505                    time_unit: None,
4506                    agent: None,
4507                };
4508                let conn = WebsocketConnection::new("test");
4509                let api = WebsocketApi::new(config, vec![conn.clone()]);
4510                (api, conn)
4511            }
4512
4513            #[test]
4514            fn resolves_pending_and_removes_request() {
4515                TOKIO_SHARED_RT.block_on(async {
4516                    let (api, conn) = create_websocket_api_and_conn();
4517                    let (tx, rx) = oneshot::channel();
4518                    {
4519                        let mut st = conn.state.lock().await;
4520                        st.pending_requests
4521                            .insert("id1".to_string(), PendingRequest { completion: tx });
4522                    }
4523                    let msg = json!({"id":"id1","status":200,"foo":"bar"});
4524                    api.on_message(msg.to_string(), conn.clone()).await;
4525                    let got = rx.await.unwrap().unwrap();
4526                    assert_eq!(got, msg);
4527                    let st = conn.state.lock().await;
4528                    assert!(st.pending_requests.get("id1").is_none());
4529                });
4530            }
4531
4532            #[test]
4533            fn uses_result_when_present() {
4534                TOKIO_SHARED_RT.block_on(async {
4535                    let (api, conn) = create_websocket_api_and_conn();
4536                    let (tx, rx) = oneshot::channel();
4537                    {
4538                        let mut st = conn.state.lock().await;
4539                        st.pending_requests
4540                            .insert("id1".to_string(), PendingRequest { completion: tx });
4541                    }
4542                    let msg = json!({
4543                        "id": "id1",
4544                        "status": 200,
4545                        "response": [1,2],
4546                        "result": {"a":1}
4547                    });
4548                    api.on_message(msg.to_string(), conn.clone()).await;
4549                    let got = rx.await.unwrap().unwrap();
4550                    assert_eq!(got.get("result").unwrap(), &json!({"a":1}));
4551                });
4552            }
4553
4554            #[test]
4555            fn uses_response_when_no_result() {
4556                TOKIO_SHARED_RT.block_on(async {
4557                    let (api, conn) = create_websocket_api_and_conn();
4558                    let (tx, rx) = oneshot::channel();
4559                    {
4560                        let mut st = conn.state.lock().await;
4561                        st.pending_requests
4562                            .insert("id1".to_string(), PendingRequest { completion: tx });
4563                    }
4564                    let msg = json!({
4565                        "id": "id1",
4566                        "status": 200,
4567                        "response": ["ok"]
4568                    });
4569                    api.on_message(msg.to_string(), conn.clone()).await;
4570                    let got = rx.await.unwrap().unwrap();
4571                    assert_eq!(got.get("response").unwrap(), &json!(["ok"]));
4572                });
4573            }
4574
4575            #[test]
4576            fn errors_for_status_ge_400() {
4577                TOKIO_SHARED_RT.block_on(async {
4578                    let (api, conn) = create_websocket_api_and_conn();
4579                    let (tx, rx) = oneshot::channel();
4580                    {
4581                        let mut st = conn.state.lock().await;
4582                        st.pending_requests
4583                            .insert("bad".to_string(), PendingRequest { completion: tx });
4584                    }
4585                    let err_obj = json!({"code":123,"msg":"oops"});
4586                    let msg = json!({"id":"bad","status":500,"error":err_obj});
4587                    api.on_message(msg.to_string(), conn.clone()).await;
4588                    match rx.await.unwrap() {
4589                        Err(WebsocketError::ResponseError { code, message }) => {
4590                            assert_eq!(code, 123);
4591                            assert_eq!(message, "oops");
4592                        }
4593                        other => panic!("expected ResponseError, got {other:?}"),
4594                    }
4595                    let st = conn.state.lock().await;
4596                    assert!(st.pending_requests.get("bad").is_none());
4597                });
4598            }
4599
4600            #[test]
4601            fn ignores_unknown_id() {
4602                TOKIO_SHARED_RT.block_on(async {
4603                    let (api, conn) = create_websocket_api_and_conn();
4604                    let msg = json!({"id":"nope","status":200});
4605                    api.on_message(msg.to_string(), conn.clone()).await;
4606                    let st = conn.state.lock().await;
4607                    assert!(st.pending_requests.is_empty());
4608                });
4609            }
4610
4611            #[test]
4612            fn parse_error_ignored() {
4613                TOKIO_SHARED_RT.block_on(async {
4614                    let (api, conn) = create_websocket_api_and_conn();
4615                    api.on_message("not json".to_string(), conn.clone()).await;
4616                    let st = conn.state.lock().await;
4617                    assert!(st.pending_requests.is_empty());
4618                });
4619            }
4620
4621            #[test]
4622            fn error_status_sends_error() {
4623                TOKIO_SHARED_RT.block_on(async {
4624                    let (api, conn) = create_websocket_api_and_conn();
4625                    let (tx, rx) = oneshot::channel();
4626                    {
4627                        let mut st = conn.state.lock().await;
4628                        st.pending_requests
4629                            .insert("err".to_string(), PendingRequest { completion: tx });
4630                    }
4631                    let msg = json!({
4632                        "id": "err",
4633                        "status": 500,
4634                        "error": { "code": 42, "msg": "Bad!" }
4635                    });
4636                    api.on_message(msg.to_string(), conn.clone()).await;
4637                    match rx.await.unwrap() {
4638                        Err(WebsocketError::ResponseError { code, message }) => {
4639                            assert_eq!(code, 42);
4640                            assert_eq!(message, "Bad!");
4641                        }
4642                        other => panic!("expected ResponseError, got {other:?}"),
4643                    }
4644                });
4645            }
4646
4647            #[test]
4648            fn unknown_id_logs_warning_and_leaves_pending() {
4649                TOKIO_SHARED_RT.block_on(async {
4650                    let (api, conn) = create_websocket_api_and_conn();
4651                    {
4652                        let mut st = conn.state.lock().await;
4653                        st.pending_requests.insert(
4654                            "keep".to_string(),
4655                            PendingRequest {
4656                                completion: oneshot::channel().0,
4657                            },
4658                        );
4659                    }
4660                    api.on_message(
4661                        json!({ "id": "foo", "status": 200, "result": 1 }).to_string(),
4662                        conn.clone(),
4663                    )
4664                    .await;
4665                    let st = conn.state.lock().await;
4666                    assert!(st.pending_requests.contains_key("keep"));
4667                });
4668            }
4669        }
4670    }
4671
4672    mod websocket_streams {
4673        use super::*;
4674
4675        mod initialisation {
4676            use super::*;
4677
4678            #[test]
4679            fn new_initializes_fields() {
4680                TOKIO_SHARED_RT.block_on(async {
4681                    let sig_gen = SignatureGenerator::new(
4682                        Some("api_secret".to_string()),
4683                        None::<PrivateKey>,
4684                        None::<String>,
4685                    );
4686                    let config = ConfigurationWebsocketApi {
4687                        api_key: Some("api_key".to_string()),
4688                        api_secret: Some("api_secret".to_string()),
4689                        private_key: None,
4690                        private_key_passphrase: None,
4691                        ws_url: Some("wss://example".to_string()),
4692                        mode: WebsocketMode::Single,
4693                        reconnect_delay: 1000,
4694                        signature_gen: sig_gen.clone(),
4695                        timeout: 500,
4696                        time_unit: None,
4697                        agent: None,
4698                    };
4699                    let conn1 = WebsocketConnection::new("c1");
4700                    let conn2 = WebsocketConnection::new("c2");
4701                    let api = WebsocketApi::new(config.clone(), vec![conn1.clone(), conn2.clone()]);
4702
4703                    assert_eq!(api.common.connection_pool.len(), 2);
4704                    assert!(Arc::ptr_eq(&api.common.connection_pool[0], &conn1));
4705                    assert!(Arc::ptr_eq(&api.common.connection_pool[1], &conn2));
4706                    assert_eq!(api.configuration.ws_url, Some("wss://example".to_string()));
4707                    let flag = api.is_connecting.lock().await;
4708                    assert!(!*flag);
4709                });
4710            }
4711        }
4712
4713        mod connect {
4714            use super::*;
4715
4716            #[test]
4717            fn establishes_successfully() {
4718                TOKIO_SHARED_RT.block_on(async {
4719                    let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
4720                    let port = listener.local_addr().unwrap().port();
4721
4722                    tokio::spawn(async move {
4723                        for _ in 0..2 {
4724                            if let Ok((stream, _)) = listener.accept().await {
4725                                let mut ws = accept_async(stream).await.unwrap();
4726                                ws.close(None).await.ok();
4727                            }
4728                        }
4729                    });
4730
4731                    let create_websocket_streams = |ws_url: &str| {
4732                        let c1 = WebsocketConnection::new("c1");
4733                        let c2 = WebsocketConnection::new("c2");
4734                        let config = ConfigurationWebsocketStreams {
4735                            ws_url: Some(ws_url.to_string()),
4736                            mode: WebsocketMode::Pool(2),
4737                            reconnect_delay: 500,
4738                            time_unit: None,
4739                            agent: None,
4740                        };
4741                        WebsocketStreams::new(config, vec![c1, c2])
4742                    };
4743
4744                    let url = format!("ws://127.0.0.1:{port}");
4745                    let ws = create_websocket_streams(&url);
4746
4747                    let res = ws.connect(vec!["stream1".into()]).await;
4748                    assert!(res.is_ok());
4749                });
4750            }
4751
4752            #[test]
4753            fn refused_returns_error() {
4754                TOKIO_SHARED_RT.block_on(async {
4755                    let ws = create_websocket_streams(Some("ws://127.0.0.1:9"), None);
4756                    let res = ws.connect(vec!["stream1".into()]).await;
4757                    assert!(res.is_err());
4758                });
4759            }
4760
4761            #[test]
4762            fn invalid_url_returns_error() {
4763                TOKIO_SHARED_RT.block_on(async {
4764                    let ws = create_websocket_streams(Some("not-a-url"), None);
4765                    let res = ws.connect(vec!["s".into()]).await;
4766                    assert!(res.is_err());
4767                });
4768            }
4769        }
4770
4771        mod disconnect {
4772            use super::*;
4773
4774            #[test]
4775            fn disconnect_clears_state_and_streams() {
4776                TOKIO_SHARED_RT.block_on(async {
4777                    let ws = create_websocket_streams(None, None);
4778                    let conn = &ws.common.connection_pool[0];
4779                    {
4780                        let mut state = conn.state.lock().await;
4781                        state.stream_callbacks.insert("s1".to_string(), Vec::new());
4782                        state.pending_subscriptions.push_back("s2".to_string());
4783                    }
4784                    {
4785                        let mut map = ws.connection_streams.lock().await;
4786                        map.insert("s3".to_string(), Arc::clone(conn));
4787                    }
4788
4789                    let res = ws.disconnect().await;
4790                    assert!(res.is_ok());
4791
4792                    let state = conn.state.lock().await;
4793                    assert!(state.stream_callbacks.is_empty());
4794                    assert!(state.pending_subscriptions.is_empty());
4795
4796                    let map = ws.connection_streams.lock().await;
4797                    assert!(map.is_empty());
4798                });
4799            }
4800        }
4801
4802        mod subscribe {
4803            use super::*;
4804
4805            #[test]
4806            fn empty_list_does_nothing() {
4807                TOKIO_SHARED_RT.block_on(async {
4808                    let ws = create_websocket_streams(None, None);
4809                    ws.clone().subscribe(Vec::new(), None).await;
4810                    let map = ws.connection_streams.lock().await;
4811                    assert!(map.is_empty());
4812                });
4813            }
4814
4815            #[test]
4816            fn queue_when_not_ready() {
4817                TOKIO_SHARED_RT.block_on(async {
4818                    let ws = create_websocket_streams(None, None);
4819                    let conn = ws.common.connection_pool[0].clone();
4820                    ws.clone().subscribe(vec!["s1".into()], None).await;
4821                    let state = conn.state.lock().await;
4822                    let pending: Vec<String> =
4823                        state.pending_subscriptions.iter().cloned().collect();
4824                    assert_eq!(pending, vec!["s1".to_string()]);
4825                });
4826            }
4827
4828            #[test]
4829            fn only_one_subscription_per_stream() {
4830                TOKIO_SHARED_RT.block_on(async {
4831                    let ws = create_websocket_streams(None, None);
4832                    let conn = ws.common.connection_pool[0].clone();
4833                    ws.clone().subscribe(vec!["s1".into()], None).await;
4834                    ws.clone().subscribe(vec!["s1".into()], None).await;
4835                    let state = conn.state.lock().await;
4836                    let pending: Vec<String> =
4837                        state.pending_subscriptions.iter().cloned().collect();
4838                    assert_eq!(pending, vec!["s1".to_string()]);
4839                });
4840            }
4841
4842            #[test]
4843            fn multiple_streams_assigned() {
4844                TOKIO_SHARED_RT.block_on(async {
4845                    let ws = create_websocket_streams(None, None);
4846                    ws.clone()
4847                        .subscribe(vec!["s1".into(), "s2".into()], None)
4848                        .await;
4849                    let map = ws.connection_streams.lock().await;
4850                    assert!(map.contains_key("s1"));
4851                    assert!(map.contains_key("s2"));
4852                });
4853            }
4854
4855            #[test]
4856            fn existing_stream_not_reassigned() {
4857                TOKIO_SHARED_RT.block_on(async {
4858                    let ws = create_websocket_streams(None, None);
4859                    ws.clone().subscribe(vec!["s1".into()], None).await;
4860                    let first_id = {
4861                        let map = ws.connection_streams.lock().await;
4862                        map.get("s1").unwrap().id.clone()
4863                    };
4864                    ws.clone()
4865                        .subscribe(vec!["s1".into(), "s2".into()], None)
4866                        .await;
4867                    let map = ws.connection_streams.lock().await;
4868                    let second_id = map.get("s1").unwrap().id.clone();
4869                    assert_eq!(first_id, second_id);
4870                    assert!(map.contains_key("s2"));
4871                });
4872            }
4873        }
4874
4875        mod unsubscribe {
4876            use super::*;
4877
4878            #[test]
4879            fn removes_stream_with_no_callbacks() {
4880                TOKIO_SHARED_RT.block_on(async {
4881                    let ws = create_websocket_streams(None, None);
4882                    let conn = ws.common.connection_pool[0].clone();
4883
4884                    {
4885                        let (tx, _rx) = unbounded_channel::<Message>();
4886                        let mut st = conn.state.lock().await;
4887                        st.ws_write_tx = Some(tx);
4888                    }
4889
4890                    {
4891                        let mut map = ws.connection_streams.lock().await;
4892                        map.insert("s1".to_string(), conn.clone());
4893                    }
4894                    {
4895                        let mut st = conn.state.lock().await;
4896                        st.stream_callbacks.insert("s1".to_string(), Vec::new());
4897                    }
4898
4899                    ws.unsubscribe(vec!["s1".to_string()], None).await;
4900
4901                    assert!(!ws.connection_streams.lock().await.contains_key("s1"));
4902                    assert!(!conn.state.lock().await.stream_callbacks.contains_key("s1"));
4903                });
4904            }
4905
4906            #[test]
4907            fn preserves_stream_with_callbacks() {
4908                TOKIO_SHARED_RT.block_on(async {
4909                    let ws = create_websocket_streams(None, None);
4910                    let conn = ws.common.connection_pool[1].clone();
4911
4912                    {
4913                        let mut map = ws.connection_streams.lock().await;
4914                        map.insert("s2".to_string(), conn.clone());
4915                    }
4916                    {
4917                        let mut state = conn.state.lock().await;
4918                        state
4919                            .stream_callbacks
4920                            .insert("s2".to_string(), vec![Arc::new(|_: &Value| {})]);
4921                    }
4922
4923                    ws.unsubscribe(vec!["s2".to_string()], None).await;
4924
4925                    assert!(ws.connection_streams.lock().await.contains_key("s2"));
4926                    assert!(conn.state.lock().await.stream_callbacks.contains_key("s2"));
4927                });
4928            }
4929
4930            #[test]
4931            fn does_not_send_if_callbacks_exist() {
4932                TOKIO_SHARED_RT.block_on(async {
4933                    let ws = create_websocket_streams(None, None);
4934                    let conn = ws.common.connection_pool[0].clone();
4935                    {
4936                        let mut map = ws.connection_streams.lock().await;
4937                        map.insert("s1".to_string(), conn.clone());
4938                    }
4939                    {
4940                        let mut state = conn.state.lock().await;
4941                        state.stream_callbacks.insert(
4942                            "s1".to_string(),
4943                            vec![Arc::new(|_: &Value| {}), Arc::new(|_: &Value| {})],
4944                        );
4945                    }
4946                    ws.unsubscribe(vec!["s1".into()], None).await;
4947                    assert!(ws.connection_streams.lock().await.contains_key("s1"));
4948                    assert!(conn.state.lock().await.stream_callbacks.contains_key("s1"));
4949                });
4950            }
4951
4952            #[test]
4953            fn warns_if_not_associated() {
4954                TOKIO_SHARED_RT.block_on(async {
4955                    let ws = create_websocket_streams(None, None);
4956                    ws.unsubscribe(vec!["nope".into()], None).await;
4957                });
4958            }
4959
4960            #[test]
4961            fn empty_list_does_nothing() {
4962                TOKIO_SHARED_RT.block_on(async {
4963                    let ws = create_websocket_streams(None, None);
4964                    let before = ws.connection_streams.lock().await.len();
4965                    ws.unsubscribe(Vec::<String>::new(), None).await;
4966                    let after = ws.connection_streams.lock().await.len();
4967                    assert_eq!(before, after);
4968                });
4969            }
4970
4971            #[test]
4972            fn invalid_custom_id_falls_back() {
4973                TOKIO_SHARED_RT.block_on(async {
4974                    let ws = create_websocket_streams(None, None);
4975                    let conn = ws.common.connection_pool[0].clone();
4976                    {
4977                        let mut map = ws.connection_streams.lock().await;
4978                        map.insert("foo".to_string(), conn.clone());
4979                    }
4980                    {
4981                        let mut state = conn.state.lock().await;
4982                        let (tx, _rx) = unbounded_channel();
4983                        state.ws_write_tx = Some(tx);
4984                        state.stream_callbacks.insert("foo".to_string(), Vec::new());
4985                    }
4986                    ws.unsubscribe(vec!["foo".into()], Some("bad-id".into()))
4987                        .await;
4988                    assert!(!ws.connection_streams.lock().await.contains_key("foo"));
4989                });
4990            }
4991
4992            #[test]
4993            fn removes_even_without_write_channel() {
4994                TOKIO_SHARED_RT.block_on(async {
4995                    let ws = create_websocket_streams(None, None);
4996                    let conn = ws.common.connection_pool[0].clone();
4997                    {
4998                        let mut map = ws.connection_streams.lock().await;
4999                        map.insert("x".to_string(), conn.clone());
5000                    }
5001                    {
5002                        let mut state = conn.state.lock().await;
5003                        let (tx, _rx) = unbounded_channel();
5004                        state.ws_write_tx = Some(tx);
5005                        state.stream_callbacks.insert("x".to_string(), Vec::new());
5006                    }
5007                    ws.unsubscribe(vec!["x".into()], None).await;
5008                    assert!(!ws.connection_streams.lock().await.contains_key("x"));
5009                });
5010            }
5011        }
5012
5013        mod is_subscribed {
5014            use super::*;
5015
5016            #[test]
5017            fn returns_false_when_not_subscribed() {
5018                TOKIO_SHARED_RT.block_on(async {
5019                    let ws = create_websocket_streams(None, None);
5020                    assert!(!ws.is_subscribed("unknown").await);
5021                });
5022            }
5023
5024            #[test]
5025            fn returns_true_when_subscribed() {
5026                TOKIO_SHARED_RT.block_on(async {
5027                    let ws = create_websocket_streams(None, None);
5028                    let conn = ws.common.connection_pool[0].clone();
5029                    {
5030                        let mut map = ws.connection_streams.lock().await;
5031                        map.insert("stream1".to_string(), conn);
5032                    }
5033                    assert!(ws.is_subscribed("stream1").await);
5034                });
5035            }
5036        }
5037
5038        mod prepare_url {
5039            use super::*;
5040
5041            #[test]
5042            fn without_time_unit_returns_base_url() {
5043                TOKIO_SHARED_RT.block_on(async {
5044                    let conns = vec![
5045                        WebsocketConnection::new("c1"),
5046                        WebsocketConnection::new("c2"),
5047                    ];
5048                    let config = ConfigurationWebsocketStreams {
5049                        ws_url: Some("wss://example".to_string()),
5050                        mode: WebsocketMode::Single,
5051                        reconnect_delay: 100,
5052                        time_unit: None,
5053                        agent: None,
5054                    };
5055                    let ws = WebsocketStreams::new(config, conns);
5056                    let url = ws.prepare_url(&["s1".into(), "s2".into()]);
5057                    assert_eq!(url, "wss://example/stream?streams=s1/s2");
5058                });
5059            }
5060
5061            #[test]
5062            fn with_time_unit_appends_parameter() {
5063                TOKIO_SHARED_RT.block_on(async {
5064                    let conns = vec![WebsocketConnection::new("c1")];
5065                    let config = ConfigurationWebsocketStreams {
5066                        ws_url: Some("wss://example".to_string()),
5067                        mode: WebsocketMode::Single,
5068                        reconnect_delay: 100,
5069                        time_unit: Some(TimeUnit::Millisecond),
5070                        agent: None,
5071                    };
5072                    let ws = WebsocketStreams::new(config, conns);
5073                    let url = ws.prepare_url(&["a".into()]);
5074                    assert_eq!(url, "wss://example/stream?streams=a&timeUnit=millisecond");
5075                });
5076            }
5077
5078            #[test]
5079            fn multiple_streams_and_time_unit() {
5080                TOKIO_SHARED_RT.block_on(async {
5081                    let conns = vec![WebsocketConnection::new("c1")];
5082                    let config = ConfigurationWebsocketStreams {
5083                        ws_url: Some("wss://example".to_string()),
5084                        mode: WebsocketMode::Single,
5085                        reconnect_delay: 100,
5086                        time_unit: Some(TimeUnit::Microsecond),
5087                        agent: None,
5088                    };
5089                    let ws = WebsocketStreams::new(config, conns);
5090                    let url = ws.prepare_url(&["x".into(), "y".into(), "z".into()]);
5091                    assert_eq!(
5092                        url,
5093                        "wss://example/stream?streams=x/y/z&timeUnit=microsecond"
5094                    );
5095                });
5096            }
5097        }
5098
5099        mod handle_stream_assignment {
5100            use super::*;
5101
5102            #[test]
5103            fn assigns_new_streams_to_connections() {
5104                TOKIO_SHARED_RT.block_on(async {
5105                    let ws = create_websocket_streams(None, None);
5106                    let groups = ws
5107                        .clone()
5108                        .handle_stream_assignment(vec!["s1".into(), "s2".into()])
5109                        .await;
5110                    let mut seen_streams = HashSet::new();
5111                    for (_conn, streams) in &groups {
5112                        for s in streams {
5113                            seen_streams.insert(s);
5114                        }
5115                    }
5116                    assert_eq!(
5117                        seen_streams,
5118                        ["s1".to_string(), "s2".to_string()].iter().collect()
5119                    );
5120                    assert_eq!(groups.len(), 1);
5121                });
5122            }
5123
5124            #[test]
5125            fn reuses_existing_connection_for_duplicate_stream() {
5126                TOKIO_SHARED_RT.block_on(async {
5127                    let ws = create_websocket_streams(None, None);
5128                    let _ = ws.clone().handle_stream_assignment(vec!["s1".into()]).await;
5129                    let groups = ws
5130                        .clone()
5131                        .handle_stream_assignment(vec!["s1".into(), "s3".into()])
5132                        .await;
5133                    let mut all_streams = Vec::new();
5134                    for (_conn, streams) in groups {
5135                        all_streams.extend(streams);
5136                    }
5137                    all_streams.sort();
5138                    assert_eq!(all_streams, vec!["s1".to_string(), "s3".to_string()]);
5139                });
5140            }
5141
5142            #[test]
5143            fn empty_stream_list_returns_empty() {
5144                TOKIO_SHARED_RT.block_on(async {
5145                    let ws = create_websocket_streams(None, None);
5146                    let groups = ws.clone().handle_stream_assignment(vec![]).await;
5147                    assert!(groups.is_empty());
5148                });
5149            }
5150
5151            #[test]
5152            fn closed_or_reconnecting_forces_reassignment_of_stream() {
5153                TOKIO_SHARED_RT.block_on(async {
5154                    let ws = create_websocket_streams(None, None);
5155                    let mut groups = ws.clone().handle_stream_assignment(vec!["s1".into()]).await;
5156                    let (conn, _) = groups.pop().unwrap();
5157                    {
5158                        let mut st = conn.state.lock().await;
5159                        st.close_initiated = true;
5160                    }
5161                    let groups2 = ws.clone().handle_stream_assignment(vec!["s2".into()]).await;
5162                    assert_eq!(groups2.len(), 1);
5163                    let (_new_conn, streams) = &groups2[0];
5164                    assert_eq!(streams, &vec!["s2".to_string()]);
5165                });
5166            }
5167
5168            #[test]
5169            fn no_available_connections_falls_back_to_one() {
5170                TOKIO_SHARED_RT.block_on(async {
5171                    let ws = create_websocket_streams(None, Some(vec![]));
5172                    let assigned = ws.handle_stream_assignment(vec!["foo".into()]).await;
5173                    assert_eq!(assigned.len(), 1);
5174                    let (_conn, streams) = &assigned[0];
5175                    assert_eq!(streams.as_slice(), &["foo".to_string()]);
5176                });
5177            }
5178
5179            #[test]
5180            fn single_connection_groups_multiple_streams() {
5181                TOKIO_SHARED_RT.block_on(async {
5182                    let conn = WebsocketConnection::new("c1");
5183                    let ws = create_websocket_streams(None, Some(vec![conn.clone()]));
5184                    let assigned = ws
5185                        .handle_stream_assignment(vec!["s1".into(), "s2".into()])
5186                        .await;
5187                    assert_eq!(assigned.len(), 1);
5188                    let (assigned_conn, streams) = &assigned[0];
5189                    assert!(Arc::ptr_eq(assigned_conn, &conn));
5190                    assert_eq!(streams.len(), 2);
5191                    assert!(streams.contains(&"s1".to_string()));
5192                    assert!(streams.contains(&"s2".to_string()));
5193                });
5194            }
5195
5196            #[test]
5197            fn reuse_existing_healthy_connection() {
5198                TOKIO_SHARED_RT.block_on(async {
5199                    let conn = WebsocketConnection::new("c");
5200                    let ws = create_websocket_streams(None, Some(vec![conn.clone()]));
5201                    let _ = ws.handle_stream_assignment(vec!["s1".into()]).await;
5202                    let second = ws.handle_stream_assignment(vec!["s1".into()]).await;
5203                    assert_eq!(second.len(), 1);
5204                    let (assigned_conn, streams) = &second[0];
5205                    assert!(Arc::ptr_eq(assigned_conn, &conn));
5206                    assert_eq!(streams.as_slice(), &["s1".to_string()]);
5207                });
5208            }
5209
5210            #[test]
5211            fn mix_new_and_assigned_streams() {
5212                TOKIO_SHARED_RT.block_on(async {
5213                    let conn = WebsocketConnection::new("c");
5214                    let ws = create_websocket_streams(None, Some(vec![conn.clone()]));
5215                    let _ = ws
5216                        .handle_stream_assignment(vec!["s1".into(), "s2".into()])
5217                        .await;
5218                    let mixed = ws
5219                        .handle_stream_assignment(vec!["s2".into(), "s3".into()])
5220                        .await;
5221                    assert_eq!(mixed.len(), 1);
5222                    let (assigned_conn, streams) = &mixed[0];
5223                    assert!(Arc::ptr_eq(assigned_conn, &conn));
5224                    let mut got = streams.clone();
5225                    got.sort();
5226                    assert_eq!(got, vec!["s2".to_string(), "s3".to_string()]);
5227                });
5228            }
5229        }
5230
5231        mod send_subscription_payload {
5232            use super::*;
5233
5234            #[test]
5235            fn subscribe_payload_with_custom_id_fallbacks_if_invalid() {
5236                TOKIO_SHARED_RT.block_on(async {
5237                    let ws: Arc<WebsocketStreams> =
5238                        create_websocket_streams(Some("ws://example.com"), None);
5239                    let conn = &ws.common.connection_pool[0];
5240                    let (tx, mut rx) = unbounded_channel();
5241                    {
5242                        let mut st = conn.state.lock().await;
5243                        st.ws_write_tx = Some(tx);
5244                    }
5245                    ws.send_subscription_payload(
5246                        conn.clone(),
5247                        vec!["s1".to_string()],
5248                        Some("badid".to_string()),
5249                    );
5250                    let msg = rx.recv().await.expect("no message sent");
5251                    if let Message::Text(txt) = msg {
5252                        let v: serde_json::Value = serde_json::from_str(&txt).unwrap();
5253                        assert_eq!(v["method"], "SUBSCRIBE");
5254                        let id = v["id"].as_str().unwrap();
5255                        assert_ne!(id, "badid");
5256                        assert!(Regex::new(r"^[0-9a-fA-F]{32}$").unwrap().is_match(id));
5257                    } else {
5258                        panic!("unexpected message: {msg:?}");
5259                    }
5260                });
5261            }
5262
5263            #[test]
5264            fn subscribe_payload_with_and_without_custom_id() {
5265                TOKIO_SHARED_RT.block_on(async {
5266                    let ws: Arc<WebsocketStreams> =
5267                        create_websocket_streams(Some("ws://unused"), None);
5268                    let conn = &ws.common.connection_pool[0];
5269                    let (tx, mut rx) = unbounded_channel();
5270                    {
5271                        let mut st = conn.state.lock().await;
5272                        st.ws_write_tx = Some(tx);
5273                    }
5274                    ws.send_subscription_payload(
5275                        conn.clone(),
5276                        vec!["a".to_string(), "b".to_string()],
5277                        Some("deadbeefdeadbeefdeadbeefdeadbeef".to_string()),
5278                    );
5279                    let msg1 = rx.recv().await.unwrap();
5280                    ws.send_subscription_payload(conn.clone(), vec!["x".to_string()], None);
5281                    let msg2 = rx.recv().await.unwrap();
5282
5283                    if let Message::Text(txt1) = msg1 {
5284                        let v1: serde_json::Value = serde_json::from_str(&txt1).unwrap();
5285                        assert_eq!(v1["id"], "deadbeefdeadbeefdeadbeefdeadbeef");
5286                        assert_eq!(
5287                            v1["params"].as_array().unwrap(),
5288                            &vec![serde_json::json!("a"), serde_json::json!("b")]
5289                        );
5290                    } else {
5291                        panic!()
5292                    }
5293
5294                    if let Message::Text(txt2) = msg2 {
5295                        let v2: serde_json::Value = serde_json::from_str(&txt2).unwrap();
5296                        assert_eq!(v2["method"], "SUBSCRIBE");
5297                        let params = v2["params"].as_array().unwrap();
5298                        assert_eq!(params.len(), 1);
5299                        assert_eq!(params[0], "x");
5300                        let id2 = v2["id"].as_str().unwrap();
5301                        assert!(Regex::new(r"^[0-9a-fA-F]{32}$").unwrap().is_match(id2));
5302                    } else {
5303                        panic!()
5304                    }
5305                });
5306            }
5307        }
5308
5309        mod on_open {
5310            use super::*;
5311
5312            #[test]
5313            fn sends_pending_subscriptions() {
5314                TOKIO_SHARED_RT.block_on(async {
5315                    let ws: Arc<WebsocketStreams> =
5316                        create_websocket_streams(Some("ws://example.com"), None);
5317                    let conn = &ws.common.connection_pool[0];
5318                    let (tx, mut rx) = unbounded_channel();
5319                    {
5320                        let mut st = conn.state.lock().await;
5321                        st.ws_write_tx = Some(tx);
5322                        st.pending_subscriptions.push_back("foo".to_string());
5323                        st.pending_subscriptions.push_back("bar".to_string());
5324                    }
5325                    ws.on_open("ws://example.com".to_string(), conn.clone())
5326                        .await;
5327                    let msg = rx.recv().await.expect("no subscription sent");
5328                    if let Message::Text(txt) = msg {
5329                        let v: Value = serde_json::from_str(&txt).unwrap();
5330                        assert_eq!(v["method"], "SUBSCRIBE");
5331                        let params = v["params"].as_array().unwrap();
5332                        assert_eq!(
5333                            params,
5334                            &vec![Value::String("foo".into()), Value::String("bar".into())]
5335                        );
5336                    } else {
5337                        panic!("unexpected message: {msg:?}");
5338                    }
5339                    let st_after = conn.state.lock().await;
5340                    assert!(st_after.pending_subscriptions.is_empty());
5341                });
5342            }
5343
5344            #[test]
5345            fn with_no_pending_subscriptions_sends_nothing() {
5346                TOKIO_SHARED_RT.block_on(async {
5347                    let ws: Arc<WebsocketStreams> =
5348                        create_websocket_streams(Some("ws://example.com"), None);
5349                    let conn = &ws.common.connection_pool[0];
5350                    let (tx, mut rx) = unbounded_channel();
5351                    {
5352                        let mut st = conn.state.lock().await;
5353                        st.ws_write_tx = Some(tx);
5354                    }
5355                    ws.on_open("ws://example.com".to_string(), conn.clone())
5356                        .await;
5357                    assert!(rx.try_recv().is_err(), "unexpected message sent");
5358                });
5359            }
5360
5361            #[test]
5362            fn clears_pending_without_write_channel() {
5363                TOKIO_SHARED_RT.block_on(async {
5364                    let ws: Arc<WebsocketStreams> =
5365                        create_websocket_streams(Some("ws://example.com"), None);
5366                    let conn = &ws.common.connection_pool[0];
5367                    {
5368                        let mut st = conn.state.lock().await;
5369                        st.pending_subscriptions.push_back("solo".to_string());
5370                    }
5371                    ws.on_open("ws://example.com".to_string(), conn.clone())
5372                        .await;
5373                    let st_after = conn.state.lock().await;
5374                    assert!(st_after.pending_subscriptions.is_empty());
5375                });
5376            }
5377        }
5378
5379        mod on_message {
5380            use super::*;
5381
5382            #[test]
5383            fn invokes_registered_callback() {
5384                TOKIO_SHARED_RT.block_on(async {
5385                    let ws: Arc<WebsocketStreams> =
5386                        create_websocket_streams(Some("ws://example.com"), None);
5387                    let conn = &ws.common.connection_pool[0];
5388                    let called = Arc::new(AtomicBool::new(false));
5389                    let called_clone = called.clone();
5390
5391                    {
5392                        let mut st = conn.state.lock().await;
5393                        st.stream_callbacks
5394                            .entry("stream1".to_string())
5395                            .or_default()
5396                            .push(
5397                                (Box::new(move |_: &Value| {
5398                                    called_clone.store(true, Ordering::SeqCst);
5399                                })
5400                                    as Box<dyn Fn(&Value) + Send + Sync>)
5401                                    .into(),
5402                            );
5403                    }
5404
5405                    let msg = json!({
5406                        "stream": "stream1",
5407                        "data": { "key": "value" }
5408                    })
5409                    .to_string();
5410
5411                    ws.on_message(msg, conn.clone()).await;
5412
5413                    assert!(called.load(Ordering::SeqCst));
5414                });
5415            }
5416
5417            #[test]
5418            fn invokes_all_registered_callbacks() {
5419                TOKIO_SHARED_RT.block_on(async {
5420                    let ws: Arc<WebsocketStreams> =
5421                        create_websocket_streams(Some("ws://example.com"), None);
5422                    let conn = &ws.common.connection_pool[0];
5423                    let counter = Arc::new(AtomicUsize::new(0));
5424
5425                    {
5426                        let mut st = conn.state.lock().await;
5427                        let entry = st.stream_callbacks.entry("s".into()).or_default();
5428                        let c1 = counter.clone();
5429                        entry.push(
5430                            (Box::new(move |_: &Value| {
5431                                c1.fetch_add(1, Ordering::SeqCst);
5432                            }) as Box<dyn Fn(&Value) + Send + Sync>)
5433                                .into(),
5434                        );
5435                        let c2 = counter.clone();
5436                        entry.push(
5437                            (Box::new(move |_: &Value| {
5438                                c2.fetch_add(1, Ordering::SeqCst);
5439                            }) as Box<dyn Fn(&Value) + Send + Sync>)
5440                                .into(),
5441                        );
5442                    }
5443
5444                    let msg = json!({"stream":"s","data":42}).to_string();
5445                    ws.on_message(msg, conn.clone()).await;
5446
5447                    assert_eq!(counter.load(Ordering::SeqCst), 2);
5448                });
5449            }
5450
5451            #[test]
5452            fn handles_null_data_field() {
5453                TOKIO_SHARED_RT.block_on(async {
5454                    let ws: Arc<WebsocketStreams> =
5455                        create_websocket_streams(Some("ws://example.com"), None);
5456                    let conn = &ws.common.connection_pool[0];
5457                    let called = Arc::new(AtomicUsize::new(0));
5458                    {
5459                        let mut st = conn.state.lock().await;
5460                        st.stream_callbacks.entry("n".into()).or_default().push(
5461                            (Box::new({
5462                                let c = called.clone();
5463                                move |data: &Value| {
5464                                    if data.is_null() {
5465                                        c.fetch_add(1, Ordering::SeqCst);
5466                                    }
5467                                }
5468                            }) as Box<dyn Fn(&Value) + Send + Sync>)
5469                                .into(),
5470                        );
5471                    }
5472                    let msg = json!({"stream":"n","data":null}).to_string();
5473                    ws.on_message(msg, conn.clone()).await;
5474                    assert_eq!(called.load(Ordering::SeqCst), 1);
5475                });
5476            }
5477
5478            #[test]
5479            fn with_invalid_json_does_not_panic() {
5480                TOKIO_SHARED_RT.block_on(async {
5481                    let ws: Arc<WebsocketStreams> =
5482                        create_websocket_streams(Some("ws://example.com"), None);
5483                    let conn = &ws.common.connection_pool[0];
5484                    let bad = "not a json";
5485                    ws.on_message(bad.to_string(), conn.clone()).await;
5486                });
5487            }
5488
5489            #[test]
5490            fn without_stream_field_does_nothing() {
5491                TOKIO_SHARED_RT.block_on(async {
5492                    let ws: Arc<WebsocketStreams> =
5493                        create_websocket_streams(Some("ws://example.com"), None);
5494                    let conn = &ws.common.connection_pool[0];
5495                    let msg = json!({ "data": { "foo": 1 } }).to_string();
5496                    ws.on_message(msg, conn.clone()).await;
5497                });
5498            }
5499
5500            #[test]
5501            fn with_unregistered_stream_does_not_panic() {
5502                TOKIO_SHARED_RT.block_on(async {
5503                    let ws: Arc<WebsocketStreams> =
5504                        create_websocket_streams(Some("ws://example.com"), None);
5505                    let conn = &ws.common.connection_pool[0];
5506                    let msg = json!({
5507                        "stream": "nope",
5508                        "data": { "foo": 1 }
5509                    })
5510                    .to_string();
5511                    ws.on_message(msg, conn.clone()).await;
5512                });
5513            }
5514        }
5515
5516        mod get_reconnect_url {
5517            use super::*;
5518
5519            #[test]
5520            fn single_stream_reconnect_url() {
5521                TOKIO_SHARED_RT.block_on(async {
5522                    let ws: Arc<WebsocketStreams> =
5523                        create_websocket_streams(Some("ws://example.com"), None);
5524                    let c0 = ws.common.connection_pool[0].clone();
5525                    {
5526                        let mut map = ws.connection_streams.lock().await;
5527                        map.insert("s1".to_string(), c0.clone());
5528                    }
5529                    let url = ws.get_reconnect_url("default_url".into(), c0).await;
5530                    assert_eq!(url, "ws://example.com/stream?streams=s1");
5531                });
5532            }
5533
5534            #[test]
5535            fn multiple_streams_same_connection() {
5536                TOKIO_SHARED_RT.block_on(async {
5537                    let ws: Arc<WebsocketStreams> =
5538                        create_websocket_streams(Some("ws://example.com"), None);
5539                    let c0 = ws.common.connection_pool[0].clone();
5540                    {
5541                        let mut map = ws.connection_streams.lock().await;
5542                        map.insert("a".to_string(), c0.clone());
5543                        map.insert("b".to_string(), c0.clone());
5544                    }
5545                    let url = ws.get_reconnect_url("default_url".into(), c0).await;
5546                    let suffix = url
5547                        .strip_prefix("ws://example.com/stream?streams=")
5548                        .unwrap();
5549                    let parts: Vec<_> = suffix.split('&').next().unwrap().split('/').collect();
5550                    let set = parts.into_iter().collect::<std::collections::HashSet<_>>();
5551                    assert_eq!(set, ["a", "b"].iter().copied().collect());
5552                });
5553            }
5554
5555            #[test]
5556            fn reconnect_url_with_time_unit() {
5557                TOKIO_SHARED_RT.block_on(async {
5558                    let mut ws: Arc<WebsocketStreams> =
5559                        create_websocket_streams(Some("ws://example.com"), None);
5560                    Arc::get_mut(&mut ws).unwrap().configuration.time_unit =
5561                        Some(TimeUnit::Microsecond);
5562                    let c0 = ws.common.connection_pool[0].clone();
5563                    {
5564                        let mut map = ws.connection_streams.lock().await;
5565                        map.insert("x".to_string(), c0.clone());
5566                    }
5567                    let url = ws.get_reconnect_url("default_url".into(), c0).await;
5568                    assert_eq!(
5569                        url,
5570                        "ws://example.com/stream?streams=x&timeUnit=microsecond"
5571                    );
5572                });
5573            }
5574        }
5575    }
5576
5577    mod websocket_stream {
5578        use super::*;
5579
5580        mod on {
5581            use super::*;
5582
5583            #[test]
5584            fn registers_callback_and_stream_callback_for_websocket_streams() {
5585                TOKIO_SHARED_RT.block_on(async {
5586                    let ws_base = create_websocket_streams(Some("example.com"), None);
5587                    let stream_name = "s1".to_string();
5588                    let conn = ws_base.common.connection_pool[0].clone();
5589                    {
5590                        let mut map = ws_base.connection_streams.lock().await;
5591                        map.insert(stream_name.clone(), conn.clone());
5592                    }
5593                    {
5594                        let mut state = conn.state.lock().await;
5595                        state
5596                            .stream_callbacks
5597                            .insert(stream_name.clone(), Vec::new());
5598                    }
5599                    let stream = Arc::new(WebsocketStream::<Value> {
5600                        websocket_base: WebsocketBase::WebsocketStreams(ws_base.clone()),
5601                        stream_or_id: stream_name.clone(),
5602                        callback: Mutex::new(None),
5603                        id: None,
5604                        _phantom: PhantomData,
5605                    });
5606                    let called = Arc::new(Mutex::new(false));
5607                    let called_clone = called.clone();
5608                    stream
5609                        .on("message", move |v: Value| {
5610                            let mut lock = called_clone.blocking_lock();
5611                            *lock = v == Value::String("x".into());
5612                        })
5613                        .await;
5614                    let cb_guard = stream.callback.lock().await;
5615                    assert!(cb_guard.is_some());
5616                    let cbs = {
5617                        let state = conn.state.lock().await;
5618                        state.stream_callbacks.get(&stream_name).unwrap().clone()
5619                    };
5620                    assert_eq!(cbs.len(), 1);
5621                });
5622            }
5623
5624            #[test]
5625            fn message_twice_registers_two_wrappers_for_websocket_streams() {
5626                TOKIO_SHARED_RT.block_on(async {
5627                    let ws_base = create_websocket_streams(Some("example.com"), None);
5628                    let stream_name = "s2".to_string();
5629                    let conn = ws_base.common.connection_pool[0].clone();
5630                    {
5631                        let mut map = ws_base.connection_streams.lock().await;
5632                        map.insert(stream_name.clone(), conn.clone());
5633                    }
5634                    {
5635                        let mut state = conn.state.lock().await;
5636                        state
5637                            .stream_callbacks
5638                            .insert(stream_name.clone(), Vec::new());
5639                    }
5640                    let stream = Arc::new(WebsocketStream::<Value> {
5641                        websocket_base: WebsocketBase::WebsocketStreams(ws_base.clone()),
5642                        stream_or_id: stream_name.clone(),
5643                        callback: Mutex::new(None),
5644                        id: None,
5645                        _phantom: PhantomData,
5646                    });
5647                    stream.on("message", |_| {}).await;
5648                    stream.on("message", |_| {}).await;
5649                    let state = conn.state.lock().await;
5650                    let callbacks = state.stream_callbacks.get(&stream_name).unwrap();
5651                    assert_eq!(callbacks.len(), 2);
5652                });
5653            }
5654
5655            #[test]
5656            fn ignores_non_message_event_for_websocket_streams() {
5657                TOKIO_SHARED_RT.block_on(async {
5658                    let ws_base = create_websocket_streams(Some("example.com"), None);
5659                    let stream = Arc::new(WebsocketStream::<Value> {
5660                        websocket_base: WebsocketBase::WebsocketStreams(ws_base.clone()),
5661                        stream_or_id: "s".into(),
5662                        callback: Mutex::new(None),
5663                        id: None,
5664                        _phantom: PhantomData,
5665                    });
5666                    stream.on("open", |_| {}).await;
5667                    let guard = stream.callback.lock().await;
5668                    assert!(guard.is_none());
5669                });
5670            }
5671
5672            #[test]
5673            fn registers_callback_and_stream_callback_for_websocket_api() {
5674                TOKIO_SHARED_RT.block_on(async {
5675                    let ws_base = create_websocket_api(None);
5676
5677                    {
5678                        let mut stream_callbacks = ws_base.stream_callbacks.lock().await;
5679                        stream_callbacks.insert("id1".to_string(), Vec::new());
5680                    }
5681
5682                    let stream = Arc::new(WebsocketStream::<Value> {
5683                        websocket_base: WebsocketBase::WebsocketApi(ws_base.clone()),
5684                        stream_or_id: "id1".to_string(),
5685                        callback: Mutex::new(None),
5686                        id: None,
5687                        _phantom: PhantomData,
5688                    });
5689
5690                    let called = Arc::new(Mutex::new(false));
5691                    let called_clone = called.clone();
5692                    stream
5693                        .on("message", move |v: Value| {
5694                            let mut lock = called_clone.blocking_lock();
5695                            *lock = v == Value::String("x".into());
5696                        })
5697                        .await;
5698
5699                    let cb_guard = stream.callback.lock().await;
5700                    assert!(cb_guard.is_some());
5701
5702                    let stream_callbacks = ws_base.stream_callbacks.lock().await;
5703                    let callbacks = stream_callbacks.get("id1").unwrap();
5704                    assert_eq!(callbacks.len(), 1);
5705                });
5706            }
5707
5708            #[test]
5709            fn message_twice_registers_two_wrappers_for_websocket_api() {
5710                TOKIO_SHARED_RT.block_on(async {
5711                    let ws_base = create_websocket_api(None);
5712
5713                    let stream = Arc::new(WebsocketStream::<Value> {
5714                        websocket_base: WebsocketBase::WebsocketApi(ws_base.clone()),
5715                        stream_or_id: "id2".to_string(),
5716                        callback: Mutex::new(None),
5717                        id: None,
5718                        _phantom: PhantomData,
5719                    });
5720
5721                    stream.on("message", |_| {}).await;
5722                    stream.on("message", |_| {}).await;
5723
5724                    let stream_callbacks = ws_base.stream_callbacks.lock().await;
5725                    let callbacks = stream_callbacks.get("id2").unwrap();
5726                    assert_eq!(callbacks.len(), 2);
5727                });
5728            }
5729
5730            #[test]
5731            fn ignores_non_message_event_for_websocket_api() {
5732                TOKIO_SHARED_RT.block_on(async {
5733                    let ws_base = create_websocket_api(None);
5734
5735                    let stream = Arc::new(WebsocketStream::<Value> {
5736                        websocket_base: WebsocketBase::WebsocketApi(ws_base.clone()),
5737                        stream_or_id: "id3".into(),
5738                        callback: Mutex::new(None),
5739                        id: None,
5740                        _phantom: PhantomData,
5741                    });
5742
5743                    stream.on("open", |_| {}).await;
5744
5745                    let guard = stream.callback.lock().await;
5746                    assert!(guard.is_none());
5747
5748                    let stream_callbacks = ws_base.stream_callbacks.lock().await;
5749                    assert!(stream_callbacks.get("id3").is_none());
5750                    assert!(stream_callbacks.is_empty());
5751                });
5752            }
5753        }
5754
5755        mod on_message {
5756            use super::*;
5757
5758            #[test]
5759            fn on_message_registers_callback_for_websocket_streams() {
5760                TOKIO_SHARED_RT.block_on(async {
5761                    let ws_base = create_websocket_streams(Some("example.com"), None);
5762                    let stream_name = "s".to_string();
5763                    let conn = ws_base.common.connection_pool[0].clone();
5764                    {
5765                        let mut map = ws_base.connection_streams.lock().await;
5766                        map.insert(stream_name.clone(), conn.clone());
5767                    }
5768                    {
5769                        let mut state = conn.state.lock().await;
5770                        state
5771                            .stream_callbacks
5772                            .insert(stream_name.clone(), Vec::new());
5773                    }
5774                    let stream = Arc::new(WebsocketStream::<Value> {
5775                        websocket_base: WebsocketBase::WebsocketStreams(ws_base.clone()),
5776                        stream_or_id: stream_name.clone(),
5777                        callback: Mutex::new(None),
5778                        id: None,
5779                        _phantom: PhantomData,
5780                    });
5781                    stream.on_message(|_v| {});
5782                    let callbacks = &conn.state.lock().await.stream_callbacks[&stream_name];
5783                    assert_eq!(callbacks.len(), 1);
5784                });
5785            }
5786
5787            #[test]
5788            fn on_message_twice_registers_two_callbacks_for_websocket_streams() {
5789                TOKIO_SHARED_RT.block_on(async {
5790                    let ws_base = create_websocket_streams(Some("example.com"), None);
5791                    let stream_name = "s".to_string();
5792                    let conn = ws_base.common.connection_pool[0].clone();
5793                    {
5794                        let mut map = ws_base.connection_streams.lock().await;
5795                        map.insert(stream_name.clone(), conn.clone());
5796                    }
5797                    {
5798                        let mut state = conn.state.lock().await;
5799                        state
5800                            .stream_callbacks
5801                            .insert(stream_name.clone(), Vec::new());
5802                    }
5803                    let stream = Arc::new(WebsocketStream::<Value> {
5804                        websocket_base: WebsocketBase::WebsocketStreams(ws_base.clone()),
5805                        stream_or_id: stream_name.clone(),
5806                        callback: Mutex::new(None),
5807                        id: None,
5808                        _phantom: PhantomData,
5809                    });
5810                    stream.on_message(|_v| {});
5811                    stream.on_message(|_v| {});
5812                    let callbacks = &conn.state.lock().await.stream_callbacks[&stream_name];
5813                    assert_eq!(callbacks.len(), 2);
5814                });
5815            }
5816
5817            #[test]
5818            fn on_message_registers_callback_for_websocket_api() {
5819                TOKIO_SHARED_RT.block_on(async {
5820                    let ws_base = create_websocket_api(None);
5821                    let identifier = "id1".to_string();
5822
5823                    let stream = Arc::new(WebsocketStream::<Value> {
5824                        websocket_base: WebsocketBase::WebsocketApi(ws_base.clone()),
5825                        stream_or_id: identifier.clone(),
5826                        callback: Mutex::new(None),
5827                        id: None,
5828                        _phantom: PhantomData,
5829                    });
5830
5831                    stream.on_message(|_v: Value| {});
5832
5833                    let stream_callbacks = ws_base.stream_callbacks.lock().await;
5834                    let callbacks = stream_callbacks.get(&identifier).unwrap();
5835                    assert_eq!(callbacks.len(), 1);
5836                });
5837            }
5838
5839            #[test]
5840            fn on_message_twice_registers_two_callbacks_for_websocket_api() {
5841                TOKIO_SHARED_RT.block_on(async {
5842                    let ws_base = create_websocket_api(None);
5843                    let identifier = "id2".to_string();
5844
5845                    let stream = Arc::new(WebsocketStream::<Value> {
5846                        websocket_base: WebsocketBase::WebsocketApi(ws_base.clone()),
5847                        stream_or_id: identifier.clone(),
5848                        callback: Mutex::new(None),
5849                        id: None,
5850                        _phantom: PhantomData,
5851                    });
5852
5853                    stream.on_message(|_v: Value| {});
5854                    stream.on_message(|_v: Value| {});
5855
5856                    let stream_callbacks = ws_base.stream_callbacks.lock().await;
5857                    let callbacks = stream_callbacks.get(&identifier).unwrap();
5858                    assert_eq!(callbacks.len(), 2);
5859                });
5860            }
5861        }
5862
5863        mod unsubscribe {
5864            use super::*;
5865
5866            #[test]
5867            fn without_callback_does_nothing() {
5868                TOKIO_SHARED_RT.block_on(async {
5869                    let ws_base = create_websocket_streams(Some("example.com"), None);
5870                    let stream_name = "s1".to_string();
5871                    let conn = ws_base.common.connection_pool[0].clone();
5872                    {
5873                        let mut map = ws_base.connection_streams.lock().await;
5874                        map.insert(stream_name.clone(), conn.clone());
5875                    }
5876                    let mut state = conn.state.lock().await;
5877                    state.stream_callbacks.insert(stream_name.clone(), vec![]);
5878                    drop(state);
5879                    let stream = Arc::new(WebsocketStream::<Value> {
5880                        websocket_base: WebsocketBase::WebsocketStreams(ws_base.clone()),
5881                        stream_or_id: stream_name.clone(),
5882                        callback: Mutex::new(None),
5883                        id: None,
5884                        _phantom: PhantomData,
5885                    });
5886                    stream.unsubscribe().await;
5887                    let state = conn.state.lock().await;
5888                    assert!(state.stream_callbacks.contains_key(&stream_name));
5889                });
5890            }
5891
5892            #[test]
5893            fn removes_registered_callback_and_clears_state() {
5894                TOKIO_SHARED_RT.block_on(async {
5895                    let ws_base = create_websocket_streams(Some("example.com"), None);
5896                    let stream_name = "s2".to_string();
5897                    let conn = ws_base.common.connection_pool[0].clone();
5898                    {
5899                        let mut map = ws_base.connection_streams.lock().await;
5900                        map.insert(stream_name.clone(), conn.clone());
5901                    }
5902                    {
5903                        let mut state = conn.state.lock().await;
5904                        state
5905                            .stream_callbacks
5906                            .insert(stream_name.clone(), Vec::new());
5907                    }
5908                    let stream = Arc::new(WebsocketStream::<Value> {
5909                        websocket_base: WebsocketBase::WebsocketStreams(ws_base.clone()),
5910                        stream_or_id: stream_name.clone(),
5911                        callback: Mutex::new(None),
5912                        id: None,
5913                        _phantom: PhantomData,
5914                    });
5915                    stream.on("message", |_| {}).await;
5916                    {
5917                        let guard = stream.callback.lock().await;
5918                        assert!(guard.is_some());
5919                    }
5920                    stream.unsubscribe().await;
5921                    sleep(Duration::from_millis(10)).await;
5922                    let guard = stream.callback.lock().await;
5923                    assert!(guard.is_none());
5924                    let state = conn.state.lock().await;
5925                    assert!(
5926                        state
5927                            .stream_callbacks
5928                            .get(&stream_name)
5929                            .is_none_or(std::vec::Vec::is_empty)
5930                    );
5931                });
5932            }
5933
5934            #[test]
5935            fn without_callback_does_nothing_for_websocket_api() {
5936                TOKIO_SHARED_RT.block_on(async {
5937                    let ws_base = create_websocket_api(None);
5938                    let identifier = "id1".to_string();
5939
5940                    {
5941                        let mut stream_callbacks = ws_base.stream_callbacks.lock().await;
5942                        stream_callbacks.insert(identifier.clone(), Vec::new());
5943                    }
5944
5945                    let stream = Arc::new(WebsocketStream::<Value> {
5946                        websocket_base: WebsocketBase::WebsocketApi(ws_base.clone()),
5947                        stream_or_id: identifier.clone(),
5948                        callback: Mutex::new(None),
5949                        id: None,
5950                        _phantom: PhantomData,
5951                    });
5952
5953                    stream.unsubscribe().await;
5954
5955                    let stream_callbacks = ws_base.stream_callbacks.lock().await;
5956                    assert!(stream_callbacks.contains_key(&identifier));
5957                    let callbacks = stream_callbacks.get(&identifier).unwrap();
5958                    assert!(callbacks.is_empty());
5959                });
5960            }
5961
5962            #[test]
5963            fn removes_registered_callback_and_clears_state_for_websocket_api() {
5964                TOKIO_SHARED_RT.block_on(async {
5965                    let ws_base = create_websocket_api(None);
5966                    let identifier = "id2".to_string();
5967
5968                    {
5969                        let mut stream_callbacks = ws_base.stream_callbacks.lock().await;
5970                        stream_callbacks.insert(identifier.clone(), Vec::new());
5971                    }
5972
5973                    let stream = Arc::new(WebsocketStream::<Value> {
5974                        websocket_base: WebsocketBase::WebsocketApi(ws_base.clone()),
5975                        stream_or_id: identifier.clone(),
5976                        callback: Mutex::new(None),
5977                        id: None,
5978                        _phantom: PhantomData,
5979                    });
5980
5981                    stream.on("message", |_| {}).await;
5982
5983                    {
5984                        let stream_callbacks = ws_base.stream_callbacks.lock().await;
5985                        let callbacks = stream_callbacks
5986                            .get(&identifier)
5987                            .expect("Entry for 'id2' should exist");
5988                        assert_eq!(callbacks.len(), 1);
5989                    }
5990
5991                    stream.unsubscribe().await;
5992
5993                    {
5994                        let guard = stream.callback.lock().await;
5995                        assert!(guard.is_none());
5996                    }
5997
5998                    {
5999                        let stream_callbacks = ws_base.stream_callbacks.lock().await;
6000                        let callbacks = stream_callbacks
6001                            .get(&identifier)
6002                            .expect("Entry for 'id2' should still exist");
6003                        assert!(callbacks.is_empty());
6004                    }
6005                });
6006            }
6007        }
6008    }
6009
6010    mod create_stream_handler {
6011        use super::*;
6012
6013        #[test]
6014        fn create_stream_handler_without_id_registers_stream() {
6015            TOKIO_SHARED_RT.block_on(async {
6016                let ws = create_websocket_streams(Some("ws://example.com"), None);
6017                let stream_name = "foo".to_string();
6018                let handler = create_stream_handler::<serde_json::Value>(
6019                    WebsocketBase::WebsocketStreams(ws.clone()),
6020                    stream_name.clone(),
6021                    None,
6022                )
6023                .await;
6024                assert_eq!(handler.stream_or_id, stream_name);
6025                assert!(handler.id.is_none());
6026                let map = ws.connection_streams.lock().await;
6027                assert!(map.contains_key(&stream_name));
6028            });
6029        }
6030
6031        #[test]
6032        fn create_stream_handler_with_custom_id_registers_stream_and_id() {
6033            TOKIO_SHARED_RT.block_on(async {
6034                let ws = create_websocket_streams(Some("ws://example.com"), None);
6035                let stream_name = "bar".to_string();
6036                let custom_id = Some("my-custom-id".to_string());
6037                let handler = create_stream_handler::<serde_json::Value>(
6038                    WebsocketBase::WebsocketStreams(ws.clone()),
6039                    stream_name.clone(),
6040                    custom_id.clone(),
6041                )
6042                .await;
6043                assert_eq!(handler.stream_or_id, stream_name);
6044                assert_eq!(handler.id, custom_id);
6045                let map = ws.connection_streams.lock().await;
6046                assert!(map.contains_key(&stream_name));
6047            });
6048        }
6049
6050        #[test]
6051        fn create_stream_handler_without_id_registers_api_stream() {
6052            TOKIO_SHARED_RT.block_on(async {
6053                let ws_base = create_websocket_api(None);
6054                let identifier = "foo-api".to_string();
6055
6056                let handler = create_stream_handler::<Value>(
6057                    WebsocketBase::WebsocketApi(ws_base.clone()),
6058                    identifier.clone(),
6059                    None,
6060                )
6061                .await;
6062
6063                assert_eq!(handler.stream_or_id, identifier);
6064                assert!(handler.id.is_none());
6065            });
6066        }
6067
6068        #[test]
6069        fn create_stream_handler_with_custom_id_registers_api_stream_and_id() {
6070            TOKIO_SHARED_RT.block_on(async {
6071                let ws_base = create_websocket_api(None);
6072                let identifier = "bar-api".to_string();
6073                let custom_id = Some("custom-123".to_string());
6074
6075                let handler = create_stream_handler::<Value>(
6076                    WebsocketBase::WebsocketApi(ws_base.clone()),
6077                    identifier.clone(),
6078                    custom_id.clone(),
6079                )
6080                .await;
6081
6082                assert_eq!(handler.stream_or_id, identifier);
6083                assert_eq!(handler.id, custom_id);
6084            });
6085        }
6086    }
6087}