Skip to main content

binance_sdk/common/
websocket.rs

1use async_trait::async_trait;
2use flate2::read::ZlibDecoder;
3use futures::{SinkExt, StreamExt, future::try_join_all, stream::FuturesUnordered};
4use http::header::USER_AGENT;
5use serde::de::DeserializeOwned;
6use serde_json::{Value, json};
7use std::{
8    collections::{BTreeMap, HashMap, VecDeque},
9    io::Read,
10    marker::PhantomData,
11    mem::take,
12    sync::{
13        Arc,
14        atomic::{AtomicBool, AtomicUsize, Ordering},
15    },
16    time::Duration,
17};
18use tokio::{
19    net::TcpStream,
20    select, spawn,
21    sync::{
22        Mutex, Notify,
23        mpsc::{Receiver, Sender, UnboundedSender, channel, unbounded_channel},
24        oneshot,
25    },
26    task::JoinHandle,
27    time::{sleep, timeout},
28};
29use tokio_tungstenite::{
30    Connector, MaybeTlsStream, WebSocketStream, connect_async_tls_with_config,
31    tungstenite::{
32        Message,
33        client::IntoClientRequest,
34        protocol::{CloseFrame, WebSocketConfig, frame::coding::CloseCode},
35    },
36};
37use tokio_util::time::DelayQueue;
38use tracing::{debug, error, info, warn};
39
40use super::{
41    config::{AgentConnector, ConfigurationWebsocketApi, ConfigurationWebsocketStreams},
42    errors::{WebsocketConnectionFailureReason, WebsocketError},
43    models::{StreamId, WebsocketApiResponse, WebsocketEvent, WebsocketMode},
44    utils::{build_websocket_api_message, normalize_stream_id, random_string, validate_time_unit},
45};
46
47pub type WebSocketClient = WebSocketStream<MaybeTlsStream<TcpStream>>;
48
49const MAX_CONN_DURATION: Duration = Duration::from_secs(23 * 60 * 60);
50
51pub struct Subscription {
52    handle: JoinHandle<()>,
53}
54
55impl Subscription {
56    /// Cancels the ongoing WebSocket event subscription and stops the event processing task.
57    ///
58    /// This method aborts the background task responsible for receiving and processing
59    /// WebSocket events, effectively unsubscribing from further event notifications.
60    ///
61    /// # Examples
62    ///
63    ///
64    /// let emitter = `WebsocketEventEmitter::new()`;
65    /// let subscription = emitter.subscribe(|event| {
66    ///     // Handle WebSocket event
67    /// });
68    /// `subscription.unsubscribe()`; // Stop receiving events
69    ///
70    pub fn unsubscribe(self) {
71        self.handle.abort();
72    }
73}
74
75#[derive(Clone)]
76pub enum WebsocketBase {
77    WebsocketApi(Arc<WebsocketApi>),
78    WebsocketStreams(Arc<WebsocketStreams>),
79}
80
81pub struct WebsocketEventEmitter {
82    subscribers: Arc<std::sync::Mutex<Vec<UnboundedSender<WebsocketEvent>>>>,
83}
84
85impl Default for WebsocketEventEmitter {
86    fn default() -> Self {
87        Self::new()
88    }
89}
90
91impl WebsocketEventEmitter {
92    #[must_use]
93    pub fn new() -> Self {
94        Self {
95            subscribers: Arc::new(std::sync::Mutex::new(Vec::new())),
96        }
97    }
98
99    /// Subscribes to WebSocket events and returns a `Subscription` that allows event processing.
100    ///
101    /// This method creates an unbounded channel for receiving WebSocket events and
102    /// spawns an asynchronous task to process these events using the provided callback function.
103    ///
104    /// # Arguments
105    ///
106    /// * `callback` - A mutable function that will be called for each received WebSocket event.
107    ///   The callback must be thread-safe and have a static lifetime.
108    ///
109    /// # Returns
110    ///
111    /// A `Subscription` that can be used to unsubscribe and stop event processing.
112    ///
113    /// # Examples
114    ///
115    ///
116    /// let emitter = `WebsocketEventEmitter::new()`;
117    /// let subscription = emitter.subscribe(|event| {
118    ///     // Handle WebSocket event
119    ///     println!("Received event: {:?}", event);
120    /// });
121    ///
122    /// // Later, when no longer needed
123    /// `subscription.unsubscribe()`;
124    ///
125    pub fn subscribe<F>(&self, mut callback: F) -> Subscription
126    where
127        F: FnMut(WebsocketEvent) + Send + 'static,
128    {
129        let (tx, mut rx) = unbounded_channel();
130        let mut guard = match self.subscribers.lock() {
131            Ok(guard) => guard,
132            Err(poisoned) => poisoned.into_inner(),
133        };
134        guard.push(tx);
135        drop(guard);
136
137        let handle = spawn(async move {
138            while let Some(event) = rx.recv().await {
139                callback(event);
140            }
141        });
142        Subscription { handle }
143    }
144
145    /// Emits a WebSocket event to all registered subscribers.
146    ///
147    /// This method sends the given event to all active subscribers. If a subscriber
148    /// has been dropped without unsubscribing, a warning is logged and the subscriber
149    /// is removed from the list.
150    ///
151    /// # Arguments
152    ///
153    /// * `event` - The WebSocket event to be emitted to all subscribers.
154    pub fn emit(&self, event: &WebsocketEvent) {
155        let mut guard = match self.subscribers.lock() {
156            Ok(guard) => guard,
157            Err(poisoned) => poisoned.into_inner(),
158        };
159
160        guard.retain(|tx| {
161            if tx.send(event.clone()).is_ok() {
162                true
163            } else {
164                warn!("subscriber dropped without unsubscribing");
165                false
166            }
167        });
168    }
169}
170
171/// A trait defining the lifecycle and behavior of a WebSocket connection.
172///
173/// This trait provides methods for handling WebSocket connection events,
174/// including connection opening, message handling, and reconnection URL retrieval.
175///
176/// # Methods
177///
178/// * `on_open`: Called when a WebSocket connection is established
179/// * `on_message`: Called when a message is received over the WebSocket
180/// * `get_reconnect_url`: Determines the URL to use for reconnecting
181///
182/// # Thread Safety
183///
184/// Implementors must be safely shareable across threads, as indicated by the `Send + Sync + 'static` bounds.
185#[async_trait]
186pub trait WebsocketHandler: Send + Sync + 'static {
187    async fn on_open(&self, url: String, connection: Arc<WebsocketConnection>);
188    async fn on_message(&self, data: String, connection: Arc<WebsocketConnection>);
189    async fn get_reconnect_url(
190        &self,
191        default_url: String,
192        connection: Arc<WebsocketConnection>,
193    ) -> String;
194}
195
196pub struct PendingRequest {
197    pub completion: oneshot::Sender<Result<Value, WebsocketError>>,
198}
199
200#[derive(Clone)]
201pub struct WebsocketSessionLogonReq {
202    pub method: String,
203    pub payload: BTreeMap<String, Value>,
204    pub options: WebsocketMessageSendOptions,
205}
206
207pub struct WebsocketConnectionState {
208    pub reconnection_pending: bool,
209    pub renewal_pending: bool,
210    pub close_initiated: bool,
211    pub pending_requests: HashMap<String, PendingRequest>,
212    pub pending_subscriptions: VecDeque<String>,
213    pub stream_callbacks: HashMap<String, Vec<Arc<dyn Fn(&Value) + Send + Sync + 'static>>>,
214    pub is_session_logged_on: bool,
215    pub session_logon_req: Option<WebsocketSessionLogonReq>,
216    pub url_path: Option<String>,
217    pub handler: Option<Arc<dyn WebsocketHandler>>,
218    pub ws_write_tx: Option<UnboundedSender<Message>>,
219}
220
221impl Default for WebsocketConnectionState {
222    fn default() -> Self {
223        Self::new()
224    }
225}
226
227impl WebsocketConnectionState {
228    #[must_use]
229    pub fn new() -> Self {
230        Self {
231            reconnection_pending: false,
232            renewal_pending: false,
233            close_initiated: false,
234            pending_requests: HashMap::new(),
235            pending_subscriptions: VecDeque::new(),
236            stream_callbacks: HashMap::new(),
237            is_session_logged_on: false,
238            session_logon_req: None,
239            url_path: None,
240            handler: None,
241            ws_write_tx: None,
242        }
243    }
244}
245
246pub struct WebsocketConnection {
247    pub id: String,
248    pub drain_notify: Notify,
249    pub state: Mutex<WebsocketConnectionState>,
250}
251
252impl WebsocketConnection {
253    pub fn new(id: impl Into<String>) -> Arc<Self> {
254        Arc::new(Self {
255            id: id.into(),
256            drain_notify: Notify::new(),
257            state: Mutex::new(WebsocketConnectionState::new()),
258        })
259    }
260
261    pub async fn set_handler(&self, handler: Arc<dyn WebsocketHandler>) {
262        let mut conn_state = self.state.lock().await;
263        conn_state.handler = Some(handler);
264    }
265}
266
267struct ReconnectEntry {
268    connection_id: String,
269    url: String,
270    is_renewal: bool,
271}
272
273pub struct WebsocketCommon {
274    pub events: WebsocketEventEmitter,
275    mode: WebsocketMode,
276    round_robin_index: AtomicUsize,
277    connection_pool: Vec<Arc<WebsocketConnection>>,
278    reconnect_tx: Sender<ReconnectEntry>,
279    renewal_tx: Sender<(String, String)>,
280    reconnect_delay: usize,
281    agent: Option<AgentConnector>,
282    user_agent: Option<String>,
283}
284
285impl WebsocketCommon {
286    #[must_use]
287    pub fn new(
288        mut initial_pool: Vec<Arc<WebsocketConnection>>,
289        mode: WebsocketMode,
290        reconnect_delay: usize,
291        agent: Option<AgentConnector>,
292        user_agent: Option<String>,
293    ) -> Arc<Self> {
294        if initial_pool.is_empty() {
295            for _ in 0..mode.pool_size() {
296                let id = random_string();
297                initial_pool.push(WebsocketConnection::new(id));
298            }
299        }
300
301        let (reconnect_tx, reconnect_rx) = channel::<ReconnectEntry>(mode.pool_size());
302        let (renewal_tx, renewal_rx) = channel::<(String, String)>(mode.pool_size());
303
304        let common = Arc::new(Self {
305            events: WebsocketEventEmitter::new(),
306            mode,
307            round_robin_index: AtomicUsize::new(0),
308            connection_pool: initial_pool,
309            reconnect_tx,
310            renewal_tx,
311            reconnect_delay,
312            agent,
313            user_agent,
314        });
315
316        Self::spawn_reconnect_loop(Arc::clone(&common), reconnect_rx);
317        Self::spawn_renewal_loop(&Arc::clone(&common), renewal_rx);
318
319        common
320    }
321
322    /// Spawns an asynchronous loop to handle websocket reconnection attempts
323    ///
324    /// This method manages reconnection logic for websocket connections, including:
325    /// - Scheduling reconnects with a configurable delay
326    /// - Finding the appropriate connection in the connection pool
327    /// - Attempting to reinitialize the connection
328    /// - Logging reconnection failures or warnings
329    ///
330    /// # Arguments
331    /// * `common` - A shared reference to the `WebsocketCommon` instance
332    /// * `reconnect_rx` - A receiver channel for reconnection entries
333    ///
334    /// # Behavior
335    /// - Waits for reconnection entries from the channel
336    /// - Applies a configurable delay before attempting reconnection
337    /// - Attempts to reinitialize the connection with the provided URL
338    /// - Handles and logs any reconnection errors
339    fn spawn_reconnect_loop(common: Arc<Self>, mut reconnect_rx: Receiver<ReconnectEntry>) {
340        spawn(async move {
341            while let Some(entry) = reconnect_rx.recv().await {
342                info!("Scheduling reconnect for id {}", entry.connection_id);
343
344                if !entry.is_renewal {
345                    sleep(Duration::from_millis(common.reconnect_delay as u64)).await;
346                }
347
348                if let Some(conn_arc) = common
349                    .connection_pool
350                    .iter()
351                    .find(|c| c.id == entry.connection_id)
352                    .cloned()
353                {
354                    let common_clone = Arc::clone(&common);
355                    if let Err(err) = common_clone
356                        .init_connect(&entry.url, entry.is_renewal, Some(conn_arc.clone()))
357                        .await
358                    {
359                        error!(
360                            "Reconnect failed for {} → {}: {:?}",
361                            entry.connection_id, entry.url, err
362                        );
363                    }
364
365                    sleep(Duration::from_secs(1)).await;
366                } else {
367                    warn!("No connection {} found for reconnect", entry.connection_id);
368                }
369            }
370        });
371    }
372
373    /// Spawns an asynchronous loop to manage connection renewals
374    ///
375    /// This method handles the periodic renewal of websocket connections by:
376    /// - Maintaining a delay queue for connection expiration
377    /// - Receiving renewal requests for specific connections
378    /// - Triggering reconnection when a connection reaches its maximum duration
379    /// - Attempting to find and renew connections in the connection pool
380    ///
381    /// # Behavior
382    /// - Listens for renewal requests on a channel
383    /// - Tracks connection expiration using a delay queue
384    /// - Initiates reconnection process when a connection expires
385    /// - Handles and logs any renewal failures
386    fn spawn_renewal_loop(common: &Arc<Self>, renewal_rx: Receiver<(String, String)>) {
387        let common = Arc::clone(common);
388        spawn(async move {
389            let mut dq = DelayQueue::new();
390            let mut renewal_rx = renewal_rx;
391
392            loop {
393                select! {
394                    Some((conn_id, url)) = renewal_rx.recv() => {
395                        debug!("Scheduling renewal for {}", conn_id);
396                        dq.insert((conn_id, url), MAX_CONN_DURATION);
397                    }
398
399                    Some(expired) = dq.next() => {
400                        let (conn_id, default_url) = expired.into_inner();
401
402                        if let Some(conn_arc) = common
403                            .connection_pool
404                            .iter()
405                            .find(|c| c.id == conn_id)
406                            .cloned()
407                        {
408                            debug!("Renewing connection {}", conn_id);
409                            let url = common
410                                .get_reconnect_url(&default_url, Arc::clone(&conn_arc))
411                                .await;
412                            if let Err(e) = common.reconnect_tx.send(ReconnectEntry {
413                                connection_id: conn_id.clone(),
414                                url,
415                                is_renewal: true,
416                            }).await {
417                                error!(
418                                    "Failed to enqueue renewal for {}: {:?}",
419                                    conn_id, e
420                                );
421                            }
422                        } else {
423                            warn!("No connection {} found for renewal", conn_id);
424                        }
425                    }
426                }
427            }
428        });
429    }
430
431    /// Checks if a WebSocket connection is ready for use.
432    ///
433    /// # Arguments
434    ///
435    /// * `connection` - The WebSocket connection to check
436    /// * `allow_non_established` - If true, allows connections that are not fully established
437    ///
438    /// # Returns
439    ///
440    /// `true` if the connection is ready, `false` otherwise
441    ///
442    /// # Behavior
443    ///
444    /// A connection is considered ready if:
445    /// - It has a write channel (unless `allow_non_established` is true)
446    /// - No reconnection is pending
447    /// - No close has been initiated
448    pub async fn is_connection_ready(
449        &self,
450        connection: &WebsocketConnection,
451        allow_non_established: bool,
452    ) -> bool {
453        let conn_state = connection.state.lock().await;
454        (allow_non_established || conn_state.ws_write_tx.is_some())
455            && !conn_state.reconnection_pending
456            && !conn_state.close_initiated
457    }
458
459    /// Checks if a WebSocket connection is established.
460    ///
461    /// # Arguments
462    ///
463    /// * `connection` - Optional specific WebSocket connection to check
464    ///
465    /// # Returns
466    ///
467    /// `true` if a connection is ready and established, `false` otherwise
468    ///
469    /// # Behavior
470    ///
471    /// - If a specific connection is provided, checks only that connection
472    /// - If no connection is provided, checks all connections in the pool
473    /// - A connection is considered established if it is ready and not in a non-established state
474    async fn is_connected(&self, connection: Option<&Arc<WebsocketConnection>>) -> bool {
475        if let Some(conn_arc) = connection {
476            return self.is_connection_ready(conn_arc, false).await;
477        }
478
479        for conn_arc in &self.connection_pool {
480            if self.is_connection_ready(conn_arc, false).await {
481                return true;
482            }
483        }
484
485        false
486    }
487
488    /// Retrieves available WebSocket connections from the connection pool.
489    ///
490    /// # Arguments
491    ///
492    /// * `allow_non_established` - If `true`, includes connections that are not fully established
493    /// * `url_path` - Optional URL path to filter connections
494    ///
495    /// # Returns
496    ///
497    /// A vector of `Arc<WebsocketConnection>` that are ready based on the `allow_non_established` flag
498    ///
499    /// # Behavior
500    ///
501    /// - For single connection mode, returns the first connection
502    /// - For multi-connection mode, filters connections based on readiness
503    /// - Uses `is_connection_ready` to determine connection availability
504    async fn get_available_connections(
505        &self,
506        allow_non_established: bool,
507        url_path: Option<&str>,
508    ) -> Vec<Arc<WebsocketConnection>> {
509        if matches!(self.mode, WebsocketMode::Single) && url_path.is_none() {
510            return vec![Arc::clone(&self.connection_pool[0])];
511        }
512
513        let mut ready = Vec::new();
514        for conn in &self.connection_pool {
515            if self.is_connection_ready(conn, allow_non_established).await {
516                ready.push(Arc::clone(conn));
517            }
518        }
519
520        ready
521    }
522
523    /// Retrieves a WebSocket connection from the connection pool.
524    ///
525    /// # Arguments
526    ///
527    /// * `allow_non_established` - If `true`, allows selecting a connection that is not fully established
528    /// * `url_path` - Optional URL path to filter connections
529    ///
530    /// # Returns
531    ///
532    /// An `Arc` to a `WebsocketConnection` from the pool, selected using round-robin strategy
533    ///
534    /// # Errors
535    ///
536    /// Returns `WebsocketError::NotConnected` if no suitable connection is available
537    ///
538    /// # Behavior
539    ///
540    /// - For single connection mode, returns the first connection
541    /// - For multi-connection mode, selects a ready connection using round-robin
542    /// - Filters connections based on `allow_non_established` parameter
543    async fn get_connection(
544        &self,
545        allow_non_established: bool,
546        url_path: Option<&str>,
547    ) -> Result<Arc<WebsocketConnection>, WebsocketError> {
548        let candidates = self
549            .get_available_connections(allow_non_established, url_path)
550            .await;
551
552        let mut ready = Vec::new();
553        for conn in candidates {
554            if let Some(path) = url_path {
555                let st = conn.state.lock().await;
556                if st.url_path.as_deref() != Some(path) {
557                    continue;
558                }
559            }
560            ready.push(conn);
561        }
562
563        if ready.is_empty() {
564            return Err(WebsocketError::NotConnected);
565        }
566
567        let idx = self.round_robin_index.fetch_add(1, Ordering::Relaxed) % ready.len();
568
569        Ok(Arc::clone(&ready[idx]))
570    }
571
572    /// Gracefully closes a WebSocket connection by waiting for pending requests to complete.
573    ///
574    /// # Arguments
575    ///
576    /// * `ws_write_tx_to_close` - Sender channel for sending close message
577    /// * `connection` - Shared reference to the WebSocket connection
578    ///
579    /// # Behavior
580    ///
581    /// - Waits up to 30 seconds for all pending requests to complete
582    /// - Logs debug and warning messages during the closing process
583    /// - Sends a normal close frame to the WebSocket
584    ///
585    /// # Returns
586    ///
587    /// `Ok(())` if connection closes successfully, otherwise a `WebsocketError`
588    async fn close_connection_gracefully(
589        &self,
590        ws_write_tx_to_close: UnboundedSender<Message>,
591        connection: Arc<WebsocketConnection>,
592    ) -> Result<(), WebsocketError> {
593        debug!("Waiting for pending requests to complete before disconnecting.");
594
595        let drain = async {
596            loop {
597                {
598                    let conn_state = connection.state.lock().await;
599                    if conn_state.pending_requests.is_empty() {
600                        debug!("All pending requests completed, proceeding to close.");
601                        break;
602                    }
603                }
604                connection.drain_notify.notified().await;
605            }
606        };
607
608        if timeout(Duration::from_secs(30), drain).await.is_err() {
609            warn!("Timeout waiting for pending requests; forcing close.");
610        }
611
612        info!("Closing WebSocket connection for {}", connection.id);
613        let _ = ws_write_tx_to_close.send(Message::Close(Some(CloseFrame {
614            code: CloseCode::Normal,
615            reason: "".into(),
616        })));
617
618        Ok(())
619    }
620
621    /// Retrieves the URL to use for reconnecting to the WebSocket.
622    ///
623    /// # Arguments
624    ///
625    /// * `default_url` - The default URL to use if no custom reconnect URL is provided
626    /// * `connection` - A shared reference to the WebSocket connection
627    ///
628    /// # Returns
629    ///
630    /// The URL to use for reconnecting, either from a custom handler or the default URL
631    ///
632    /// # Behavior
633    ///
634    /// - Checks if a connection handler is available
635    /// - If a handler exists, calls its `get_reconnect_url` method
636    /// - Otherwise, returns the default URL
637    async fn get_reconnect_url(
638        &self,
639        default_url: &str,
640        connection: Arc<WebsocketConnection>,
641    ) -> String {
642        if let Some(handler) = {
643            let conn_state = connection.state.lock().await;
644            conn_state.handler.clone()
645        } {
646            return handler
647                .get_reconnect_url(default_url.to_string(), Arc::clone(&connection))
648                .await;
649        }
650
651        default_url.to_string()
652    }
653
654    /// Handles the WebSocket connection opening event.
655    ///
656    /// This method is called when a WebSocket connection is successfully established. It performs
657    /// the following key actions:
658    /// - Invokes the connection handler's `on_open` method if a handler is present
659    /// - Logs connection information
660    /// - Handles connection renewal and close scenarios
661    /// - Emits a WebSocket open event
662    ///
663    /// # Arguments
664    ///
665    /// * `url` - The URL of the WebSocket server
666    /// * `connection` - A shared reference to the WebSocket connection
667    /// * `old_ws_writer` - Optional previous WebSocket writer for graceful connection handling
668    ///
669    /// # Behavior
670    ///
671    /// - If a connection handler exists, calls its `on_open` method
672    /// - Checks for pending renewal or close states
673    /// - Closes the previous connection if renewal is in progress
674    /// - Emits an open event if the connection is successfully established
675    async fn on_open(
676        &self,
677        url: String,
678        connection: Arc<WebsocketConnection>,
679        old_ws_writer: Option<UnboundedSender<Message>>,
680    ) {
681        if let Some(handler) = {
682            let conn_state = connection.state.lock().await;
683            conn_state.handler.clone()
684        } {
685            handler.on_open(url.clone(), Arc::clone(&connection)).await;
686        }
687
688        let conn_id = &connection.id;
689        info!("Connected to WebSocket Server with id {}: {}", conn_id, url);
690
691        {
692            let mut conn_state = connection.state.lock().await;
693
694            if conn_state.renewal_pending {
695                conn_state.renewal_pending = false;
696                drop(conn_state);
697                if let Some(tx) = old_ws_writer {
698                    info!("Connection renewal in progress; closing previous connection.");
699                    let _ = self
700                        .close_connection_gracefully(tx, Arc::clone(&connection))
701                        .await;
702                }
703                return;
704            }
705
706            if conn_state.close_initiated {
707                drop(conn_state);
708                if let Some(tx) = connection.state.lock().await.ws_write_tx.clone() {
709                    info!("Close initiated; closing connection.");
710                    let _ = self
711                        .close_connection_gracefully(tx, Arc::clone(&connection))
712                        .await;
713                }
714                return;
715            }
716
717            self.events.emit(&WebsocketEvent::Open);
718        }
719    }
720
721    /// Handles an incoming WebSocket message
722    ///
723    /// # Arguments
724    ///
725    /// * `msg` - The received message as a string
726    /// * `connection` - A shared reference to the WebSocket connection
727    ///
728    /// # Behavior
729    ///
730    /// - If a connection handler exists, awaits its `on_message` method
731    /// - Emits a `WebsocketEvent::Message` event with the received message
732    async fn on_message(&self, msg: String, connection: Arc<WebsocketConnection>) {
733        let handler = connection.state.lock().await.handler.clone();
734        if let Some(handler) = handler {
735            handler
736                .on_message(msg.clone(), Arc::clone(&connection))
737                .await;
738        }
739        self.events.emit(&WebsocketEvent::Message(msg));
740    }
741
742    /// Creates a WebSocket connection with optional configuration and agent
743    ///
744    /// # Arguments
745    ///
746    /// * `url` - The WebSocket server URL to connect to
747    /// * `agent` - Optional agent connector for configuring the connection
748    /// * `user_agent` - Optional custom user agent string
749    ///
750    /// # Returns
751    ///
752    /// A `Result` containing the established WebSocket stream or a `WebsocketError`
753    ///
754    /// # Errors
755    ///
756    /// Returns a `WebsocketError` if:
757    /// - The WebSocket handshake fails
758    /// - The connection times out after 10 seconds
759    ///
760    /// # Behavior
761    ///
762    /// Attempts to establish a WebSocket connection with a configurable timeout,
763    /// supporting optional TLS, custom user agent, and connection connectors
764    async fn create_websocket(
765        url: &str,
766        agent: Option<AgentConnector>,
767        user_agent: Option<String>,
768    ) -> Result<WebSocketStream<MaybeTlsStream<tokio::net::TcpStream>>, WebsocketError> {
769        let mut req = url
770            .into_client_request()
771            .map_err(|e| WebsocketError::Handshake(e.to_string()))?;
772
773        if let Some(ua) = user_agent {
774            req.headers_mut().insert(USER_AGENT, ua.parse().unwrap());
775        }
776
777        let ws_config: Option<WebSocketConfig> = None;
778        let disable_nagle = false;
779        let connector: Option<Connector> = agent.map(|dbg| dbg.0);
780
781        let timeout_duration = Duration::from_secs(10);
782        let handshake = connect_async_tls_with_config(req, ws_config, disable_nagle, connector);
783        match timeout(timeout_duration, handshake).await {
784            Ok(Ok((ws_stream, response))) => {
785                debug!("WebSocket connected: {:?}", response);
786                Ok(ws_stream)
787            }
788            Ok(Err(e)) => {
789                let msg = e.to_string();
790                error!("WebSocket handshake failed: {}", msg);
791                Err(WebsocketError::Handshake(msg))
792            }
793            Err(_) => {
794                error!(
795                    "WebSocket connection timed out after {}s",
796                    timeout_duration.as_secs()
797                );
798                Err(WebsocketError::Timeout)
799            }
800        }
801    }
802
803    /// Connects to a WebSocket URL for all connections in the connection pool concurrently
804    ///
805    /// # Arguments
806    ///
807    /// * `url` - The WebSocket server URL to connect to
808    /// * `connections` - Optional specific connections to use, otherwise uses the entire pool
809    ///
810    /// # Returns
811    ///
812    /// A `Result` indicating whether all connections were successfully established
813    ///
814    /// # Errors
815    ///
816    /// Returns a `WebsocketError` if any connection in the pool fails to establish
817    ///
818    /// # Behavior
819    ///
820    /// Attempts to initialize a WebSocket connection for each connection in the pool
821    /// concurrently, logging successes and failures for each connection attempt
822    async fn connect_pool(
823        self: Arc<Self>,
824        url: &str,
825        connections: Option<Vec<Arc<WebsocketConnection>>>,
826    ) -> Result<(), WebsocketError> {
827        let pool: Vec<Arc<WebsocketConnection>> = match connections {
828            Some(v) => v,
829            None => self.connection_pool.clone(),
830        };
831
832        let mut tasks = FuturesUnordered::new();
833
834        for conn in pool {
835            let common = Arc::clone(&self);
836            let url = url.to_owned();
837
838            tasks.push(async move {
839                match common.init_connect(&url, false, Some(conn)).await {
840                    Ok(()) => {
841                        info!("Successfully connected to {}", url);
842                        Ok(())
843                    }
844                    Err(err) => {
845                        error!("Failed to connect to {}: {:?}", url, err);
846                        Err(err)
847                    }
848                }
849            });
850        }
851
852        while let Some(result) = tasks.next().await {
853            result?;
854        }
855
856        Ok(())
857    }
858
859    /// Initializes a WebSocket connection for a specific connection in the pool
860    ///
861    /// # Arguments
862    ///
863    /// * `url` - The WebSocket server URL to connect to
864    /// * `is_renewal` - Flag indicating whether this is a connection renewal attempt
865    /// * `connection` - Optional specific WebSocket connection to use, otherwise selects from the pool
866    ///
867    /// # Returns
868    ///
869    /// A `Result` indicating whether the connection was successfully established
870    ///
871    /// # Errors
872    ///
873    /// Returns a `WebsocketError` if the connection fails to initialize or establish
874    ///
875    /// # Behavior
876    ///
877    /// Handles connection establishment, splitting read/write streams, spawning reader/writer tasks,
878    /// and managing connection state including renewal, reconnection, and error handling
879    async fn init_connect(
880        self: Arc<Self>,
881        url: &str,
882        is_renewal: bool,
883        connection: Option<Arc<WebsocketConnection>>,
884    ) -> Result<(), WebsocketError> {
885        let conn = connection.unwrap_or(self.get_connection(true, None).await?);
886
887        {
888            let mut conn_state = conn.state.lock().await;
889            if conn_state.renewal_pending && is_renewal {
890                info!("Renewal in progress {}→{}", conn.id, url);
891                return Ok(());
892            }
893            if conn_state.ws_write_tx.is_some() && !is_renewal && !conn_state.reconnection_pending {
894                info!("Exists {}; skipping {}", conn.id, url);
895                return Ok(());
896            }
897            if is_renewal {
898                conn_state.renewal_pending = true;
899            }
900
901            conn_state.is_session_logged_on = false;
902        }
903
904        let ws = Self::create_websocket(url, self.agent.clone(), self.user_agent.clone())
905            .await
906            .map_err(|e| {
907                error!("Handshake failed {}: {:?}", url, e);
908                e
909            })?;
910
911        info!("Established {} → {}", conn.id, url);
912
913        if let Err(e) = self.renewal_tx.try_send((conn.id.clone(), url.to_string())) {
914            error!("Failed to schedule renewal for {}: {:?}", conn.id, e);
915        }
916
917        let (write_half, mut read_half) = ws.split();
918        let (tx, mut rx) = unbounded_channel::<Message>();
919
920        let old_writer = {
921            let mut conn_state = conn.state.lock().await;
922            conn_state.reconnection_pending = false;
923            conn_state.ws_write_tx.replace(tx.clone())
924        };
925
926        {
927            let wconn = conn.clone();
928            let common_clone = self.clone();
929            let writer_url = url.to_string();
930
931            spawn(async move {
932                let mut sink = write_half;
933                while let Some(msg) = rx.recv().await {
934                    if let Err(e) = sink.send(msg).await {
935                        let failure_reason =
936                            WebsocketConnectionFailureReason::from_tungstenite_error(&e);
937
938                        error!(
939                            "Write error on {}: {:?}, classified as {:?}",
940                            wconn.id, e, failure_reason
941                        );
942
943                        // Apply same reconnection logic as reader errors
944                        let mut conn_state = wconn.state.lock().await;
945                        if !conn_state.close_initiated
946                            && !is_renewal
947                            && failure_reason.should_reconnect()
948                        {
949                            info!(
950                                "Writer connection {} has recoverable error, attempting reconnection: {:?}",
951                                wconn.id, failure_reason
952                            );
953                            conn_state.reconnection_pending = true;
954                            conn_state.is_session_logged_on = false;
955                            drop(conn_state);
956                            let reconnect_url = common_clone
957                                .get_reconnect_url(&writer_url, Arc::clone(&wconn))
958                                .await;
959
960                            let _ = common_clone
961                                .reconnect_tx
962                                .send(ReconnectEntry {
963                                    connection_id: wconn.id.clone(),
964                                    url: reconnect_url,
965                                    is_renewal: false,
966                                })
967                                .await;
968                        } else {
969                            warn!(
970                                "Writer connection {} has permanent error, will not reconnect: {:?}",
971                                wconn.id, failure_reason
972                            );
973                        }
974
975                        break;
976                    }
977                }
978                debug!("Writer {} exit", wconn.id);
979            });
980        }
981
982        {
983            let common = self.clone();
984            let conn = conn.clone();
985            let url = url.to_string();
986            spawn(async move {
987                common.on_open(url, conn, old_writer).await;
988            });
989        }
990
991        {
992            let common = self.clone();
993            let reader_conn = conn.clone();
994            let read_url = url.to_string();
995
996            spawn(async move {
997                let mut stream_end_reason = None;
998                while let Some(item) = read_half.next().await {
999                    match item {
1000                        Ok(Message::Text(msg)) => {
1001                            common
1002                                .on_message(msg.to_string(), Arc::clone(&reader_conn))
1003                                .await;
1004                        }
1005                        Ok(Message::Binary(bin)) => {
1006                            let mut decoder = ZlibDecoder::new(&bin[..]);
1007                            let mut decompressed = String::new();
1008                            if let Err(err) = decoder.read_to_string(&mut decompressed) {
1009                                error!("Binary message decompress failed: {:?}", err);
1010                                continue;
1011                            }
1012                            common
1013                                .on_message(decompressed, Arc::clone(&reader_conn))
1014                                .await;
1015                        }
1016                        Ok(Message::Ping(payload)) => {
1017                            info!("PING received from server on {}", reader_conn.id);
1018                            common.events.emit(&WebsocketEvent::Ping);
1019                            if let Some(tx) = reader_conn.state.lock().await.ws_write_tx.clone() {
1020                                let _ = tx.send(Message::Pong(payload));
1021                                info!(
1022                                    "Responded PONG to server's PING message on {}",
1023                                    reader_conn.id
1024                                );
1025                            }
1026                        }
1027                        Ok(Message::Pong(_)) => {
1028                            info!("Received PONG from server on {}", reader_conn.id);
1029                            common.events.emit(&WebsocketEvent::Pong);
1030                        }
1031                        Ok(Message::Close(frame)) => {
1032                            let (code, reason) = frame
1033                                .map_or((1000, String::new()), |CloseFrame { code, reason }| {
1034                                    (code.into(), reason.to_string())
1035                                });
1036                            common
1037                                .events
1038                                .emit(&WebsocketEvent::Close(code, reason.clone()));
1039
1040                            // Classify the close reason
1041                            let user_initiated = {
1042                                let conn_state = reader_conn.state.lock().await;
1043                                conn_state.close_initiated
1044                            };
1045
1046                            let failure_reason = WebsocketConnectionFailureReason::from_close_code(
1047                                code,
1048                                user_initiated,
1049                            );
1050                            stream_end_reason = Some(failure_reason);
1051
1052                            info!(
1053                                "Connection {} received close frame: code={}, reason='{}', classified as {:?}",
1054                                reader_conn.id, code, reason, failure_reason
1055                            );
1056
1057                            let mut conn_state = reader_conn.state.lock().await;
1058                            if !conn_state.close_initiated
1059                                && !is_renewal
1060                                && failure_reason.should_reconnect()
1061                            {
1062                                info!(
1063                                    "Connection {} received close frame with reconnectable failure: {:?}",
1064                                    reader_conn.id, failure_reason
1065                                );
1066                                conn_state.reconnection_pending = true;
1067                                conn_state.is_session_logged_on = false;
1068                                drop(conn_state);
1069                                let reconnect_url = common
1070                                    .get_reconnect_url(&read_url, Arc::clone(&reader_conn))
1071                                    .await;
1072
1073                                let _ = common
1074                                    .reconnect_tx
1075                                    .send(ReconnectEntry {
1076                                        connection_id: reader_conn.id.clone(),
1077                                        url: reconnect_url,
1078                                        is_renewal: false,
1079                                    })
1080                                    .await;
1081                            } else {
1082                                warn!(
1083                                    "Connection {} received close frame with non-reconnectable failure: {:?}",
1084                                    reader_conn.id, failure_reason
1085                                );
1086
1087                                // Emit detailed error for permanent failures
1088                                common.events.emit(&WebsocketEvent::Error(format!(
1089                                    "[CRITICAL] Connection {} permanently failed: {:?}",
1090                                    reader_conn.id, failure_reason
1091                                )));
1092                            }
1093
1094                            break;
1095                        }
1096                        Err(e) => {
1097                            // Classify the error type for reconnection decision
1098                            let failure_reason =
1099                                WebsocketConnectionFailureReason::from_tungstenite_error(&e);
1100
1101                            stream_end_reason = Some(failure_reason);
1102                            error!(
1103                                "WebSocket error on {}: {:?}, classified as {:?}",
1104                                reader_conn.id, e, failure_reason
1105                            );
1106
1107                            common.events.emit(&WebsocketEvent::Error(e.to_string()));
1108
1109                            // Apply the same reconnection logic as Close frames
1110                            let mut conn_state = reader_conn.state.lock().await;
1111                            if !conn_state.close_initiated
1112                                && !is_renewal
1113                                && failure_reason.should_reconnect()
1114                            {
1115                                info!(
1116                                    "Connection {} has recoverable error, attempting reconnection: {:?}",
1117                                    reader_conn.id, failure_reason
1118                                );
1119                                conn_state.reconnection_pending = true;
1120                                conn_state.is_session_logged_on = false;
1121                                drop(conn_state);
1122                                let reconnect_url = common
1123                                    .get_reconnect_url(&read_url, Arc::clone(&reader_conn))
1124                                    .await;
1125
1126                                let _ = common
1127                                    .reconnect_tx
1128                                    .send(ReconnectEntry {
1129                                        connection_id: reader_conn.id.clone(),
1130                                        url: reconnect_url,
1131                                        is_renewal: false,
1132                                    })
1133                                    .await;
1134                            } else {
1135                                warn!(
1136                                    "Connection {} has permanent error, will not reconnect: {:?}",
1137                                    reader_conn.id, failure_reason
1138                                );
1139
1140                                // Emit critical error for non-reconnectable failures
1141                                common.events.emit(&WebsocketEvent::Error(format!(
1142                                    "[CRITICAL] Connection {} permanently failed: {:?}",
1143                                    reader_conn.id, failure_reason
1144                                )));
1145                            }
1146
1147                            break;
1148                        }
1149                        _ => {}
1150                    }
1151                }
1152
1153                // Handle case where stream ends unexpectedly (e.g., network disconnection)
1154                info!("WebSocket stream ended for connection {}", reader_conn.id);
1155
1156                // Handle possibly unexpected stream end with same logic as other errors
1157                let failure_reason =
1158                    stream_end_reason.unwrap_or(WebsocketConnectionFailureReason::StreamEnded);
1159
1160                info!(
1161                    "WebSocket stream ended for connection {}, classified as {:?}",
1162                    reader_conn.id, failure_reason
1163                );
1164
1165                let mut conn_state = reader_conn.state.lock().await;
1166                if !conn_state.close_initiated && !is_renewal && failure_reason.should_reconnect() {
1167                    info!(
1168                        "Connection {} stream ended unexpectedly, attempting reconnection",
1169                        reader_conn.id
1170                    );
1171                    conn_state.reconnection_pending = true;
1172                    conn_state.is_session_logged_on = false;
1173                    conn_state.ws_write_tx = None;
1174                    drop(conn_state);
1175                    let reconnect_url = common
1176                        .get_reconnect_url(&read_url, Arc::clone(&reader_conn))
1177                        .await;
1178
1179                    let _ = common
1180                        .reconnect_tx
1181                        .send(ReconnectEntry {
1182                            connection_id: reader_conn.id.clone(),
1183                            url: reconnect_url,
1184                            is_renewal: false,
1185                        })
1186                        .await;
1187                } else {
1188                    debug!(
1189                        "Connection {} stream ended normally (close_initiated={}, is_renewal={})",
1190                        reader_conn.id, conn_state.close_initiated, is_renewal
1191                    );
1192                }
1193
1194                debug!("Reader actor for {} exiting", reader_conn.id);
1195            });
1196        }
1197
1198        Ok(())
1199    }
1200    /// Gracefully disconnects all active WebSocket connections.
1201    ///
1202    /// This method attempts to close all connections in the connection pool within a 30-second timeout.
1203    /// It marks each connection as close-initiated and attempts to close them gracefully.
1204    ///
1205    /// # Returns
1206    ///
1207    /// - `Ok(())` if all connections are successfully closed
1208    /// - `Err(WebsocketError)` if there are errors during disconnection or a timeout occurs
1209    ///
1210    /// # Errors
1211    ///
1212    /// Returns `WebsocketError::Timeout` if disconnection takes longer than 30 seconds
1213    ///
1214    async fn disconnect(&self) -> Result<(), WebsocketError> {
1215        if !self.is_connected(None).await {
1216            warn!("No active connection to close.");
1217            return Ok(());
1218        }
1219
1220        let mut shutdowns = FuturesUnordered::new();
1221        for conn in &self.connection_pool {
1222            {
1223                let mut conn_state = conn.state.lock().await;
1224                conn_state.close_initiated = true;
1225                if let Some(tx) = &conn_state.ws_write_tx {
1226                    shutdowns.push(self.close_connection_gracefully(tx.clone(), Arc::clone(conn)));
1227                }
1228            }
1229        }
1230
1231        let close_all = async {
1232            while let Some(result) = shutdowns.next().await {
1233                result?;
1234            }
1235            Ok::<(), WebsocketError>(())
1236        };
1237
1238        match timeout(Duration::from_secs(30), close_all).await {
1239            Ok(Ok(())) => {
1240                info!("Disconnected all WebSocket connections successfully.");
1241                for conn in &self.connection_pool {
1242                    let mut st = conn.state.lock().await;
1243                    st.is_session_logged_on = false;
1244                    st.session_logon_req = None;
1245                }
1246                Ok(())
1247            }
1248            Ok(Err(err)) => {
1249                error!("Error while disconnecting: {:?}", err);
1250                Err(err)
1251            }
1252            Err(_) => {
1253                error!("Timed out while disconnecting WebSocket connections.");
1254                Err(WebsocketError::Timeout)
1255            }
1256        }
1257    }
1258
1259    /// Sends a PING message to all ready WebSocket connections.
1260    ///
1261    /// This method iterates through the connection pool, identifies ready connections,
1262    /// and sends a PING message to each of them. It logs the number of connections
1263    /// being pinged and handles any send errors individually.
1264    ///
1265    /// # Behavior
1266    ///
1267    /// - Skips connections that are not ready
1268    /// - Logs a warning if no connections are ready
1269    /// - Sends PING messages concurrently
1270    /// - Logs debug/error messages for each PING attempt
1271    async fn ping_server(&self) {
1272        let mut ready = Vec::new();
1273        for conn in &self.connection_pool {
1274            if self.is_connection_ready(conn, false).await {
1275                let id = conn.id.clone();
1276                let ws_write_tx = {
1277                    let conn_state = conn.state.lock().await;
1278                    conn_state.ws_write_tx.clone()
1279                };
1280                ready.push((id, ws_write_tx));
1281            }
1282        }
1283
1284        if ready.is_empty() {
1285            warn!("No ready connections for PING.");
1286            return;
1287        }
1288        info!("Sending PING to {} WebSocket connections.", ready.len());
1289
1290        let mut tasks = FuturesUnordered::new();
1291        for (id, ws_write_tx_opt) in ready {
1292            if let Some(tx) = ws_write_tx_opt {
1293                tasks.push(async move {
1294                    if let Err(e) = tx.send(Message::Ping(Vec::new().into())) {
1295                        error!("Failed to send PING to {}: {:?}", id, e);
1296                    } else {
1297                        debug!("Sent PING to connection {}", id);
1298                    }
1299                });
1300            } else {
1301                error!("Connection {} was ready but has no write channel", id);
1302            }
1303        }
1304
1305        while tasks.next().await.is_some() {}
1306    }
1307
1308    /// Sends a WebSocket message and optionally waits for a reply.
1309    ///
1310    /// # Arguments
1311    ///
1312    /// * `payload` - The message payload to send
1313    /// * `id` - Optional request identifier, required when waiting for a reply
1314    /// * `wait_for_reply` - Whether to wait for a response to the message
1315    /// * `timeout` - Maximum duration to wait for a reply
1316    /// * `connection` - Optional specific WebSocket connection to use
1317    ///
1318    /// # Returns
1319    ///
1320    /// A receiver for the response if `wait_for_reply` is true, otherwise `None`
1321    ///
1322    /// # Errors
1323    ///
1324    /// Returns a `WebsocketError` if the connection is not ready or the send fails
1325    async fn send(
1326        &self,
1327        payload: String,
1328        id: Option<String>,
1329        wait_for_reply: bool,
1330        timeout: Duration,
1331        connection: Option<Arc<WebsocketConnection>>,
1332    ) -> Result<Option<oneshot::Receiver<Result<Value, WebsocketError>>>, WebsocketError> {
1333        let conn = if let Some(c) = connection {
1334            c
1335        } else {
1336            self.get_connection(false, None).await?
1337        };
1338
1339        if !self.is_connected(Some(&conn)).await {
1340            warn!("Send attempted on a non-connected socket");
1341            return Err(WebsocketError::NotConnected);
1342        }
1343
1344        let ws_write_tx = {
1345            let conn_state = conn.state.lock().await;
1346            conn_state
1347                .ws_write_tx
1348                .clone()
1349                .ok_or(WebsocketError::NotConnected)?
1350        };
1351
1352        let pending_setup = if wait_for_reply {
1353            let request_id = id.ok_or_else(|| {
1354                error!("id is required when waiting for a reply");
1355                WebsocketError::NotConnected
1356            })?;
1357
1358            let (tx, rx) = oneshot::channel();
1359            {
1360                let mut conn_state = conn.state.lock().await;
1361                conn_state
1362                    .pending_requests
1363                    .insert(request_id.clone(), PendingRequest { completion: tx });
1364            }
1365
1366            Some((request_id, rx))
1367        } else {
1368            None
1369        };
1370
1371        debug!("Sending message to WebSocket on connection {}", conn.id);
1372
1373        if ws_write_tx
1374            .send(Message::Text(payload.clone().into()))
1375            .is_err()
1376        {
1377            if let Some((request_id, _)) = &pending_setup {
1378                let mut conn_state = conn.state.lock().await;
1379                conn_state.pending_requests.remove(request_id);
1380            }
1381            return Err(WebsocketError::NotConnected);
1382        }
1383
1384        let rx = if let Some((request_id, rx)) = pending_setup {
1385            let conn_clone = Arc::clone(&conn);
1386            let timeout_id = request_id.clone();
1387            spawn(async move {
1388                sleep(timeout).await;
1389                let mut conn_state = conn_clone.state.lock().await;
1390                if let Some(pending_req) = conn_state.pending_requests.remove(&timeout_id) {
1391                    let _ = pending_req.completion.send(Err(WebsocketError::Timeout));
1392                }
1393            });
1394            Some(rx)
1395        } else {
1396            None
1397        };
1398
1399        Ok(rx)
1400    }
1401}
1402
1403#[derive(Debug, Default, Clone)]
1404pub struct WebsocketMessageSendOptions {
1405    pub with_api_key: bool,
1406    pub is_signed: bool,
1407    pub is_session_logon: Option<bool>,
1408    pub is_session_logout: Option<bool>,
1409}
1410
1411impl WebsocketMessageSendOptions {
1412    #[must_use]
1413    pub fn new() -> Self {
1414        Self::default()
1415    }
1416
1417    #[must_use]
1418    pub fn with_api_key(mut self) -> Self {
1419        self.with_api_key = true;
1420        self
1421    }
1422
1423    #[must_use]
1424    pub fn signed(mut self) -> Self {
1425        self.is_signed = true;
1426        self
1427    }
1428
1429    #[must_use]
1430    pub fn session_logon(mut self) -> Self {
1431        self.is_session_logon = Some(true);
1432        self
1433    }
1434
1435    #[must_use]
1436    pub fn session_logout(mut self) -> Self {
1437        self.is_session_logout = Some(true);
1438        self
1439    }
1440}
1441
1442#[derive(Debug)]
1443pub enum SendWebsocketMessageResult<R> {
1444    Single(WebsocketApiResponse<R>),
1445    Multiple(Vec<WebsocketApiResponse<R>>),
1446}
1447
1448impl<R> IntoIterator for SendWebsocketMessageResult<R> {
1449    type Item = WebsocketApiResponse<R>;
1450    type IntoIter = std::vec::IntoIter<Self::Item>;
1451
1452    fn into_iter(self) -> Self::IntoIter {
1453        match self {
1454            SendWebsocketMessageResult::Single(resp) => vec![resp].into_iter(),
1455            SendWebsocketMessageResult::Multiple(v) => v.into_iter(),
1456        }
1457    }
1458}
1459
1460pub struct WebsocketApi {
1461    pub common: Arc<WebsocketCommon>,
1462    configuration: ConfigurationWebsocketApi,
1463    is_connecting: Arc<Mutex<bool>>,
1464    stream_callbacks: Mutex<HashMap<String, Vec<Arc<dyn Fn(&Value) + Send + Sync + 'static>>>>,
1465}
1466
1467impl WebsocketApi {
1468    #[must_use]
1469    /// Creates a new WebSocket API instance with the given configuration and connection pool.
1470    ///
1471    /// # Arguments
1472    ///
1473    /// * `configuration` - Configuration settings for the WebSocket API
1474    /// * `connection_pool` - A vector of WebSocket connections to be used
1475    ///
1476    /// # Returns
1477    ///
1478    /// An `Arc`-wrapped `WebsocketApi` instance ready for use
1479    ///
1480    /// # Panics
1481    ///
1482    /// This function will panic if the configuration is not valid.
1483    ///
1484    /// # Examples
1485    ///
1486    ///
1487    /// let api = `WebsocketApi::new(config`, `connection_pool`);
1488    ///
1489    pub fn new(
1490        configuration: ConfigurationWebsocketApi,
1491        connection_pool: Vec<Arc<WebsocketConnection>>,
1492    ) -> Arc<Self> {
1493        let agent_clone = configuration.agent.clone();
1494        let user_agent_clone = configuration.user_agent.clone();
1495        let common = WebsocketCommon::new(
1496            connection_pool,
1497            configuration.mode.clone(),
1498            usize::try_from(configuration.reconnect_delay)
1499                .expect("reconnect_delay should fit in usize"),
1500            agent_clone,
1501            Some(user_agent_clone),
1502        );
1503
1504        Arc::new(Self {
1505            common: Arc::clone(&common),
1506            configuration,
1507            is_connecting: Arc::new(Mutex::new(false)),
1508            stream_callbacks: Mutex::new(HashMap::new()),
1509        })
1510    }
1511
1512    /// Connects to a WebSocket server with a configurable timeout and connection handling.
1513    ///
1514    /// This method attempts to establish a WebSocket connection if not already connected.
1515    /// It prevents multiple simultaneous connection attempts and supports a connection pool.
1516    ///
1517    /// # Errors
1518    ///
1519    /// Returns a `WebsocketError` if:
1520    /// - Connection fails
1521    /// - Connection times out after 10 seconds
1522    ///
1523    /// # Behavior
1524    ///
1525    /// - Checks if already connected and returns early if so
1526    /// - Prevents multiple concurrent connection attempts
1527    /// - Sets a WebSocket handler for the connection pool
1528    /// - Attempts to connect with a 10-second timeout
1529    ///
1530    /// # Returns
1531    ///
1532    /// `Ok(())` if connection is successful, otherwise a `WebsocketError`
1533    pub async fn connect(self: Arc<Self>) -> Result<(), WebsocketError> {
1534        if self.common.is_connected(None).await {
1535            info!("WebSocket connection already established");
1536            return Ok(());
1537        }
1538
1539        {
1540            let mut flag = self.is_connecting.lock().await;
1541            if *flag {
1542                info!("Already connecting...");
1543                return Ok(());
1544            }
1545            *flag = true;
1546        }
1547
1548        let url = self.prepare_url(self.configuration.ws_url.as_deref().unwrap_or_default());
1549
1550        let handler: Arc<dyn WebsocketHandler> = self.clone();
1551        for slot in &self.common.connection_pool {
1552            slot.set_handler(handler.clone()).await;
1553        }
1554
1555        let result = select! {
1556            () = sleep(Duration::from_secs(10)) => Err(WebsocketError::Timeout),
1557            r = self.common.clone().connect_pool(&url, None) => r,
1558        };
1559
1560        {
1561            let mut flag = self.is_connecting.lock().await;
1562            *flag = false;
1563        }
1564
1565        result
1566    }
1567
1568    /// Disconnects the WebSocket connection.
1569    ///
1570    /// # Returns
1571    ///
1572    /// `Ok(())` if disconnection is successful, otherwise a `WebsocketError`
1573    ///
1574    /// # Errors
1575    ///
1576    /// Returns a `WebsocketError` if:
1577    /// - Disconnection fails
1578    /// - Connection is not established
1579    ///
1580    pub async fn disconnect(&self) -> Result<(), WebsocketError> {
1581        self.common.disconnect().await
1582    }
1583
1584    /// Checks if the WebSocket connection is currently established.
1585    ///
1586    /// # Returns
1587    ///
1588    /// `true` if the connection is active, `false` otherwise.
1589    pub async fn is_connected(&self) -> bool {
1590        self.common.is_connected(None).await
1591    }
1592
1593    /// Sends a ping to the WebSocket server to maintain the connection.
1594    ///
1595    /// This method calls the underlying connection's ping mechanism to check
1596    /// and keep the WebSocket connection alive.
1597    pub async fn ping_server(&self) {
1598        self.common.ping_server().await;
1599    }
1600
1601    /// Sends a WebSocket message with the specified method and payload.
1602    ///
1603    /// This method prepares and sends a WebSocket request with optional API key and signature.
1604    /// It handles connection status, generates a unique request ID, and processes the response.
1605    ///
1606    /// # Arguments
1607    ///
1608    /// * `method` - The WebSocket API method to be called
1609    /// * `payload` - A map of parameters to be sent with the request
1610    /// * `options` - Configuration options for message sending (API key, signing)
1611    ///
1612    /// # Returns
1613    ///
1614    /// A deserialized response of type `R` or a `WebsocketError` if the request fails
1615    ///
1616    /// # Panics
1617    ///
1618    /// Panics if:
1619    ///
1620    /// - The WebSocket is not connected
1621    /// - The request cannot be processed
1622    /// - No response is received within the timeout
1623    ///
1624    /// # Errors
1625    ///
1626    /// Returns `WebsocketError` if:
1627    /// - The WebSocket is not connected
1628    /// - The request cannot be processed
1629    /// - No response is received within the timeout
1630    pub async fn send_message<R>(
1631        &self,
1632        method: &str,
1633        payload: BTreeMap<String, Value>,
1634        options: WebsocketMessageSendOptions,
1635    ) -> Result<SendWebsocketMessageResult<R>, WebsocketError>
1636    where
1637        R: DeserializeOwned + Send + Sync + 'static,
1638    {
1639        if !self.common.is_connected(None).await {
1640            return Err(WebsocketError::NotConnected);
1641        }
1642
1643        let do_multi =
1644            options.is_session_logon.unwrap_or(false) || options.is_session_logout.unwrap_or(false);
1645
1646        let connections = if do_multi {
1647            self.common.get_available_connections(false, None).await
1648        } else {
1649            vec![self.common.get_connection(false, None).await?]
1650        };
1651
1652        let skip_auth = if do_multi {
1653            false
1654        } else {
1655            let connection = &connections[0];
1656            let conn_state = connection.state.lock().await;
1657            self.configuration.auto_session_relogon && conn_state.is_session_logged_on
1658        };
1659
1660        let payload_clone = payload.clone();
1661
1662        let (id, request) =
1663            build_websocket_api_message(&self.configuration, method, payload, &options, skip_auth);
1664        let raw_payload = serde_json::to_string(&request).unwrap();
1665        debug!("Sending message to WebSocket API: {:?}", request);
1666
1667        let timeout = Duration::from_millis(self.configuration.timeout);
1668
1669        let mut receivers = Vec::with_capacity(connections.len());
1670        for connection in &connections {
1671            let opt_rx = self
1672                .common
1673                .send(
1674                    raw_payload.clone(),
1675                    Some(id.clone()),
1676                    true,
1677                    timeout,
1678                    Some(connection.clone()),
1679                )
1680                .await?;
1681            receivers.push((connection.clone(), opt_rx));
1682        }
1683
1684        let mut raw_msgs = Vec::with_capacity(receivers.len());
1685        for (_conn, opt_rx) in receivers {
1686            let rx = opt_rx.ok_or(WebsocketError::NoResponse)?;
1687            let msg = rx.await.unwrap_or(Err(WebsocketError::Timeout))?;
1688            raw_msgs.push(msg);
1689        }
1690
1691        let mut responses = Vec::with_capacity(raw_msgs.len());
1692        for msg in raw_msgs {
1693            let raw = msg
1694                .get("result")
1695                .or_else(|| msg.get("response"))
1696                .cloned()
1697                .unwrap_or(Value::Null);
1698
1699            let rate_limits = msg
1700                .get("rateLimits")
1701                .and_then(Value::as_array)
1702                .map(|arr| {
1703                    arr.iter()
1704                        .filter_map(|v| serde_json::from_value(v.clone()).ok())
1705                        .collect()
1706                })
1707                .unwrap_or_default();
1708
1709            responses.push(WebsocketApiResponse {
1710                raw,
1711                rate_limits,
1712                _marker: PhantomData,
1713            });
1714        }
1715
1716        if do_multi && self.configuration.auto_session_relogon {
1717            for connection in &connections {
1718                let mut state = connection.state.lock().await;
1719                if options.is_session_logon.unwrap_or(false) {
1720                    state.is_session_logged_on = true;
1721                    state.session_logon_req = Some(WebsocketSessionLogonReq {
1722                        method: method.to_string(),
1723                        payload: payload_clone.clone(),
1724                        options: options.clone(),
1725                    });
1726                } else {
1727                    state.is_session_logged_on = false;
1728                    state.session_logon_req = None;
1729                }
1730            }
1731        }
1732
1733        Ok(if responses.len() == 1 && !do_multi {
1734            SendWebsocketMessageResult::Single(responses.into_iter().next().unwrap())
1735        } else {
1736            SendWebsocketMessageResult::Multiple(responses)
1737        })
1738    }
1739
1740    /// Prepares a WebSocket URL by appending a validated time unit parameter.
1741    ///
1742    /// This method checks if a time unit is configured and validates it. If valid,
1743    /// the time unit is appended to the URL as a query parameter. If no time unit
1744    /// is specified or the validation fails, the original URL is returned.
1745    ///
1746    /// # Arguments
1747    ///
1748    /// * `ws_url` - The base WebSocket URL to be modified
1749    ///
1750    /// # Returns
1751    ///
1752    /// A modified URL with the time unit parameter, or the original URL if no
1753    /// modification is possible
1754    fn prepare_url(&self, ws_url: &str) -> String {
1755        let mut url = ws_url.to_string();
1756
1757        let time_unit = match &self.configuration.time_unit {
1758            Some(u) => u.to_string(),
1759            None => return url,
1760        };
1761
1762        match validate_time_unit(&time_unit) {
1763            Ok(Some(validated)) => {
1764                let sep = if url.contains('?') { '&' } else { '?' };
1765                url.push(sep);
1766                url.push_str("timeUnit=");
1767                url.push_str(validated);
1768            }
1769            Ok(None) => {}
1770            Err(e) => {
1771                error!("Invalid time unit provided: {:?}", e);
1772            }
1773        }
1774
1775        url
1776    }
1777}
1778
1779#[async_trait]
1780impl WebsocketHandler for WebsocketApi {
1781    /// Handles the WebSocket connection opening event, attempting to re-establish a session logon if needed.
1782    ///
1783    /// This method checks if a session logon request exists and has not already been logged on.
1784    /// If conditions are met, it attempts to send a session re-logon message and update the connection state.
1785    ///
1786    /// # Arguments
1787    ///
1788    /// * `_url` - The WebSocket connection URL (unused)
1789    /// * `connection` - The WebSocket connection context
1790    ///
1791    /// # Behavior
1792    ///
1793    /// - Checks for an existing session logon request
1794    /// - Verifies the session is not already logged on
1795    /// - Attempts to send a re-logon message
1796    /// - Updates connection state upon successful re-logon
1797    /// - Logs errors if re-logon dispatch fails
1798    async fn on_open(&self, _url: String, connection: Arc<WebsocketConnection>) {
1799        let session_req = {
1800            let conn_state = connection.state.lock().await;
1801            conn_state.session_logon_req.clone()
1802        };
1803
1804        let Some(req) = session_req else {
1805            return;
1806        };
1807
1808        let already_logged_on = {
1809            let conn_state = connection.state.lock().await;
1810            conn_state.is_session_logged_on
1811        };
1812
1813        if already_logged_on {
1814            debug!(
1815                "Connection {} already logged on, skipping re-logon",
1816                connection.id
1817            );
1818            return;
1819        }
1820
1821        let conn = connection.clone();
1822        let common = Arc::clone(&self.common);
1823        let configuration = self.configuration.clone();
1824        let method = req.method.clone();
1825        let payload = req.payload.clone();
1826        let options = req.options.clone();
1827
1828        spawn(async move {
1829            let (id, json_msg) =
1830                build_websocket_api_message(&configuration, &method, payload, &options, false);
1831
1832            let raw_message = match serde_json::to_string(&json_msg) {
1833                Ok(msg) => msg,
1834                Err(e) => {
1835                    warn!(
1836                        "Failed to serialize session logon message for connection {}: {}",
1837                        conn.id, e
1838                    );
1839                    return;
1840                }
1841            };
1842
1843            debug!(
1844                "Session re-logon on connection {}: {}",
1845                conn.id, raw_message
1846            );
1847
1848            let rx = match common
1849                .send(
1850                    raw_message,
1851                    Some(id.clone()),
1852                    true,
1853                    Duration::from_millis(configuration.timeout),
1854                    Some(conn.clone()),
1855                )
1856                .await
1857            {
1858                Ok(Some(rx)) => rx,
1859                Ok(None) => {
1860                    warn!(
1861                        "Session re-logon dispatch returned None for connection {}",
1862                        conn.id
1863                    );
1864                    return;
1865                }
1866                Err(e) => {
1867                    warn!(
1868                        "Session re-logon dispatch failed on connection {}: {}",
1869                        conn.id, e
1870                    );
1871                    return;
1872                }
1873            };
1874
1875            let Ok(result) = timeout(Duration::from_millis(configuration.timeout), rx).await else {
1876                warn!("Session re-logon timed out on connection {}", conn.id);
1877                return;
1878            };
1879
1880            let final_result = match result {
1881                Ok(final_result) => final_result,
1882                Err(e) => {
1883                    warn!(
1884                        "Session re-logon receiver error on connection {}: {}",
1885                        conn.id, e
1886                    );
1887                    return;
1888                }
1889            };
1890
1891            let payload = match final_result {
1892                Ok(payload) => payload,
1893                Err(e) => {
1894                    warn!(
1895                        "Session re-logon payload error on connection {}: {}",
1896                        conn.id, e
1897                    );
1898                    return;
1899                }
1900            };
1901
1902            debug!(
1903                "Session re-logon succeeded on connection {}: {}",
1904                conn.id, payload
1905            );
1906            let mut conn_state = conn.state.lock().await;
1907            conn_state.is_session_logged_on = true;
1908        });
1909    }
1910
1911    /// Handles incoming WebSocket messages by parsing the JSON payload and processing pending requests.
1912    ///
1913    /// This method is responsible for:
1914    /// - Parsing the received WebSocket message as JSON
1915    /// - Matching the message to a pending request by its ID
1916    /// - Sending the response back to the original request's completion channel
1917    /// - Handling both successful and error responses
1918    ///
1919    /// # Arguments
1920    ///
1921    /// * `data` - The raw WebSocket message as a string
1922    /// * `connection` - The WebSocket connection context associated with the message
1923    ///
1924    /// # Behavior
1925    ///
1926    /// - If message parsing fails, logs an error and returns
1927    /// - For known request IDs, sends the response to the corresponding completion channel
1928    /// - Warns about responses for unknown or timed-out requests
1929    /// - Differentiates between successful (status < 400) and error responses
1930    async fn on_message(&self, data: String, connection: Arc<WebsocketConnection>) {
1931        let msg: Value = match serde_json::from_str(&data) {
1932            Ok(v) => v,
1933            Err(err) => {
1934                error!("Failed to parse WebSocket message {} – {}", data, err);
1935                return;
1936            }
1937        };
1938
1939        if let Some(id) = msg.get("id").and_then(Value::as_str) {
1940            let maybe_sender = {
1941                let mut conn_state = connection.state.lock().await;
1942                conn_state.pending_requests.remove(id)
1943            };
1944
1945            if let Some(PendingRequest { completion }) = maybe_sender {
1946                connection.drain_notify.notify_one();
1947                let status = msg.get("status").and_then(Value::as_u64).unwrap_or(200);
1948                if status >= 400 {
1949                    let error_map = msg
1950                        .get("error")
1951                        .and_then(Value::as_object)
1952                        .unwrap_or(&serde_json::Map::new())
1953                        .clone();
1954
1955                    let code = error_map
1956                        .get("code")
1957                        .and_then(Value::as_i64)
1958                        .unwrap_or(status.try_into().unwrap());
1959
1960                    let message = error_map
1961                        .get("msg")
1962                        .and_then(Value::as_str)
1963                        .unwrap_or("Unknown error")
1964                        .to_string();
1965
1966                    let _ = completion.send(Err(WebsocketError::ResponseError { code, message }));
1967                } else {
1968                    let _ = completion.send(Ok(msg.clone()));
1969                }
1970            }
1971
1972            return;
1973        }
1974
1975        if let Some(event) = msg.get("event") {
1976            if let Some(event_type) = event.get("e").and_then(Value::as_str) {
1977                if event_type == "serverShutdown" {
1978                    warn!(
1979                        "Received serverShutdown event on connection {}",
1980                        connection.id
1981                    );
1982
1983                    let mut conn_state = connection.state.lock().await;
1984
1985                    if !conn_state.renewal_pending && !conn_state.close_initiated {
1986                        conn_state.renewal_pending = true;
1987
1988                        let url = conn_state.url_path.clone().unwrap_or_default();
1989
1990                        drop(conn_state);
1991
1992                        if let Err(e) = self
1993                            .common
1994                            .reconnect_tx
1995                            .send(ReconnectEntry {
1996                                connection_id: connection.id.clone(),
1997                                url,
1998                                is_renewal: true,
1999                            })
2000                            .await
2001                        {
2002                            error!("Failed to enqueue serverShutdown renewal: {:?}", e);
2003                        }
2004                    }
2005
2006                    return;
2007                }
2008            }
2009        }
2010
2011        if let Some(event) = msg.get("event") {
2012            if event.get("e").is_some() {
2013                for callbacks in self.stream_callbacks.lock().await.values() {
2014                    for callback in callbacks {
2015                        callback(event);
2016                    }
2017                }
2018
2019                return;
2020            }
2021        }
2022
2023        warn!(
2024            "Received response for unknown or timed-out request: {}",
2025            data
2026        );
2027    }
2028
2029    /// Generates the URL to use for reconnecting to a WebSocket connection.
2030    ///
2031    /// # Arguments
2032    ///
2033    /// * `default_url` - The original URL to potentially modify for reconnection
2034    /// * `_connection` - The WebSocket connection context (currently unused)
2035    ///
2036    /// # Returns
2037    ///
2038    /// A `String` representing the URL to use for reconnecting
2039    async fn get_reconnect_url(
2040        &self,
2041        default_url: String,
2042        _connection: Arc<WebsocketConnection>,
2043    ) -> String {
2044        default_url
2045    }
2046}
2047
2048pub struct WebsocketStreams {
2049    pub common: Arc<WebsocketCommon>,
2050    pub stream_id_is_strictly_number: AtomicBool,
2051    url_paths: Vec<String>,
2052    is_connecting: Mutex<bool>,
2053    connection_streams: Mutex<HashMap<String, Arc<WebsocketConnection>>>,
2054    configuration: ConfigurationWebsocketStreams,
2055}
2056
2057impl WebsocketStreams {
2058    /// Creates a new `WebsocketStreams` instance with the given configuration and connection pool.
2059    ///
2060    /// # Arguments
2061    ///
2062    /// * `configuration` - Configuration settings for the WebSocket streams
2063    /// * `connection_pool` - A vector of WebSocket connections to use
2064    /// * `url_paths` - A vector of URL paths for the streams
2065    ///
2066    /// # Returns
2067    ///
2068    /// An `Arc`-wrapped `WebsocketStreams` instance
2069    ///
2070    /// # Panics
2071    ///
2072    /// Panics if the `reconnect_delay` cannot be converted to `usize`
2073    #[must_use]
2074    pub fn new(
2075        configuration: ConfigurationWebsocketStreams,
2076        mut connection_pool: Vec<Arc<WebsocketConnection>>,
2077        url_paths: Vec<String>,
2078    ) -> Arc<Self> {
2079        if !url_paths.is_empty() {
2080            let base_pool_size = configuration.mode.pool_size();
2081            let expected = base_pool_size * url_paths.len();
2082
2083            while connection_pool.len() < expected {
2084                connection_pool.push(WebsocketConnection::new(random_string()));
2085            }
2086        }
2087
2088        let agent_clone = configuration.agent.clone();
2089        let user_agent_clone = configuration.user_agent.clone();
2090        let common = WebsocketCommon::new(
2091            connection_pool,
2092            configuration.mode.clone(),
2093            usize::try_from(configuration.reconnect_delay)
2094                .expect("reconnect_delay should fit in usize"),
2095            agent_clone,
2096            Some(user_agent_clone),
2097        );
2098        Arc::new(Self {
2099            common,
2100            is_connecting: Mutex::new(false),
2101            connection_streams: Mutex::new(HashMap::new()),
2102            configuration,
2103            stream_id_is_strictly_number: AtomicBool::new(false),
2104            url_paths,
2105        })
2106    }
2107
2108    /// Establishes a WebSocket connection for the given streams.
2109    ///
2110    /// This method attempts to connect to a WebSocket server using the connection pool.
2111    /// If a connection is already established or in progress, it returns immediately.
2112    ///
2113    /// # Arguments
2114    ///
2115    /// * `streams` - A vector of stream identifiers to connect to
2116    ///
2117    /// # Returns
2118    ///
2119    /// A `Result` indicating whether the connection was successful or an error occurred
2120    ///
2121    /// # Errors
2122    ///
2123    /// Returns a `WebsocketError` if the connection fails or times out after 10 seconds
2124    pub async fn connect(self: Arc<Self>, streams: Vec<String>) -> Result<(), WebsocketError> {
2125        if self.common.is_connected(None).await {
2126            info!("WebSocket connection already established");
2127            return Ok(());
2128        }
2129
2130        {
2131            let mut flag = self.is_connecting.lock().await;
2132            if *flag {
2133                info!("Already connecting...");
2134                return Ok(());
2135            }
2136            *flag = true;
2137        }
2138
2139        let handler: Arc<dyn WebsocketHandler> = self.clone();
2140        for conn in &self.common.connection_pool {
2141            conn.set_handler(handler.clone()).await;
2142        }
2143
2144        let base_pool_size = self.configuration.mode.pool_size();
2145
2146        let connect_fut = async {
2147            if self.url_paths.is_empty() {
2148                let url = self.prepare_url(&streams, None);
2149                self.common.clone().connect_pool(&url, None).await
2150            } else {
2151                let mut futures = Vec::with_capacity(self.url_paths.len());
2152
2153                for (i, path) in self.url_paths.iter().enumerate() {
2154                    let start = i * base_pool_size;
2155
2156                    let subset: Vec<Arc<WebsocketConnection>> = self
2157                        .common
2158                        .connection_pool
2159                        .iter()
2160                        .skip(start)
2161                        .take(base_pool_size)
2162                        .cloned()
2163                        .collect();
2164
2165                    if subset.len() != base_pool_size {
2166                        return Err(WebsocketError::ServerError(format!(
2167                            "connection_pool too small for url_paths: need {} per path, got {} for path index {}",
2168                            base_pool_size,
2169                            subset.len(),
2170                            i
2171                        )));
2172                    }
2173
2174                    for c in &subset {
2175                        let mut st = c.state.lock().await;
2176                        st.url_path = Some(path.clone());
2177                    }
2178
2179                    let url = self.prepare_url(&streams, Some(path.as_str()));
2180                    let common = self.common.clone();
2181
2182                    futures.push(async move { common.connect_pool(&url, Some(subset)).await });
2183                }
2184
2185                try_join_all(futures).await?;
2186                Ok(())
2187            }
2188        };
2189
2190        let connect_res = select! {
2191            () = sleep(Duration::from_secs(10)) => Err(WebsocketError::Timeout),
2192            r = connect_fut => r,
2193        };
2194
2195        {
2196            let mut flag = self.is_connecting.lock().await;
2197            *flag = false;
2198        }
2199
2200        connect_res
2201    }
2202
2203    /// Disconnects all WebSocket connections and clears associated state.
2204    ///
2205    /// # Returns
2206    ///
2207    /// A `Result` indicating whether the disconnection was successful or an error occurred
2208    ///
2209    /// # Errors
2210    ///
2211    /// Returns a `WebsocketError` if there are issues during the disconnection process
2212    ///
2213    /// # Side Effects
2214    ///
2215    /// - Clears stream callbacks for all connections
2216    /// - Clears pending subscriptions for all connections
2217    /// - Removes all connection stream mappings
2218    pub async fn disconnect(&self) -> Result<(), WebsocketError> {
2219        for connection in &self.common.connection_pool {
2220            let mut conn_state = connection.state.lock().await;
2221            conn_state.stream_callbacks.clear();
2222            conn_state.pending_subscriptions.clear();
2223        }
2224        self.connection_streams.lock().await.clear();
2225        self.common.disconnect().await
2226    }
2227
2228    /// Checks if the WebSocket connection is currently active.
2229    ///
2230    /// # Returns
2231    ///
2232    /// `true` if the WebSocket connection is established, `false` otherwise.
2233    pub async fn is_connected(&self) -> bool {
2234        self.common.is_connected(None).await
2235    }
2236
2237    /// Sends a ping to the WebSocket server to maintain the connection.
2238    ///
2239    /// This method delegates the ping operation to the underlying common WebSocket connection.
2240    /// It is typically used to keep the connection alive and check its status.
2241    ///
2242    /// # Side Effects
2243    ///
2244    /// Sends a ping request to the WebSocket server through the common connection.
2245    pub async fn ping_server(&self) {
2246        self.common.ping_server().await;
2247    }
2248
2249    /// Subscribes to multiple WebSocket streams, handling connection and queuing logic.
2250    ///
2251    /// # Arguments
2252    ///
2253    /// * `streams` - A vector of stream names to subscribe to
2254    /// * `id` - An optional request identifier for the subscription
2255    /// * `url_path` - An optional URL path for the subscription
2256    ///
2257    /// # Behavior
2258    ///
2259    /// - Filters out streams already subscribed
2260    /// - Assigns streams to appropriate connections
2261    /// - Handles subscription for active connections
2262    /// - Queues subscriptions for inactive connections
2263    ///
2264    /// # Side Effects
2265    ///
2266    /// - Sends subscription payloads for active connections
2267    /// - Adds pending subscriptions for inactive connections
2268    pub async fn subscribe(
2269        self: Arc<Self>,
2270        streams: Vec<String>,
2271        id: Option<StreamId>,
2272        url_path: Option<&str>,
2273    ) {
2274        let streams: Vec<String> = {
2275            let map = self.connection_streams.lock().await;
2276            streams
2277                .into_iter()
2278                .filter(|s| {
2279                    let key = self.stream_key(s, url_path);
2280                    !map.contains_key(&key)
2281                })
2282                .collect()
2283        };
2284
2285        if streams.is_empty() {
2286            return;
2287        }
2288
2289        let connection_streams = self.handle_stream_assignment(streams, url_path).await;
2290
2291        for (conn, assigned_streams) in connection_streams {
2292            if !self.common.is_connected(Some(&conn)).await {
2293                info!(
2294                    "Connection {} is not ready. Queuing subscription for streams: {:?}",
2295                    conn.id, assigned_streams
2296                );
2297
2298                let mut conn_state = conn.state.lock().await;
2299                conn_state
2300                    .pending_subscriptions
2301                    .extend(assigned_streams.iter().cloned());
2302
2303                continue;
2304            }
2305
2306            self.send_subscription_payload(&conn, &assigned_streams, id.clone());
2307        }
2308    }
2309
2310    /// Unsubscribes from specified WebSocket streams.
2311    ///
2312    /// # Arguments
2313    ///
2314    /// * `streams` - A vector of stream names to unsubscribe from
2315    /// * `id` - An optional request identifier for the unsubscription
2316    /// * `url_path` - An optional URL path for the unsubscription
2317    ///
2318    /// # Behavior
2319    ///
2320    /// - Validates the request identifier or generates a random one
2321    /// - Checks for active connections and subscribed streams
2322    /// - Sends unsubscribe payload for streams with active callbacks
2323    /// - Removes stream from connection streams and callbacks
2324    ///
2325    /// # Side Effects
2326    ///
2327    /// - Sends unsubscribe request to WebSocket server
2328    /// - Removes stream tracking from internal state
2329    ///
2330    /// # Async
2331    ///
2332    /// This method is asynchronous and requires `.await` when called
2333    ///
2334    /// # Panics
2335    ///
2336    /// This method may panic if the request identifier is not valid.
2337    ///
2338    pub async fn unsubscribe(
2339        &self,
2340        streams: Vec<String>,
2341        id: Option<StreamId>,
2342        url_path: Option<&str>,
2343    ) {
2344        let request_id = normalize_stream_id(
2345            id.clone(),
2346            self.stream_id_is_strictly_number.load(Ordering::Relaxed),
2347        );
2348
2349        for stream in streams {
2350            let key = self.stream_key(&stream, url_path);
2351            let maybe_conn = { self.connection_streams.lock().await.get(&key).cloned() };
2352
2353            let conn = match maybe_conn {
2354                Some(c) => {
2355                    if !self.common.is_connected(Some(&c)).await {
2356                        warn!(
2357                            "Stream {} not associated with an active connection.",
2358                            stream
2359                        );
2360                        continue;
2361                    }
2362                    c
2363                }
2364                None => {
2365                    warn!("Stream {} was not subscribed.", stream);
2366                    continue;
2367                }
2368            };
2369
2370            let has_callbacks = {
2371                let conn_state = conn.state.lock().await;
2372                conn_state
2373                    .stream_callbacks
2374                    .get(&key)
2375                    .is_some_and(|v| !v.is_empty())
2376            };
2377
2378            if has_callbacks {
2379                continue;
2380            }
2381
2382            let payload = json!({
2383                "method": "UNSUBSCRIBE",
2384                "params": [stream.clone()],
2385                "id": request_id,
2386            });
2387
2388            info!("UNSUBSCRIBE → {:?}", payload);
2389
2390            let common = Arc::clone(&self.common);
2391            let conn_clone = Arc::clone(&conn);
2392            let msg = serde_json::to_string(&payload).unwrap();
2393            spawn(async move {
2394                let _ = common
2395                    .send(msg, None, false, Duration::ZERO, Some(conn_clone))
2396                    .await;
2397            });
2398
2399            {
2400                let mut connection_streams = self.connection_streams.lock().await;
2401                connection_streams.remove(&key);
2402            }
2403            {
2404                let mut conn_state = conn.state.lock().await;
2405                conn_state.stream_callbacks.remove(&key);
2406            }
2407        }
2408    }
2409
2410    /// Checks if a specific stream is currently subscribed.
2411    ///
2412    /// # Arguments
2413    ///
2414    /// * `stream` - The stream identifier to check for subscription status
2415    ///
2416    /// # Returns
2417    ///
2418    /// `true` if the stream is subscribed, `false` otherwise
2419    ///
2420    /// # Async
2421    ///
2422    /// This method is asynchronous and requires `.await` when called
2423    pub async fn is_subscribed(&self, stream: &str) -> bool {
2424        let map = self.connection_streams.lock().await;
2425
2426        if map.contains_key(stream) {
2427            return true;
2428        }
2429
2430        let suffix = format!("::{}", stream);
2431        map.keys().any(|k| k.ends_with(&suffix))
2432    }
2433
2434    /// Generates a unique key for a stream based on its name and optional URL path.
2435    ///
2436    /// # Arguments
2437    ///
2438    /// * `stream` - The name of the stream
2439    /// * `url_path` - An optional URL path associated with the stream
2440    ///
2441    /// # Returns
2442    ///
2443    /// A `String` representing the unique key for the stream
2444    ///
2445    fn stream_key(&self, stream: &str, url_path: Option<&str>) -> String {
2446        match url_path {
2447            Some(p) if !p.is_empty() => format!("{p}::{stream}"),
2448            _ => stream.to_string(),
2449        }
2450    }
2451
2452    /// Prepares a WebSocket URL for streaming with optional stream names and time unit configuration.
2453    ///
2454    /// # Arguments
2455    ///
2456    /// * `streams` - A slice of stream names to be included in the URL
2457    /// * `url_path` - An optional path to append to the base WebSocket URL
2458    ///
2459    /// # Returns
2460    ///
2461    /// A fully constructed WebSocket URL with optional stream and time unit parameters
2462    ///
2463    /// # Notes
2464    ///
2465    /// - If no time unit is specified, the base URL is returned
2466    /// - Validates and appends the time unit parameter if provided and valid
2467    /// - Handles URL parameter separator based on existing query parameters
2468    fn prepare_url(&self, streams: &[String], url_path: Option<&str>) -> String {
2469        let mut url = format!(
2470            "{}/stream?streams={}",
2471            match url_path {
2472                Some(path) => format!(
2473                    "{}/{}",
2474                    self.configuration.ws_url.as_deref().unwrap_or(""),
2475                    path
2476                ),
2477                None => self
2478                    .configuration
2479                    .ws_url
2480                    .as_deref()
2481                    .unwrap_or("")
2482                    .to_string(),
2483            },
2484            streams.join("/")
2485        );
2486
2487        let time_unit = match &self.configuration.time_unit {
2488            Some(u) => u.to_string(),
2489            None => return url,
2490        };
2491
2492        match validate_time_unit(&time_unit) {
2493            Ok(Some(validated)) => {
2494                let sep = if url.contains('?') { '&' } else { '?' };
2495                url.push(sep);
2496                url.push_str("timeUnit=");
2497                url.push_str(validated);
2498            }
2499            Ok(None) => {}
2500            Err(e) => {
2501                error!("Invalid time unit provided: {:?}", e);
2502            }
2503        }
2504
2505        url
2506    }
2507
2508    /// Handles stream assignment by finding or creating WebSocket connections for a list of streams.
2509    ///
2510    /// This method attempts to assign streams to existing WebSocket connections or creates new
2511    /// connections if needed. It groups streams by their assigned connections and handles scenarios
2512    /// such as closed or pending reconnection connections.
2513    ///
2514    /// # Arguments
2515    ///
2516    /// * `streams` - A vector of stream names to be assigned
2517    /// * `url_path` - An optional URL path associated with the streams
2518    ///
2519    /// # Returns
2520    ///
2521    /// A vector of tuples containing WebSocket connections and their associated streams
2522    ///
2523    /// # Errors
2524    ///
2525    /// Returns an empty result if no connections can be established for the streams
2526    async fn handle_stream_assignment(
2527        &self,
2528        streams: Vec<String>,
2529        url_path: Option<&str>,
2530    ) -> Vec<(Arc<WebsocketConnection>, Vec<String>)> {
2531        let mut connection_streams: Vec<(String, Arc<WebsocketConnection>)> = Vec::new();
2532
2533        for stream in streams {
2534            let key = self.stream_key(&stream, url_path);
2535
2536            let mut conn_opt = {
2537                let map = self.connection_streams.lock().await;
2538                map.get(&key).cloned()
2539            };
2540
2541            let need_new = if let Some(conn) = &conn_opt {
2542                let state = conn.state.lock().await;
2543                state.close_initiated || state.reconnection_pending
2544            } else {
2545                true
2546            };
2547
2548            if need_new {
2549                match self.common.get_connection(true, url_path).await {
2550                    Ok(new_conn) => {
2551                        let mut map = self.connection_streams.lock().await;
2552                        map.insert(key.clone(), new_conn.clone());
2553                        conn_opt = Some(new_conn);
2554                    }
2555                    Err(err) => {
2556                        warn!(
2557                            "No available WebSocket connection to subscribe stream `{}` (key `{}`): {:?}",
2558                            stream, key, err
2559                        );
2560                        continue;
2561                    }
2562                }
2563            }
2564
2565            if let Some(conn) = conn_opt {
2566                {
2567                    let mut conn_state = conn.state.lock().await;
2568                    conn_state.stream_callbacks.entry(key.clone()).or_default();
2569                }
2570                connection_streams.push((stream, conn));
2571            }
2572        }
2573
2574        let mut groups: Vec<(Arc<WebsocketConnection>, Vec<String>)> = Vec::new();
2575        for (stream, conn) in connection_streams {
2576            if let Some((_, vec)) = groups.iter_mut().find(|(c, _)| Arc::ptr_eq(c, &conn)) {
2577                vec.push(stream);
2578            } else {
2579                groups.push((conn, vec![stream]));
2580            }
2581        }
2582
2583        groups
2584    }
2585
2586    /// Sends a WebSocket subscription payload for the specified streams.
2587    ///
2588    /// # Arguments
2589    ///
2590    /// * `connection` - The WebSocket connection to send the subscription on
2591    /// * `streams` - A vector of stream names to subscribe to
2592    /// * `id` - An optional request ID for the subscription (will be randomly generated if not provided)
2593    ///
2594    /// # Remarks
2595    ///
2596    /// This method constructs a SUBSCRIBE payload, logs it, and sends it asynchronously using the WebSocket connection.
2597    /// If serialization fails, an error is logged and the method returns without sending.
2598    fn send_subscription_payload(
2599        &self,
2600        connection: &Arc<WebsocketConnection>,
2601        streams: &Vec<String>,
2602        id: Option<StreamId>,
2603    ) {
2604        let request_id = normalize_stream_id(
2605            id.clone(),
2606            self.stream_id_is_strictly_number.load(Ordering::Relaxed),
2607        );
2608
2609        let payload = json!({
2610            "method": "SUBSCRIBE",
2611            "params": streams,
2612            "id": request_id,
2613        });
2614
2615        info!("SUBSCRIBE → {:?}", payload);
2616
2617        let common = Arc::clone(&self.common);
2618        let msg = match serde_json::to_string(&payload) {
2619            Ok(s) => s,
2620            Err(e) => {
2621                error!("Failed to serialize SUBSCRIBE payload: {}", e);
2622                return;
2623            }
2624        };
2625        let conn_clone = Arc::clone(connection);
2626
2627        spawn(async move {
2628            let _ = common
2629                .send(msg, None, false, Duration::ZERO, Some(conn_clone))
2630                .await;
2631        });
2632    }
2633}
2634
2635#[async_trait]
2636impl WebsocketHandler for WebsocketStreams {
2637    /// Handles the WebSocket connection opening by processing any pending subscriptions.
2638    ///
2639    /// This method is called when a WebSocket connection is established. It retrieves
2640    /// any pending stream subscriptions from the connection state and sends them
2641    /// immediately using the `send_subscription_payload` method.
2642    ///
2643    /// # Arguments
2644    ///
2645    /// * `_url` - The URL of the WebSocket connection (unused)
2646    /// * `connection` - The WebSocket connection that has just been opened
2647    ///
2648    /// # Remarks
2649    ///
2650    /// If there are any pending subscriptions, they are sent as a batch subscription
2651    /// payload. The method uses a lock to safely access and clear the pending subscriptions
2652    /// from the connection state.
2653    async fn on_open(&self, _url: String, connection: Arc<WebsocketConnection>) {
2654        let pending_subs: Vec<String> = {
2655            let mut conn_state = connection.state.lock().await;
2656            take(&mut conn_state.pending_subscriptions)
2657                .into_iter()
2658                .collect()
2659        };
2660
2661        if !pending_subs.is_empty() {
2662            info!("Processing queued subscriptions for connection");
2663            self.send_subscription_payload(&connection, &pending_subs, None);
2664        }
2665    }
2666
2667    /// Handles incoming WebSocket stream messages by parsing the JSON payload and invoking registered stream callbacks.
2668    ///
2669    /// This method processes WebSocket messages with a specific structure, extracting the stream name and data.
2670    /// It retrieves and executes any registered callbacks associated with the stream name.
2671    ///
2672    /// # Arguments
2673    ///
2674    /// * `data` - The raw WebSocket message as a JSON-formatted string
2675    /// * `connection` - The WebSocket connection through which the message was received
2676    ///
2677    /// # Behavior
2678    ///
2679    /// - Parses the JSON message
2680    /// - Extracts the stream name and data payload
2681    /// - Looks up and invokes any registered callbacks for the stream
2682    /// - Silently returns if message parsing or stream extraction fails
2683    async fn on_message(&self, data: String, connection: Arc<WebsocketConnection>) {
2684        let msg: Value = match serde_json::from_str(&data) {
2685            Ok(v) => v,
2686            Err(err) => {
2687                error!(
2688                    "Failed to parse WebSocket stream message {} – {}",
2689                    data, err
2690                );
2691                return;
2692            }
2693        };
2694
2695        let (stream_name, payload) = match (
2696            msg.get("stream").and_then(Value::as_str),
2697            msg.get("data").cloned(),
2698        ) {
2699            (Some(name), Some(data)) => (name.to_string(), data),
2700            _ => return,
2701        };
2702
2703        let callbacks = {
2704            let conn_state = connection.state.lock().await;
2705            let key = self.stream_key(&stream_name, conn_state.url_path.as_deref());
2706            conn_state
2707                .stream_callbacks
2708                .get(&key)
2709                .cloned()
2710                .unwrap_or_else(Vec::new)
2711        };
2712
2713        for callback in callbacks {
2714            callback(&payload);
2715        }
2716    }
2717
2718    /// Retrieves the reconnection URL for a specific WebSocket connection by identifying all streams associated with that connection.
2719    ///
2720    /// # Arguments
2721    ///
2722    /// * `_default_url` - A default URL that can be used if no specific reconnection URL is determined
2723    /// * `connection` - The WebSocket connection for which to generate a reconnection URL
2724    ///
2725    /// # Returns
2726    ///
2727    /// A URL string that can be used to reconnect to the WebSocket, based on the streams associated with the given connection
2728    async fn get_reconnect_url(
2729        &self,
2730        _default_url: String,
2731        connection: Arc<WebsocketConnection>,
2732    ) -> String {
2733        let connection_streams = self.connection_streams.lock().await;
2734        let reconnect_streams = connection_streams
2735            .iter()
2736            .filter_map(|(key, conn_arc)| {
2737                if Arc::ptr_eq(conn_arc, &connection) {
2738                    let stream = match key.split_once("::") {
2739                        Some((_prefix, rest)) => rest.to_string(),
2740                        None => key.clone(),
2741                    };
2742                    Some(stream)
2743                } else {
2744                    None
2745                }
2746            })
2747            .collect::<Vec<_>>();
2748
2749        let url_path = {
2750            let st = connection.state.lock().await;
2751            st.url_path.as_deref().map(std::string::ToString::to_string)
2752        };
2753
2754        self.prepare_url(&reconnect_streams, url_path.as_deref())
2755    }
2756}
2757
2758pub struct WebsocketStream<T> {
2759    websocket_base: WebsocketBase,
2760    stream_or_id: String,
2761    url_path: Option<String>,
2762    callback: Mutex<Option<Arc<dyn Fn(&Value) + Send + Sync>>>,
2763    pub id: Option<StreamId>,
2764    _phantom: PhantomData<T>,
2765}
2766
2767impl<T> WebsocketStream<T>
2768where
2769    T: DeserializeOwned + Send + 'static,
2770{
2771    /// Registers a callback function for a specific event on the WebSocket stream.
2772    ///
2773    /// This method currently only supports the "message" event. When a message is received,
2774    /// the provided callback function will be invoked with the deserialized payload.
2775    ///
2776    /// # Arguments
2777    ///
2778    /// * `event` - The event type to listen for (currently only "message" is supported)
2779    /// * `callback_fn` - A function that will be called with the deserialized message payload
2780    ///
2781    /// # Errors
2782    ///
2783    /// Logs an error if the payload cannot be deserialized into the expected type
2784    ///
2785    /// # Examples
2786    ///
2787    ///
2788    /// stream.on("message", |data: `MyType`| {
2789    ///     // Handle the deserialized message
2790    /// });
2791    async fn on<F>(&self, event: &str, callback_fn: F)
2792    where
2793        F: Fn(T) + Send + Sync + 'static,
2794    {
2795        if event != "message" {
2796            return;
2797        }
2798
2799        let cb_wrapper: Arc<dyn Fn(&Value) + Send + Sync> =
2800            Arc::new(
2801                move |v: &Value| match serde_json::from_value::<T>(v.clone()) {
2802                    Ok(data) => callback_fn(data),
2803                    Err(e) => error!("Failed to deserialize stream payload: {:?}", e),
2804                },
2805            );
2806
2807        {
2808            let mut guard = self.callback.lock().await;
2809            *guard = Some(cb_wrapper.clone());
2810        }
2811
2812        match &self.websocket_base {
2813            WebsocketBase::WebsocketStreams(ws_streams) => {
2814                let key = ws_streams.stream_key(&self.stream_or_id, self.url_path.as_deref());
2815                let conn = {
2816                    let map = ws_streams.connection_streams.lock().await;
2817                    map.get(&key).cloned().expect("stream must be subscribed")
2818                };
2819
2820                {
2821                    let mut conn_state = conn.state.lock().await;
2822                    let entry = conn_state.stream_callbacks.entry(key).or_default();
2823
2824                    if !entry
2825                        .iter()
2826                        .any(|existing| Arc::ptr_eq(existing, &cb_wrapper))
2827                    {
2828                        entry.push(cb_wrapper);
2829                    }
2830                }
2831            }
2832            WebsocketBase::WebsocketApi(ws_api) => {
2833                let mut stream_callbacks = ws_api.stream_callbacks.lock().await;
2834                let entry = stream_callbacks
2835                    .entry(self.stream_or_id.clone())
2836                    .or_default();
2837
2838                if !entry
2839                    .iter()
2840                    .any(|existing| Arc::ptr_eq(existing, &cb_wrapper))
2841                {
2842                    entry.push(cb_wrapper);
2843                }
2844            }
2845        }
2846    }
2847
2848    /// Synchronously sets a message callback for the WebSocket stream on the current thread.
2849    ///
2850    /// # Arguments
2851    ///
2852    /// * `callback_fn` - A function that will be called with the deserialized message payload
2853    ///
2854    /// # Panics
2855    ///
2856    /// Panics if the thread runtime fails to be created or if the thread join fails
2857    ///
2858    /// # Examples
2859    ///
2860    ///
2861    /// let stream = `Arc::new(WebsocketStream::new())`;
2862    /// `stream.on_message(|data`: `MyType`| {
2863    ///     // Handle the deserialized message
2864    /// });
2865    ///
2866    pub fn on_message<F>(self: &Arc<Self>, callback_fn: F)
2867    where
2868        T: Send + Sync,
2869        F: Fn(T) + Send + Sync + 'static,
2870    {
2871        let handler: Arc<Self> = Arc::clone(self);
2872
2873        std::thread::spawn(move || {
2874            let rt = tokio::runtime::Builder::new_current_thread()
2875                .enable_all()
2876                .build()
2877                .expect("failed to build Tokio runtime");
2878
2879            rt.block_on(handler.on("message", callback_fn));
2880        })
2881        .join()
2882        .expect("on_message thread panicked");
2883    }
2884
2885    /// Unsubscribes from the current WebSocket stream and removes the associated callback.
2886    ///
2887    /// This method performs the following actions:
2888    /// - Removes the current callback associated with the stream
2889    /// - Removes the callback from the connection's stream callbacks
2890    /// - Asynchronously unsubscribes from the stream using the WebSocket streams base
2891    ///
2892    /// # Panics
2893    ///
2894    /// Panics if the stream is not subscribed to
2895    ///
2896    /// # Notes
2897    /// - If no callback is present, no action is taken
2898    /// - Spawns an asynchronous task to handle the unsubscription process
2899    pub async fn unsubscribe(&self) {
2900        let maybe_cb = {
2901            let mut guard = self.callback.lock().await;
2902            guard.take()
2903        };
2904
2905        if let Some(cb) = maybe_cb {
2906            match &self.websocket_base {
2907                WebsocketBase::WebsocketStreams(ws_streams) => {
2908                    let key = ws_streams.stream_key(&self.stream_or_id, self.url_path.as_deref());
2909                    let conn = {
2910                        let map = ws_streams.connection_streams.lock().await;
2911                        map.get(&key)
2912                            .cloned()
2913                            .expect("stream must have been subscribed")
2914                    };
2915
2916                    {
2917                        let mut conn_state = conn.state.lock().await;
2918                        if let Some(list) = conn_state.stream_callbacks.get_mut(&key) {
2919                            list.retain(|existing| !Arc::ptr_eq(existing, &cb));
2920                        }
2921                    }
2922
2923                    let stream = self.stream_or_id.clone();
2924                    let id = self.id.clone();
2925                    let url_path = self.url_path.clone();
2926                    let websocket_streams_base = Arc::clone(ws_streams);
2927                    spawn(async move {
2928                        websocket_streams_base
2929                            .unsubscribe(vec![stream], id, url_path.as_deref())
2930                            .await;
2931                    });
2932                }
2933                WebsocketBase::WebsocketApi(ws_api) => {
2934                    let mut stream_callbacks = ws_api.stream_callbacks.lock().await;
2935                    if let Some(list) = stream_callbacks.get_mut(&self.stream_or_id) {
2936                        list.retain(|existing| !Arc::ptr_eq(existing, &cb));
2937                    }
2938                }
2939            }
2940        }
2941    }
2942}
2943
2944/// Creates a new WebSocket stream handler for the specified stream or ID.
2945/// This function subscribes to the stream if the WebSocket base is of type `WebsocketStreams`.
2946///
2947/// # Arguments
2948///
2949/// * `websocket_base` - The base WebSocket instance (either `WebsocketStreams` or `WebsocketApi`)
2950/// * `stream_or_id` - The stream name or identifier to subscribe to
2951/// * `id` - An optional request identifier for the subscription
2952/// * `url_path` - An optional URL path for the subscription
2953///
2954/// # Returns
2955///
2956/// A new `WebsocketStream` instance.
2957///
2958pub async fn create_stream_handler<T>(
2959    websocket_base: WebsocketBase,
2960    stream_or_id: String,
2961    id: Option<StreamId>,
2962    url_path: Option<String>,
2963) -> Arc<WebsocketStream<T>>
2964where
2965    T: DeserializeOwned + Send + 'static,
2966{
2967    match &websocket_base {
2968        WebsocketBase::WebsocketStreams(ws_streams) => {
2969            ws_streams
2970                .clone()
2971                .subscribe(vec![stream_or_id.clone()], id.clone(), url_path.as_deref())
2972                .await;
2973        }
2974        WebsocketBase::WebsocketApi(_) => {}
2975    }
2976
2977    Arc::new(WebsocketStream {
2978        websocket_base,
2979        stream_or_id,
2980        url_path,
2981        id,
2982        callback: Mutex::new(None),
2983        _phantom: PhantomData,
2984    })
2985}
2986
2987#[cfg(test)]
2988mod tests {
2989    use crate::TOKIO_SHARED_RT;
2990    use crate::common::utils::{SignatureGenerator, build_user_agent};
2991    use crate::common::websocket::{
2992        PendingRequest, ReconnectEntry, SendWebsocketMessageResult, WebsocketApi, WebsocketBase,
2993        WebsocketCommon, WebsocketConnection, WebsocketEvent, WebsocketEventEmitter,
2994        WebsocketHandler, WebsocketMessageSendOptions, WebsocketMode, WebsocketSessionLogonReq,
2995        WebsocketStream, WebsocketStreams, create_stream_handler,
2996    };
2997    use crate::config::{ConfigurationWebsocketApi, ConfigurationWebsocketStreams, PrivateKey};
2998    use crate::errors::{WebsocketConnectionFailureReason, WebsocketError};
2999    use crate::models::{StreamId, TimeUnit};
3000    use async_trait::async_trait;
3001    use futures::{SinkExt, StreamExt};
3002    use http::header::USER_AGENT;
3003    use regex::Regex;
3004    use serde_json::{Value, json};
3005    use std::collections::{BTreeMap, HashSet};
3006    use std::marker::PhantomData;
3007    use std::net::SocketAddr;
3008    use std::sync::{
3009        Arc,
3010        atomic::{AtomicBool, AtomicUsize, Ordering},
3011    };
3012    use tokio::net::TcpListener;
3013    use tokio::sync::{
3014        Mutex,
3015        mpsc::{Receiver, unbounded_channel},
3016        oneshot,
3017    };
3018    use tokio::time::{Duration, advance, pause, resume, sleep, timeout};
3019    use tokio_tungstenite::{accept_async, accept_hdr_async, tungstenite, tungstenite::Message};
3020    use tungstenite::handshake::server::Request;
3021
3022    /// RAII guard that aborts a spawned task when dropped.
3023    ///
3024    /// Used by tests to bound the lifetime of mock listener / spawned helper tasks
3025    /// to the scope of the test, so they don't leak into `TOKIO_SHARED_RT` and
3026    /// add nondeterministic background load to subsequent tests.
3027    struct AbortOnDrop(tokio::task::JoinHandle<()>);
3028
3029    impl Drop for AbortOnDrop {
3030        fn drop(&mut self) {
3031            self.0.abort();
3032        }
3033    }
3034
3035    /// Spawn a mock WebSocket listener that accepts a single connection and
3036    /// holds it open until the client disconnects, then exits. The returned
3037    /// guard aborts the task when dropped.
3038    fn spawn_mock_ws_listener(listener: TcpListener) -> AbortOnDrop {
3039        AbortOnDrop(tokio::spawn(async move {
3040            if let Ok((stream, _)) = listener.accept().await {
3041                let Ok(mut ws) = accept_async(stream).await else {
3042                    return;
3043                };
3044                while ws.next().await.is_some() {}
3045            }
3046        }))
3047    }
3048
3049    fn subscribe_events(common: &WebsocketCommon) -> Arc<Mutex<Vec<WebsocketEvent>>> {
3050        let events = Arc::new(Mutex::new(Vec::new()));
3051        let events_clone = events.clone();
3052        common.events.subscribe(move |event| {
3053            let events_clone = events_clone.clone();
3054            tokio::spawn(async move {
3055                events_clone.lock().await.push(event);
3056            });
3057        });
3058        events
3059    }
3060
3061    async fn create_connection(
3062        id: &str,
3063        has_writer: bool,
3064        reconnection_pending: bool,
3065        renewal_pending: bool,
3066        close_initiated: bool,
3067    ) -> Arc<WebsocketConnection> {
3068        let conn = WebsocketConnection::new(id);
3069        let mut st = conn.state.lock().await;
3070        st.reconnection_pending = reconnection_pending;
3071        st.renewal_pending = renewal_pending;
3072        st.close_initiated = close_initiated;
3073        if has_writer {
3074            let (tx, _) = unbounded_channel::<Message>();
3075            st.ws_write_tx = Some(tx);
3076        } else {
3077            st.ws_write_tx = None;
3078        }
3079        drop(st);
3080        conn
3081    }
3082
3083    fn create_websocket_api(
3084        time_unit: Option<TimeUnit>,
3085        mode: Option<WebsocketMode>,
3086        auto_session_relogon: Option<bool>,
3087    ) -> Arc<WebsocketApi> {
3088        let mode = mode.unwrap_or(WebsocketMode::Single);
3089        let auto_session_relogon = auto_session_relogon.unwrap_or(true);
3090        let sig_gen = SignatureGenerator::new(
3091            Some("api_secret".into()),
3092            None::<PrivateKey>,
3093            None::<String>,
3094        );
3095        let config = ConfigurationWebsocketApi {
3096            api_key: Some("api_key".into()),
3097            api_secret: Some("api_secret".into()),
3098            private_key: None,
3099            private_key_passphrase: None,
3100            ws_url: Some("wss://example.com".into()),
3101            mode,
3102            reconnect_delay: 1000,
3103            signature_gen: sig_gen,
3104            timeout: 500,
3105            time_unit,
3106            auto_session_relogon,
3107            agent: None,
3108            user_agent: build_user_agent("product"),
3109        };
3110        let conn1 = WebsocketConnection::new("c1");
3111        let conn2 = WebsocketConnection::new("c2");
3112        WebsocketApi::new(config, vec![conn1, conn2])
3113    }
3114
3115    fn create_websocket_streams(
3116        ws_url: Option<&str>,
3117        conns: Option<Vec<Arc<WebsocketConnection>>>,
3118        url_paths: Option<Vec<String>>,
3119    ) -> Arc<WebsocketStreams> {
3120        let mut connections: Vec<Arc<WebsocketConnection>> = vec![];
3121        let url_paths = url_paths.unwrap_or_default();
3122        if conns.is_none() {
3123            connections.push(WebsocketConnection::new("c1"));
3124            connections.push(WebsocketConnection::new("c2"));
3125        } else {
3126            connections = conns.expect("Expected connections to be set");
3127        }
3128        let config = ConfigurationWebsocketStreams {
3129            ws_url: Some(ws_url.unwrap_or("example.com").to_string()),
3130            mode: WebsocketMode::Single,
3131            reconnect_delay: 500,
3132            time_unit: None,
3133            agent: None,
3134            user_agent: build_user_agent("product"),
3135        };
3136        WebsocketStreams::new(config, connections, url_paths)
3137    }
3138
3139    fn subscribe_to_emitter(emitter: &WebsocketEventEmitter) -> Receiver<WebsocketEvent> {
3140        let (test_tx, test_rx) = tokio::sync::mpsc::channel(16);
3141        let _sub = emitter.subscribe(move |evt| {
3142            let _ = test_tx.try_send(evt);
3143        });
3144        test_rx
3145    }
3146
3147    async fn expect_websocket_event(rx: &mut Receiver<WebsocketEvent>) -> WebsocketEvent {
3148        timeout(Duration::from_millis(200), rx.recv())
3149            .await
3150            .expect("timed out waiting for event")
3151            .expect("subscriber channel closed")
3152    }
3153
3154    async fn eventually_async<F, Fut>(max_wait: Duration, mut f: F) -> bool
3155    where
3156        F: FnMut() -> Fut,
3157        Fut: std::future::Future<Output = bool>,
3158    {
3159        let start = tokio::time::Instant::now();
3160        while start.elapsed() < max_wait {
3161            if f().await {
3162                return true;
3163            }
3164            sleep(Duration::from_millis(20)).await;
3165        }
3166        false
3167    }
3168
3169    mod event_emitter {
3170        use super::*;
3171
3172        #[test]
3173        fn event_emitter_subscribe_and_emit() {
3174            TOKIO_SHARED_RT.block_on(async {
3175                let emitter = WebsocketEventEmitter::new();
3176                let (tx, rx) = oneshot::channel();
3177                let tx = Arc::new(std::sync::Mutex::new(Some(tx)));
3178                let tx_clone = tx.clone();
3179                let _sub = emitter.subscribe(move |event| {
3180                    if let Some(sender) = tx_clone.lock().unwrap().take() {
3181                        let _ = sender.send(event);
3182                    }
3183                });
3184                emitter.emit(&WebsocketEvent::Open);
3185                let received = timeout(Duration::from_millis(100), rx)
3186                    .await
3187                    .expect("timed out");
3188                assert_eq!(received, Ok(WebsocketEvent::Open));
3189            });
3190        }
3191
3192        #[test]
3193        fn single_subscriber_gets_event() {
3194            TOKIO_SHARED_RT.block_on(async {
3195                let emitter = WebsocketEventEmitter::new();
3196                let mut rx = subscribe_to_emitter(&emitter);
3197
3198                let e1 = WebsocketEvent::Open;
3199                emitter.emit(&e1);
3200
3201                let got = expect_websocket_event(&mut rx).await;
3202                assert_eq!(got, e1);
3203            });
3204        }
3205
3206        #[test]
3207        fn multiple_subscribers_get_event() {
3208            TOKIO_SHARED_RT.block_on(async {
3209                let emitter = WebsocketEventEmitter::new();
3210                let mut rx1 = subscribe_to_emitter(&emitter);
3211                let mut rx2 = subscribe_to_emitter(&emitter);
3212
3213                let e = WebsocketEvent::Message("hello".into());
3214                emitter.emit(&e);
3215
3216                assert_eq!(expect_websocket_event(&mut rx1).await, e.clone());
3217                assert_eq!(expect_websocket_event(&mut rx2).await, e);
3218            });
3219        }
3220
3221        #[test]
3222        fn closed_subscribers_are_pruned() {
3223            TOKIO_SHARED_RT.block_on(async {
3224                let emitter = WebsocketEventEmitter::new();
3225                let rx1 = subscribe_to_emitter(&emitter);
3226                let mut rx2 = subscribe_to_emitter(&emitter);
3227                drop(rx1);
3228
3229                let e = WebsocketEvent::Pong;
3230                emitter.emit(&e);
3231
3232                assert_eq!(expect_websocket_event(&mut rx2).await, e);
3233            });
3234        }
3235
3236        #[test]
3237        fn prune_on_error_does_not_hang() {
3238            TOKIO_SHARED_RT.block_on(async {
3239                let emitter = WebsocketEventEmitter::new();
3240                let rx = subscribe_to_emitter(&emitter);
3241                drop(rx);
3242
3243                let e = WebsocketEvent::Close(1000, "bye".into());
3244                emitter.emit(&e);
3245            });
3246        }
3247    }
3248
3249    mod websocket_common {
3250        use super::*;
3251
3252        mod initialisation {
3253            use super::*;
3254
3255            #[test]
3256            fn single_mode() {
3257                TOKIO_SHARED_RT.block_on(async {
3258                    let common = WebsocketCommon::new(vec![], WebsocketMode::Single, 0, None, None);
3259                    assert_eq!(common.connection_pool.len(), 1);
3260                });
3261            }
3262
3263            #[test]
3264            fn pool_mode() {
3265                TOKIO_SHARED_RT.block_on(async {
3266                    let common =
3267                        WebsocketCommon::new(vec![], WebsocketMode::Pool(3), 0, None, None);
3268                    assert_eq!(common.connection_pool.len(), 3);
3269                });
3270            }
3271        }
3272
3273        mod spawn_reconnect_loop {
3274            use super::*;
3275
3276            #[test]
3277            fn successful_reconnect_entry_triggers_init_connect() {
3278                TOKIO_SHARED_RT.block_on(async {
3279                    let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
3280                    let addr = listener.local_addr().unwrap();
3281                    tokio::spawn(async move {
3282                        if let Ok((stream, _)) = listener.accept().await {
3283                            let mut ws = accept_async(stream).await.unwrap();
3284                            sleep(Duration::from_secs(5)).await;
3285                            let _ = ws.close(None).await;
3286                        }
3287                    });
3288
3289                    let conn = WebsocketConnection::new("c1");
3290                    let common = WebsocketCommon::new(
3291                        vec![conn.clone()],
3292                        WebsocketMode::Single,
3293                        10,
3294                        None,
3295                        None,
3296                    );
3297                    let url = format!("ws://{addr}");
3298                    common
3299                        .reconnect_tx
3300                        .send(ReconnectEntry {
3301                            connection_id: "c1".into(),
3302                            url: url.clone(),
3303                            is_renewal: false,
3304                        })
3305                        .await
3306                        .unwrap();
3307
3308                    let mut ok = false;
3309                    for _ in 0..100 {
3310                        if conn.state.lock().await.ws_write_tx.is_some() {
3311                            ok = true;
3312                            break;
3313                        }
3314                        sleep(Duration::from_millis(50)).await;
3315                    }
3316                    assert!(ok, "expected ws_write_tx to be Some after reconnect");
3317                });
3318            }
3319
3320            #[test]
3321            fn reconnect_entry_with_unknown_id_is_ignored() {
3322                TOKIO_SHARED_RT.block_on(async {
3323                    let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
3324                    let addr = listener.local_addr().unwrap();
3325                    tokio::spawn(async move {
3326                        if let Ok((stream, _)) = listener.accept().await {
3327                            let mut ws = accept_async(stream).await.unwrap();
3328                            let _ = ws.close(None).await;
3329                        }
3330                    });
3331
3332                    let conn = WebsocketConnection::new("c1");
3333                    let common = WebsocketCommon::new(
3334                        vec![conn.clone()],
3335                        WebsocketMode::Single,
3336                        5,
3337                        None,
3338                        None,
3339                    );
3340                    let url = format!("ws://{addr}");
3341                    common
3342                        .reconnect_tx
3343                        .send(ReconnectEntry {
3344                            connection_id: "other".into(),
3345                            url,
3346                            is_renewal: false,
3347                        })
3348                        .await
3349                        .unwrap();
3350
3351                    sleep(Duration::from_secs(1)).await;
3352
3353                    let st = conn.state.lock().await;
3354                    assert!(st.ws_write_tx.is_none());
3355                });
3356            }
3357
3358            #[test]
3359            fn renewal_entries_bypass_initial_delay() {
3360                TOKIO_SHARED_RT.block_on(async {
3361                    let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
3362                    let addr = listener.local_addr().unwrap();
3363                    tokio::spawn(async move {
3364                        if let Ok((stream, _)) = listener.accept().await {
3365                            let mut ws = accept_async(stream).await.unwrap();
3366                            let _ = ws.close(None).await;
3367                        }
3368                    });
3369
3370                    let conn = WebsocketConnection::new("renew");
3371                    let common = WebsocketCommon::new(
3372                        vec![conn.clone()],
3373                        WebsocketMode::Single,
3374                        200,
3375                        None,
3376                        None,
3377                    );
3378                    let url = format!("ws://{addr}");
3379                    common
3380                        .reconnect_tx
3381                        .send(ReconnectEntry {
3382                            connection_id: "renew".into(),
3383                            url: url.clone(),
3384                            is_renewal: true,
3385                        })
3386                        .await
3387                        .unwrap();
3388
3389                    sleep(Duration::from_secs(2)).await;
3390
3391                    let st = conn.state.lock().await;
3392
3393                    assert!(st.ws_write_tx.is_some());
3394                });
3395            }
3396
3397            #[test]
3398            fn non_renewal_entries_respect_initial_delay() {
3399                TOKIO_SHARED_RT.block_on(async {
3400                    let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
3401                    let addr = listener.local_addr().unwrap();
3402                    let _listener_guard = spawn_mock_ws_listener(listener);
3403
3404                    let reconnect_delay_ms = 1500;
3405                    let conn = WebsocketConnection::new("nonrenew");
3406                    let common = WebsocketCommon::new(
3407                        vec![conn.clone()],
3408                        WebsocketMode::Single,
3409                        reconnect_delay_ms,
3410                        None,
3411                        None,
3412                    );
3413                    let url = format!("ws://{addr}");
3414                    let queued_at = tokio::time::Instant::now();
3415                    common
3416                        .reconnect_tx
3417                        .send(ReconnectEntry {
3418                            connection_id: "nonrenew".into(),
3419                            url: url.clone(),
3420                            is_renewal: false,
3421                        })
3422                        .await
3423                        .unwrap();
3424
3425                    sleep(Duration::from_millis(100)).await;
3426                    if queued_at.elapsed()
3427                        < Duration::from_millis(reconnect_delay_ms as u64)
3428                            .saturating_sub(Duration::from_millis(200))
3429                    {
3430                        assert!(conn.state.lock().await.ws_write_tx.is_none());
3431                    }
3432
3433                    let mut ok = false;
3434                    for _ in 0..200 {
3435                        if conn.state.lock().await.ws_write_tx.is_some() {
3436                            ok = true;
3437                            break;
3438                        }
3439                        sleep(Duration::from_millis(50)).await;
3440                    }
3441                    assert!(
3442                        ok,
3443                        "expected ws_write_tx to be Some after reconnect delay elapsed"
3444                    );
3445                });
3446            }
3447        }
3448
3449        mod spawn_renewal_loop {
3450            use super::*;
3451
3452            #[tokio::test]
3453            async fn scheduling_renewal_does_not_panic_for_known_connection() {
3454                pause();
3455
3456                let conn = WebsocketConnection::new("known");
3457                let common =
3458                    WebsocketCommon::new(vec![conn.clone()], WebsocketMode::Single, 0, None, None);
3459                let url = "wss://example".to_string();
3460                common
3461                    .renewal_tx
3462                    .send((conn.id.clone(), url))
3463                    .await
3464                    .unwrap();
3465                advance(Duration::from_secs(23 * 60 * 60 + 1)).await;
3466                resume();
3467            }
3468
3469            #[tokio::test]
3470            async fn scheduling_renewal_ignored_for_unknown_connection() {
3471                pause();
3472
3473                let conn = WebsocketConnection::new("c1");
3474                let common =
3475                    WebsocketCommon::new(vec![conn.clone()], WebsocketMode::Single, 0, None, None);
3476                common
3477                    .renewal_tx
3478                    .send(("other".into(), "u".into()))
3479                    .await
3480                    .unwrap();
3481                advance(Duration::from_secs(23 * 60 * 60 + 1)).await;
3482
3483                resume();
3484            }
3485        }
3486
3487        mod reconnect_regressions {
3488            use super::*;
3489
3490            #[test]
3491            fn init_connect_is_not_skipped_when_reconnection_pending() {
3492                TOKIO_SHARED_RT.block_on(async {
3493                    let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
3494                    let addr = listener.local_addr().unwrap();
3495
3496                    tokio::spawn(async move {
3497                        if let Ok((stream, _)) = listener.accept().await {
3498                            tokio::spawn(async move {
3499                                let mut ws = accept_async(stream).await.unwrap();
3500                                sleep(Duration::from_millis(500)).await;
3501                                let _ = ws.close(None).await;
3502                            });
3503                        }
3504                    });
3505
3506                    let conn = WebsocketConnection::new("c-reconnect");
3507                    {
3508                        let mut st = conn.state.lock().await;
3509                        let (tx, _) = unbounded_channel::<Message>();
3510                        st.ws_write_tx = Some(tx);
3511                        st.reconnection_pending = true;
3512                    }
3513
3514                    let common = WebsocketCommon::new(
3515                        vec![conn.clone()],
3516                        WebsocketMode::Single,
3517                        0,
3518                        None,
3519                        None,
3520                    );
3521
3522                    let url = format!("ws://{addr}");
3523                    common
3524                        .clone()
3525                        .init_connect(&url, false, Some(conn.clone()))
3526                        .await
3527                        .unwrap();
3528
3529                    let ok = eventually_async(Duration::from_secs(2), || {
3530                        let conn = conn.clone();
3531                        async move {
3532                            let st = conn.state.lock().await;
3533                            st.ws_write_tx.is_some() && !st.reconnection_pending
3534                        }
3535                    })
3536                    .await;
3537
3538                    assert!(
3539                        ok,
3540                        "expected writer installed and reconnection_pending cleared"
3541                    );
3542                });
3543            }
3544
3545            #[test]
3546            fn pending_request_is_resolved_on_socket_drop() {
3547                TOKIO_SHARED_RT.block_on(async {
3548                    let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
3549                    let addr = listener.local_addr().unwrap();
3550
3551                    tokio::spawn(async move {
3552                        if let Ok((stream, _)) = listener.accept().await {
3553                            let mut ws = accept_async(stream).await.unwrap();
3554                            let _ = ws.next().await;
3555                            let _ = ws.close(None).await;
3556                        }
3557                    });
3558
3559                    let conn = WebsocketConnection::new("c1");
3560                    let common = WebsocketCommon::new(
3561                        vec![conn.clone()],
3562                        WebsocketMode::Single,
3563                        0,
3564                        None,
3565                        None,
3566                    );
3567
3568                    let url = format!("ws://{addr}");
3569                    common
3570                        .clone()
3571                        .init_connect(&url, false, Some(conn.clone()))
3572                        .await
3573                        .unwrap();
3574
3575                    let rx = common
3576                        .send(
3577                            "{\"id\":\"req-1\",\"method\":\"PING\"}".to_string(),
3578                            Some("req-1".to_string()),
3579                            true,
3580                            Duration::from_millis(150),
3581                            Some(conn.clone()),
3582                        )
3583                        .await
3584                        .unwrap()
3585                        .expect("expected oneshot receiver");
3586
3587                    let res = timeout(Duration::from_secs(2), rx)
3588                        .await
3589                        .expect("did not resolve pending request")
3590                        .expect("oneshot cancelled");
3591
3592                    assert!(matches!(res, Err(WebsocketError::Timeout)));
3593
3594                    let ok = eventually_async(Duration::from_secs(1), || {
3595                        let conn = conn.clone();
3596                        async move { conn.state.lock().await.pending_requests.is_empty() }
3597                    })
3598                    .await;
3599
3600                    assert!(ok, "pending_requests should be drained");
3601                });
3602            }
3603        }
3604
3605        mod is_connection_ready {
3606            use super::*;
3607
3608            #[test]
3609            fn is_connection_ready() {
3610                TOKIO_SHARED_RT.block_on(async {
3611                    let conn = WebsocketConnection::new("c1");
3612                    let common = WebsocketCommon::new(
3613                        vec![conn.clone()],
3614                        WebsocketMode::Single,
3615                        0,
3616                        None,
3617                        None,
3618                    );
3619                    assert!(!common.is_connection_ready(&conn, false).await);
3620                    assert!(common.is_connection_ready(&conn, true).await);
3621                });
3622            }
3623
3624            #[test]
3625            fn connection_ready_basic() {
3626                TOKIO_SHARED_RT.block_on(async {
3627                    let conn = create_connection("c1", true, false, false, false).await;
3628                    let common = WebsocketCommon::new(
3629                        vec![conn.clone()],
3630                        WebsocketMode::Single,
3631                        0,
3632                        None,
3633                        None,
3634                    );
3635                    assert!(common.is_connection_ready(&conn, false).await);
3636                });
3637            }
3638
3639            #[test]
3640            fn connection_not_ready_without_writer() {
3641                TOKIO_SHARED_RT.block_on(async {
3642                    let conn = create_connection("c1", false, false, false, false).await;
3643                    let common = WebsocketCommon::new(
3644                        vec![conn.clone()],
3645                        WebsocketMode::Single,
3646                        0,
3647                        None,
3648                        None,
3649                    );
3650                    assert!(!common.is_connection_ready(&conn, false).await);
3651                    assert!(common.is_connection_ready(&conn, true).await);
3652                });
3653            }
3654
3655            #[test]
3656            fn connection_not_ready_when_flagged() {
3657                TOKIO_SHARED_RT.block_on(async {
3658                    let conn1 = create_connection("c1", true, true, false, false).await;
3659                    let conn2 = create_connection("c2", true, false, true, false).await;
3660                    let conn3 = create_connection("c3", true, false, false, true).await;
3661
3662                    let common = WebsocketCommon::new(
3663                        vec![conn1.clone(), conn2.clone(), conn3.clone()],
3664                        WebsocketMode::Pool(3),
3665                        0,
3666                        None,
3667                        None,
3668                    );
3669
3670                    assert!(!common.is_connection_ready(&conn1, false).await);
3671                    assert!(common.is_connection_ready(&conn2, false).await);
3672                    assert!(!common.is_connection_ready(&conn3, false).await);
3673                });
3674            }
3675        }
3676
3677        mod is_connected {
3678            use super::*;
3679
3680            #[test]
3681            fn with_pool_various_connections() {
3682                TOKIO_SHARED_RT.block_on(async {
3683                    let conn_a = create_connection("a", true, false, false, false).await;
3684                    let conn_b = create_connection("b", false, false, false, false).await;
3685                    let conn_c = create_connection("c", true, true, false, false).await;
3686                    let pool = vec![conn_a.clone(), conn_b.clone(), conn_c.clone()];
3687                    let common = WebsocketCommon::new(pool, WebsocketMode::Pool(3), 0, None, None);
3688
3689                    assert!(common.is_connected(None).await);
3690                    assert!(common.is_connected(Some(&conn_a)).await);
3691                    assert!(!common.is_connected(Some(&conn_b)).await);
3692                    assert!(!common.is_connected(Some(&conn_c)).await);
3693                });
3694            }
3695
3696            #[test]
3697            fn with_pool_all_bad_connections() {
3698                TOKIO_SHARED_RT.block_on(async {
3699                    let bad1 = create_connection("c1", false, false, false, false).await;
3700                    let bad2 = create_connection("c2", true, true, false, false).await;
3701                    let bad3 = create_connection("c3", true, false, false, true).await;
3702                    let common = WebsocketCommon::new(
3703                        vec![bad1, bad2, bad3],
3704                        WebsocketMode::Pool(3),
3705                        0,
3706                        None,
3707                        None,
3708                    );
3709
3710                    assert!(!common.is_connected(None).await);
3711                });
3712            }
3713
3714            #[test]
3715            fn with_pool_ignore_close_initiated() {
3716                TOKIO_SHARED_RT.block_on(async {
3717                    let good = create_connection("c1", true, false, false, false).await;
3718                    let closed = create_connection("c2", true, false, false, true).await;
3719                    let bad = create_connection("c3", false, false, false, false).await;
3720                    let common = WebsocketCommon::new(
3721                        vec![closed.clone(), good.clone(), bad.clone()],
3722                        WebsocketMode::Pool(3),
3723                        0,
3724                        None,
3725                        None,
3726                    );
3727
3728                    assert!(common.is_connected(None).await);
3729                    assert!(!common.is_connected(Some(&closed)).await);
3730                });
3731            }
3732        }
3733
3734        mod get_available_connections {
3735            use super::*;
3736
3737            #[test]
3738            fn single_mode() {
3739                TOKIO_SHARED_RT.block_on(async {
3740                    let common = WebsocketCommon::new(vec![], WebsocketMode::Single, 0, None, None);
3741                    let connections = common.get_available_connections(false, None).await;
3742                    assert_eq!(connections[0].id, common.connection_pool[0].id);
3743                });
3744            }
3745
3746            #[test]
3747            fn single_mode_with_url_path_does_not_force_first_connection() {
3748                TOKIO_SHARED_RT.block_on(async {
3749                    let conn1 = WebsocketConnection::new("c1");
3750                    let conn2 = WebsocketConnection::new("c2");
3751
3752                    let (tx2, _rx2) = unbounded_channel();
3753                    {
3754                        let mut s2 = conn2.state.lock().await;
3755                        s2.ws_write_tx = Some(tx2);
3756                        s2.url_path = Some("path1".to_string());
3757                    }
3758
3759                    {
3760                        let mut s1 = conn1.state.lock().await;
3761                        s1.url_path = Some("path1".to_string());
3762                    }
3763
3764                    let pool = vec![conn1.clone(), conn2.clone()];
3765                    let common = WebsocketCommon::new(pool, WebsocketMode::Single, 0, None, None);
3766
3767                    let connections = common.get_available_connections(false, Some("path1")).await;
3768
3769                    assert_eq!(connections.len(), 1);
3770                    assert_eq!(connections[0].id, "c2");
3771                });
3772            }
3773
3774            #[test]
3775            fn pool_mode_not_ready() {
3776                TOKIO_SHARED_RT.block_on(async {
3777                    let common =
3778                        WebsocketCommon::new(vec![], WebsocketMode::Pool(2), 0, None, None);
3779                    let connections = common.get_available_connections(false, None).await;
3780                    assert!(connections.is_empty());
3781                });
3782            }
3783
3784            #[test]
3785            fn pool_mode_with_ready() {
3786                TOKIO_SHARED_RT.block_on(async {
3787                    let conn1 = WebsocketConnection::new("c1");
3788                    let conn2 = WebsocketConnection::new("c2");
3789                    let (tx1, _rx1) = unbounded_channel();
3790                    {
3791                        let mut s1 = conn1.state.lock().await;
3792                        s1.ws_write_tx = Some(tx1);
3793                    }
3794                    let pool = vec![conn1.clone(), conn2.clone()];
3795                    let common = WebsocketCommon::new(pool, WebsocketMode::Pool(2), 0, None, None);
3796                    let connections = common.get_available_connections(false, None).await;
3797                    assert!(connections.len() == 1);
3798                });
3799            }
3800        }
3801
3802        mod get_connection {
3803            use super::*;
3804
3805            #[test]
3806            fn single_mode() {
3807                TOKIO_SHARED_RT.block_on(async {
3808                    let common = WebsocketCommon::new(vec![], WebsocketMode::Single, 0, None, None);
3809                    let conn = common
3810                        .get_connection(false, None)
3811                        .await
3812                        .expect("should get connection");
3813                    assert_eq!(conn.id, common.connection_pool[0].id);
3814                });
3815            }
3816
3817            #[test]
3818            fn single_mode_with_url_path_selects_matching_ready_connection() {
3819                TOKIO_SHARED_RT.block_on(async {
3820                    let conn1 = WebsocketConnection::new("c1");
3821                    let conn2 = WebsocketConnection::new("c2");
3822
3823                    let (tx1, _rx1) = unbounded_channel();
3824                    {
3825                        let mut s1 = conn1.state.lock().await;
3826                        s1.ws_write_tx = Some(tx1);
3827                        s1.url_path = Some("path2".to_string());
3828                    }
3829
3830                    let (tx2, _rx2) = unbounded_channel();
3831                    {
3832                        let mut s2 = conn2.state.lock().await;
3833                        s2.ws_write_tx = Some(tx2);
3834                        s2.url_path = Some("path1".to_string());
3835                    }
3836
3837                    let pool = vec![conn1.clone(), conn2.clone()];
3838                    let common = WebsocketCommon::new(pool, WebsocketMode::Single, 0, None, None);
3839
3840                    let chosen = common
3841                        .get_connection(false, Some("path1"))
3842                        .await
3843                        .expect("should get connection");
3844
3845                    assert_eq!(chosen.id, "c2");
3846                });
3847            }
3848
3849            #[test]
3850            fn pool_mode_not_ready() {
3851                TOKIO_SHARED_RT.block_on(async {
3852                    let common =
3853                        WebsocketCommon::new(vec![], WebsocketMode::Pool(2), 0, None, None);
3854                    let result = common.get_connection(false, None).await;
3855                    assert!(matches!(
3856                        result,
3857                        Err(crate::errors::WebsocketError::NotConnected)
3858                    ));
3859                });
3860            }
3861
3862            #[test]
3863            fn pool_mode_with_ready() {
3864                TOKIO_SHARED_RT.block_on(async {
3865                    let conn1 = WebsocketConnection::new("c1");
3866                    let conn2 = WebsocketConnection::new("c2");
3867                    let (tx1, _rx1) = unbounded_channel();
3868                    {
3869                        let mut s1 = conn1.state.lock().await;
3870                        s1.ws_write_tx = Some(tx1);
3871                    }
3872                    let pool = vec![conn1.clone(), conn2.clone()];
3873                    let common = WebsocketCommon::new(pool, WebsocketMode::Pool(2), 0, None, None);
3874                    let result = common.get_connection(false, None).await;
3875                    assert!(result.is_ok());
3876                    let chosen = result.unwrap();
3877                    assert_eq!(chosen.id, conn1.id);
3878                });
3879            }
3880
3881            #[test]
3882            fn pool_mode_with_url_path_filters_connections() {
3883                TOKIO_SHARED_RT.block_on(async {
3884                    let conn1 = WebsocketConnection::new("c1");
3885                    let conn2 = WebsocketConnection::new("c2");
3886
3887                    let (tx1, _rx1) = unbounded_channel();
3888                    {
3889                        let mut s1 = conn1.state.lock().await;
3890                        s1.ws_write_tx = Some(tx1);
3891                        s1.url_path = Some("path1".to_string());
3892                    }
3893
3894                    let (tx2, _rx2) = unbounded_channel();
3895                    {
3896                        let mut s2 = conn2.state.lock().await;
3897                        s2.ws_write_tx = Some(tx2);
3898                        s2.url_path = Some("path2".to_string());
3899                    }
3900
3901                    let pool = vec![conn1.clone(), conn2.clone()];
3902                    let common = WebsocketCommon::new(pool, WebsocketMode::Pool(2), 0, None, None);
3903
3904                    let chosen = common
3905                        .get_connection(false, Some("path2"))
3906                        .await
3907                        .expect("should pick ready connection for path2");
3908
3909                    assert_eq!(chosen.id, "c2");
3910                });
3911            }
3912
3913            #[test]
3914            fn url_path_no_match_returns_not_connected() {
3915                TOKIO_SHARED_RT.block_on(async {
3916                    let conn1 = WebsocketConnection::new("c1");
3917                    let (tx1, _rx1) = unbounded_channel();
3918                    {
3919                        let mut s1 = conn1.state.lock().await;
3920                        s1.ws_write_tx = Some(tx1);
3921                        s1.url_path = Some("path1".to_string());
3922                    }
3923
3924                    let pool = vec![conn1.clone()];
3925                    let common = WebsocketCommon::new(pool, WebsocketMode::Pool(1), 0, None, None);
3926
3927                    let result = common.get_connection(false, Some("path2")).await;
3928                    assert!(matches!(
3929                        result,
3930                        Err(crate::errors::WebsocketError::NotConnected)
3931                    ));
3932                });
3933            }
3934        }
3935
3936        mod close_connection_gracefully {
3937            use super::*;
3938
3939            #[tokio::test]
3940            async fn waits_for_pending_requests_then_closes() {
3941                pause();
3942
3943                let conn = WebsocketConnection::new("c1");
3944                let (tx, mut rx) = unbounded_channel::<Message>();
3945                let (req_tx, _req_rx) = oneshot::channel();
3946                {
3947                    let mut st = conn.state.lock().await;
3948                    st.pending_requests
3949                        .insert("r".to_string(), PendingRequest { completion: req_tx });
3950                }
3951                let common =
3952                    WebsocketCommon::new(vec![conn.clone()], WebsocketMode::Single, 0, None, None);
3953                let close_fut = common.close_connection_gracefully(tx.clone(), conn.clone());
3954                advance(Duration::from_secs(1)).await;
3955                {
3956                    let mut st = conn.state.lock().await;
3957                    st.pending_requests.clear();
3958                }
3959                conn.drain_notify.notify_waiters();
3960                advance(Duration::from_secs(1)).await;
3961                close_fut.await.unwrap();
3962                match rx.try_recv() {
3963                    Ok(Message::Close(_)) => {}
3964                    other => panic!("expected Close, got {other:?}"),
3965                }
3966
3967                resume();
3968            }
3969
3970            #[tokio::test]
3971            async fn force_closes_after_timeout() {
3972                pause();
3973
3974                let conn = WebsocketConnection::new("c2");
3975                let (tx, mut rx) = unbounded_channel::<Message>();
3976                let (req_tx, _req_rx) = oneshot::channel();
3977                {
3978                    let mut st = conn.state.lock().await;
3979                    st.pending_requests.insert(
3980                        "request_id".to_string(),
3981                        PendingRequest { completion: req_tx },
3982                    );
3983                }
3984                let common =
3985                    WebsocketCommon::new(vec![conn.clone()], WebsocketMode::Single, 0, None, None);
3986                let close_fut = common.close_connection_gracefully(tx.clone(), conn.clone());
3987                advance(Duration::from_secs(30)).await;
3988                close_fut.await.unwrap();
3989                match rx.try_recv() {
3990                    Ok(Message::Close(_)) => {}
3991                    other => panic!("expected Close on timeout, got {other:?}"),
3992                }
3993
3994                resume();
3995            }
3996        }
3997
3998        mod get_reconnect_url {
3999            use super::*;
4000
4001            struct DummyHandler {
4002                url: String,
4003            }
4004
4005            #[async_trait::async_trait]
4006            impl WebsocketHandler for DummyHandler {
4007                async fn on_open(&self, _url: String, _connection: Arc<WebsocketConnection>) {}
4008                async fn on_message(&self, _data: String, _connection: Arc<WebsocketConnection>) {}
4009                async fn get_reconnect_url(
4010                    &self,
4011                    _default_url: String,
4012                    _connection: Arc<WebsocketConnection>,
4013                ) -> String {
4014                    self.url.clone()
4015                }
4016            }
4017
4018            #[test]
4019            fn returns_default_when_no_handler() {
4020                TOKIO_SHARED_RT.block_on(async {
4021                    let conn = WebsocketConnection::new("c1");
4022                    let common = WebsocketCommon::new(
4023                        vec![conn.clone()],
4024                        WebsocketMode::Single,
4025                        0,
4026                        None,
4027                        None,
4028                    );
4029                    let default = "wss://default".to_string();
4030                    let result = common.get_reconnect_url(&default, conn.clone()).await;
4031                    assert_eq!(result, default);
4032                });
4033            }
4034
4035            #[test]
4036            fn returns_handler_url_when_set() {
4037                TOKIO_SHARED_RT.block_on(async {
4038                    let conn = WebsocketConnection::new("c2");
4039                    let handler = Arc::new(DummyHandler {
4040                        url: "wss://custom".into(),
4041                    });
4042                    conn.set_handler(handler).await;
4043                    let common = WebsocketCommon::new(
4044                        vec![conn.clone()],
4045                        WebsocketMode::Single,
4046                        0,
4047                        None,
4048                        None,
4049                    );
4050                    let default = "wss://default".to_string();
4051                    let result = common.get_reconnect_url(&default, conn.clone()).await;
4052                    assert_eq!(result, "wss://custom");
4053                });
4054            }
4055        }
4056
4057        mod on_open {
4058            use super::*;
4059
4060            struct DummyHandler {
4061                called: Arc<Mutex<bool>>,
4062                opened_url: Arc<Mutex<Option<String>>>,
4063            }
4064
4065            #[async_trait]
4066            impl WebsocketHandler for DummyHandler {
4067                async fn on_open(&self, url: String, _connection: Arc<WebsocketConnection>) {
4068                    let mut flag = self.called.lock().await;
4069                    *flag = true;
4070                    let mut store = self.opened_url.lock().await;
4071                    *store = Some(url);
4072                }
4073                async fn on_message(&self, _data: String, _connection: Arc<WebsocketConnection>) {}
4074                async fn get_reconnect_url(
4075                    &self,
4076                    default_url: String,
4077                    _connection: Arc<WebsocketConnection>,
4078                ) -> String {
4079                    default_url
4080                }
4081            }
4082
4083            #[test]
4084            fn emits_open_and_calls_handler() {
4085                TOKIO_SHARED_RT.block_on(async {
4086                    let conn = WebsocketConnection::new("c1");
4087                    let called = Arc::new(Mutex::new(false));
4088                    let opened_url = Arc::new(Mutex::new(None));
4089                    let handler = Arc::new(DummyHandler {
4090                        called: called.clone(),
4091                        opened_url: opened_url.clone(),
4092                    });
4093
4094                    conn.set_handler(handler.clone()).await;
4095                    let common = WebsocketCommon::new(
4096                        vec![conn.clone()],
4097                        WebsocketMode::Single,
4098                        0,
4099                        None,
4100                        None,
4101                    );
4102                    let events = subscribe_events(&common);
4103                    common
4104                        .on_open("wss://example.com".into(), conn.clone(), None)
4105                        .await;
4106
4107                    sleep(std::time::Duration::from_millis(10)).await;
4108
4109                    let evs = events.lock().await;
4110                    assert!(evs.iter().any(|e| matches!(e, WebsocketEvent::Open)));
4111                    assert!(*called.lock().await);
4112                    assert_eq!(
4113                        opened_url.lock().await.as_deref(),
4114                        Some("wss://example.com")
4115                    );
4116                });
4117            }
4118
4119            #[test]
4120            fn handles_renewal_pending_and_closes_old_writer() {
4121                TOKIO_SHARED_RT.block_on(async {
4122                    let conn = WebsocketConnection::new("c2");
4123                    let (old_tx, mut old_rx) = unbounded_channel::<Message>();
4124                    {
4125                        let mut st = conn.state.lock().await;
4126                        st.renewal_pending = true;
4127                    }
4128                    let common = WebsocketCommon::new(
4129                        vec![conn.clone()],
4130                        WebsocketMode::Single,
4131                        0,
4132                        None,
4133                        None,
4134                    );
4135                    common
4136                        .on_open("url".into(), conn.clone(), Some(old_tx.clone()))
4137                        .await;
4138                    assert!(!conn.state.lock().await.renewal_pending);
4139                    match old_rx.try_recv() {
4140                        Ok(Message::Close(_)) => {}
4141                        other => panic!("expected Close, got {other:?}"),
4142                    }
4143                });
4144            }
4145        }
4146
4147        mod on_message {
4148            use super::*;
4149
4150            struct DummyHandler {
4151                called_with: Arc<Mutex<Vec<String>>>,
4152            }
4153
4154            #[async_trait]
4155            impl WebsocketHandler for DummyHandler {
4156                async fn on_open(&self, _url: String, _connection: Arc<WebsocketConnection>) {}
4157                async fn on_message(&self, data: String, _connection: Arc<WebsocketConnection>) {
4158                    self.called_with.lock().await.push(data);
4159                }
4160                async fn get_reconnect_url(
4161                    &self,
4162                    default_url: String,
4163                    _connection: Arc<WebsocketConnection>,
4164                ) -> String {
4165                    default_url
4166                }
4167            }
4168
4169            #[test]
4170            fn emits_message_event_without_handler() {
4171                TOKIO_SHARED_RT.block_on(async {
4172                    let conn = WebsocketConnection::new("c1");
4173                    let common = WebsocketCommon::new(
4174                        vec![conn.clone()],
4175                        WebsocketMode::Single,
4176                        0,
4177                        None,
4178                        None,
4179                    );
4180                    let events = subscribe_events(&common);
4181                    common.on_message("msg".into(), conn.clone()).await;
4182
4183                    sleep(Duration::from_millis(10)).await;
4184
4185                    let locked = events.lock().await;
4186                    assert!(
4187                        locked
4188                            .iter()
4189                            .any(|e| matches!(e, WebsocketEvent::Message(m) if m == "msg"))
4190                    );
4191                });
4192            }
4193
4194            #[test]
4195            fn calls_handler_and_emits_message() {
4196                TOKIO_SHARED_RT.block_on(async {
4197                    let conn = WebsocketConnection::new("c2");
4198                    let called = Arc::new(Mutex::new(Vec::new()));
4199                    let handler = Arc::new(DummyHandler {
4200                        called_with: called.clone(),
4201                    });
4202                    conn.set_handler(handler.clone()).await;
4203
4204                    let common = WebsocketCommon::new(
4205                        vec![conn.clone()],
4206                        WebsocketMode::Single,
4207                        0,
4208                        None,
4209                        None,
4210                    );
4211                    let events = subscribe_events(&common);
4212                    common.on_message("msg".into(), conn.clone()).await;
4213
4214                    sleep(Duration::from_millis(10)).await;
4215
4216                    let evs = events.lock().await;
4217                    assert!(
4218                        evs.iter()
4219                            .any(|e| matches!(e, WebsocketEvent::Message(m) if m == "msg"))
4220                    );
4221                    let msgs = called.lock().await;
4222                    assert_eq!(msgs.as_slice(), &["msg".to_string()]);
4223                });
4224            }
4225
4226            #[test]
4227            fn preserves_message_order() {
4228                TOKIO_SHARED_RT.block_on(async {
4229                    let conn = WebsocketConnection::new("c3");
4230                    let called = Arc::new(Mutex::new(Vec::new()));
4231                    let handler = Arc::new(DummyHandler {
4232                        called_with: called.clone(),
4233                    });
4234                    conn.set_handler(handler.clone()).await;
4235
4236                    let common = WebsocketCommon::new(
4237                        vec![conn.clone()],
4238                        WebsocketMode::Single,
4239                        0,
4240                        None,
4241                        None,
4242                    );
4243
4244                    for i in 0..20 {
4245                        common.on_message(format!("msg_{i}"), conn.clone()).await;
4246                    }
4247
4248                    let msgs = called.lock().await;
4249                    let expected: Vec<String> = (0..20).map(|i| format!("msg_{i}")).collect();
4250                    assert_eq!(msgs.as_slice(), expected.as_slice());
4251                });
4252            }
4253
4254            #[test]
4255            fn preserves_order_with_slow_handler() {
4256                use std::sync::atomic::{AtomicU32, Ordering};
4257
4258                struct SlowHandler {
4259                    received: Arc<Mutex<Vec<String>>>,
4260                    concurrent_count: Arc<AtomicU32>,
4261                    max_concurrent: Arc<AtomicU32>,
4262                }
4263
4264                #[async_trait]
4265                impl WebsocketHandler for SlowHandler {
4266                    async fn on_open(&self, _url: String, _connection: Arc<WebsocketConnection>) {}
4267                    async fn on_message(
4268                        &self,
4269                        data: String,
4270                        _connection: Arc<WebsocketConnection>,
4271                    ) {
4272                        let prev = self.concurrent_count.fetch_add(1, Ordering::SeqCst);
4273                        self.max_concurrent.fetch_max(prev + 1, Ordering::SeqCst);
4274                        sleep(Duration::from_millis(1)).await;
4275                        self.received.lock().await.push(data);
4276                        self.concurrent_count.fetch_sub(1, Ordering::SeqCst);
4277                    }
4278                    async fn get_reconnect_url(
4279                        &self,
4280                        default_url: String,
4281                        _connection: Arc<WebsocketConnection>,
4282                    ) -> String {
4283                        default_url
4284                    }
4285                }
4286
4287                TOKIO_SHARED_RT.block_on(async {
4288                    let conn = WebsocketConnection::new("c4");
4289                    let received = Arc::new(Mutex::new(Vec::new()));
4290                    let concurrent_count = Arc::new(AtomicU32::new(0));
4291                    let max_concurrent = Arc::new(AtomicU32::new(0));
4292                    let handler = Arc::new(SlowHandler {
4293                        received: received.clone(),
4294                        concurrent_count: concurrent_count.clone(),
4295                        max_concurrent: max_concurrent.clone(),
4296                    });
4297                    conn.set_handler(handler.clone()).await;
4298
4299                    let common = WebsocketCommon::new(
4300                        vec![conn.clone()],
4301                        WebsocketMode::Single,
4302                        0,
4303                        None,
4304                        None,
4305                    );
4306
4307                    for i in 0..10 {
4308                        common.on_message(format!("msg_{i}"), conn.clone()).await;
4309                    }
4310
4311                    let msgs = received.lock().await;
4312                    let expected: Vec<String> = (0..10).map(|i| format!("msg_{i}")).collect();
4313                    assert_eq!(msgs.as_slice(), expected.as_slice());
4314                    assert_eq!(
4315                        max_concurrent.load(Ordering::SeqCst),
4316                        1,
4317                        "messages must be processed sequentially, not concurrently"
4318                    );
4319                });
4320            }
4321        }
4322
4323        mod create_websocket {
4324            use super::*;
4325
4326            #[test]
4327            fn successful_connection() {
4328                TOKIO_SHARED_RT.block_on(async {
4329                    let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
4330                    let addr: SocketAddr = listener.local_addr().unwrap();
4331
4332                    let expected_ua = build_user_agent("product");
4333                    let expected_ua_clone = expected_ua.clone();
4334
4335                    tokio::spawn(async move {
4336                        if let Ok((stream, _)) = listener.accept().await {
4337                            let callback = |req: &Request, resp| {
4338                                let got = req
4339                                    .headers()
4340                                    .get(USER_AGENT)
4341                                    .expect("no USER_AGENT header in WS handshake")
4342                                    .to_str()
4343                                    .expect("invalid USER_AGENT header");
4344                                assert_eq!(got, expected_ua_clone, "User-Agent mismatch");
4345                                Ok(resp)
4346                            };
4347                            let _ = accept_hdr_async(stream, callback).await.unwrap();
4348                        }
4349                    });
4350
4351                    let url = format!("ws://{addr}");
4352                    let res =
4353                        WebsocketCommon::create_websocket(&url, None, Some(expected_ua)).await;
4354                    assert!(res.is_ok(), "handshake failed: {res:?}");
4355                });
4356            }
4357
4358            #[test]
4359            fn invalid_url_returns_handshake_error() {
4360                TOKIO_SHARED_RT.block_on(async {
4361                    let res =
4362                        WebsocketCommon::create_websocket("not-a-valid-url", None, None).await;
4363                    assert!(matches!(res, Err(WebsocketError::Handshake(_))));
4364                });
4365            }
4366
4367            #[test]
4368            fn unreachable_host_returns_handshake_error() {
4369                TOKIO_SHARED_RT.block_on(async {
4370                    let res =
4371                        WebsocketCommon::create_websocket("ws://127.0.0.1:1", None, None).await;
4372                    assert!(matches!(res, Err(WebsocketError::Handshake(_))));
4373                });
4374            }
4375        }
4376
4377        mod connect_pool {
4378            use super::*;
4379
4380            #[test]
4381            fn connects_all_in_pool() {
4382                TOKIO_SHARED_RT.block_on(async {
4383                    let pool_size = 3;
4384                    let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
4385                    let addr = listener.local_addr().unwrap();
4386                    tokio::spawn(async move {
4387                        for _ in 0..pool_size {
4388                            if let Ok((stream, _)) = listener.accept().await {
4389                                tokio::spawn(async move {
4390                                    let mut ws = accept_async(stream).await.unwrap();
4391                                    sleep(Duration::from_millis(500)).await;
4392                                    let _ = ws.close(None).await;
4393                                });
4394                            }
4395                        }
4396                    });
4397                    let conns: Vec<Arc<WebsocketConnection>> = (0..pool_size)
4398                        .map(|i| WebsocketConnection::new(format!("c{i}")))
4399                        .collect();
4400                    let common = WebsocketCommon::new(
4401                        conns.clone(),
4402                        WebsocketMode::Pool(pool_size),
4403                        0,
4404                        None,
4405                        None,
4406                    );
4407                    let url = format!("ws://{addr}");
4408                    common.clone().connect_pool(&url, None).await.unwrap();
4409                    for conn in conns {
4410                        let mut ok = false;
4411                        for _ in 0..100 {
4412                            if conn.state.lock().await.ws_write_tx.is_some() {
4413                                ok = true;
4414                                break;
4415                            }
4416                            sleep(Duration::from_millis(50)).await;
4417                        }
4418                        assert!(ok, "expected ws_write_tx Some after connect");
4419                    }
4420                });
4421            }
4422
4423            #[test]
4424            fn fails_if_any_refused() {
4425                TOKIO_SHARED_RT.block_on(async {
4426                    let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
4427                    let addr = listener.local_addr().unwrap();
4428                    let pool_size = 3;
4429                    tokio::spawn(async move {
4430                        for _ in 0..2 {
4431                            if let Ok((stream, _)) = listener.accept().await {
4432                                let mut ws = accept_async(stream).await.unwrap();
4433                                let _ = ws.close(None).await;
4434                            }
4435                        }
4436                    });
4437                    let mut conns = Vec::new();
4438                    let valid_url = format!("ws://{addr}");
4439                    for i in 0..2 {
4440                        conns.push(WebsocketConnection::new(format!("c{i}")));
4441                    }
4442                    conns.push(WebsocketConnection::new("bad"));
4443                    let common = WebsocketCommon::new(
4444                        conns.clone(),
4445                        WebsocketMode::Pool(pool_size),
4446                        0,
4447                        None,
4448                        None,
4449                    );
4450                    let res = common.clone().connect_pool(&valid_url, None).await;
4451                    assert!(matches!(res, Err(WebsocketError::Handshake(_))));
4452                });
4453            }
4454
4455            #[test]
4456            fn fails_on_invalid_url() {
4457                TOKIO_SHARED_RT.block_on(async {
4458                    let conns = vec![WebsocketConnection::new("c1")];
4459                    let common = WebsocketCommon::new(conns, WebsocketMode::Pool(1), 0, None, None);
4460                    let res = common.connect_pool("not-a-url", None).await;
4461                    assert!(matches!(res, Err(WebsocketError::Handshake(_))));
4462                });
4463            }
4464
4465            #[test]
4466            fn fails_if_mixed_success_and_invalid_url() {
4467                TOKIO_SHARED_RT.block_on(async {
4468                    let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
4469                    let addr = listener.local_addr().unwrap();
4470                    tokio::spawn(async move {
4471                        if let Ok((stream, _)) = listener.accept().await {
4472                            let mut ws = accept_async(stream).await.unwrap();
4473                            let _ = ws.close(None).await;
4474                        }
4475                    });
4476                    let good = WebsocketConnection::new("good");
4477                    let bad = WebsocketConnection::new("bad");
4478                    let common = WebsocketCommon::new(
4479                        vec![good, bad],
4480                        WebsocketMode::Pool(2),
4481                        0,
4482                        None,
4483                        None,
4484                    );
4485                    let url = format!("ws://{addr}");
4486                    let res = common.connect_pool(&url, None).await;
4487                    assert!(matches!(res, Err(WebsocketError::Handshake(_))));
4488                });
4489            }
4490
4491            #[test]
4492            fn init_connect_invoked_for_each() {
4493                TOKIO_SHARED_RT.block_on(async {
4494                    let pool_size = 2;
4495                    let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
4496                    let addr = listener.local_addr().unwrap();
4497                    tokio::spawn(async move {
4498                        for _ in 0..pool_size {
4499                            if let Ok((stream, _)) = listener.accept().await {
4500                                tokio::spawn(async move {
4501                                    let mut ws = accept_async(stream).await.unwrap();
4502                                    sleep(Duration::from_millis(500)).await;
4503                                    let _ = ws.close(None).await;
4504                                });
4505                            }
4506                        }
4507                    });
4508                    let conns: Vec<Arc<WebsocketConnection>> = (0..pool_size)
4509                        .map(|i| WebsocketConnection::new(format!("c{i}")))
4510                        .collect();
4511                    let common = WebsocketCommon::new(
4512                        conns.clone(),
4513                        WebsocketMode::Pool(pool_size),
4514                        0,
4515                        None,
4516                        None,
4517                    );
4518                    let url = format!("ws://{addr}");
4519                    common.clone().connect_pool(&url, None).await.unwrap();
4520                    for conn in conns {
4521                        let mut ok = false;
4522                        for _ in 0..100 {
4523                            if conn.state.lock().await.ws_write_tx.is_some() {
4524                                ok = true;
4525                                break;
4526                            }
4527                            sleep(Duration::from_millis(25)).await;
4528                        }
4529                        assert!(ok, "expected ws_write_tx Some for {}", conn.id);
4530                    }
4531                });
4532            }
4533
4534            #[test]
4535            fn single_mode_uses_first_connection() {
4536                TOKIO_SHARED_RT.block_on(async {
4537                    let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
4538                    let addr = listener.local_addr().unwrap();
4539                    let _listener_guard = spawn_mock_ws_listener(listener);
4540                    let conn = WebsocketConnection::new("c1");
4541                    let common = WebsocketCommon::new(
4542                        vec![conn.clone()],
4543                        WebsocketMode::Single,
4544                        0,
4545                        None,
4546                        None,
4547                    );
4548                    let url = format!("ws://{addr}");
4549                    common.connect_pool(&url, None).await.unwrap();
4550                    let ok = eventually_async(Duration::from_secs(5), || {
4551                        let conn = conn.clone();
4552                        async move { conn.state.lock().await.ws_write_tx.is_some() }
4553                    })
4554                    .await;
4555
4556                    assert!(ok, "single mode did not select first connection");
4557                });
4558            }
4559
4560            #[test]
4561            fn empty_subset_is_ok_and_connects_none() {
4562                TOKIO_SHARED_RT.block_on(async {
4563                    let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
4564                    let addr = listener.local_addr().unwrap();
4565
4566                    tokio::spawn(async move {
4567                        let _ = addr;
4568                    });
4569
4570                    let c1 = WebsocketConnection::new("c1");
4571                    let c2 = WebsocketConnection::new("c2");
4572
4573                    let common = WebsocketCommon::new(
4574                        vec![c1.clone(), c2.clone()],
4575                        WebsocketMode::Pool(2),
4576                        0,
4577                        None,
4578                        None,
4579                    );
4580
4581                    let url = format!("ws://{addr}");
4582                    common
4583                        .clone()
4584                        .connect_pool(&url, Some(vec![]))
4585                        .await
4586                        .unwrap();
4587
4588                    assert!(c1.state.lock().await.ws_write_tx.is_none());
4589                    assert!(c2.state.lock().await.ws_write_tx.is_none());
4590                });
4591            }
4592        }
4593
4594        mod init_connect {
4595            use super::*;
4596            use std::panic;
4597            use tokio::sync::mpsc::{channel, error::TryRecvError};
4598
4599            #[test]
4600            fn pool_mode_none_connection_uses_first() {
4601                TOKIO_SHARED_RT.block_on(async {
4602                    let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
4603                    let addr = listener.local_addr().unwrap();
4604                    tokio::spawn(async move {
4605                        for _ in 0..2 {
4606                            if let Ok((stream, _)) = listener.accept().await {
4607                                let mut ws = accept_async(stream).await.unwrap();
4608                                ws.close(None).await.ok();
4609                            }
4610                        }
4611                    });
4612
4613                    let c1 = WebsocketConnection::new("c1");
4614                    let c2 = WebsocketConnection::new("c2");
4615                    let common = WebsocketCommon::new(
4616                        vec![c1.clone(), c2.clone()],
4617                        WebsocketMode::Pool(2),
4618                        0,
4619                        None,
4620                        None,
4621                    );
4622                    let url = format!("ws://{addr}");
4623
4624                    common
4625                        .clone()
4626                        .init_connect(&url, false, None)
4627                        .await
4628                        .unwrap();
4629
4630                    let ok = eventually_async(Duration::from_secs(5), || {
4631                        let conn1 = c1.clone();
4632                        async move { conn1.state.lock().await.ws_write_tx.is_some() }
4633                    })
4634                    .await;
4635
4636                    assert!(ok, "first connection was never selected");
4637                    assert!(c2.state.lock().await.ws_write_tx.is_none());
4638                });
4639            }
4640
4641            #[test]
4642            fn writer_channel_can_send_text() {
4643                TOKIO_SHARED_RT.block_on(async {
4644                    let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
4645                    let addr = listener.local_addr().unwrap();
4646                    let received = Arc::new(Mutex::new(None::<String>));
4647                    let received_clone = received.clone();
4648
4649                    tokio::spawn(async move {
4650                        if let Ok((stream, _)) = listener.accept().await {
4651                            let mut ws = accept_async(stream).await.unwrap();
4652                            if let Some(Ok(Message::Text(txt))) = ws.next().await {
4653                                *received_clone.lock().await = Some(txt.to_string());
4654                            }
4655                            ws.close(None).await.ok();
4656                        }
4657                    });
4658
4659                    let conn = WebsocketConnection::new("cw");
4660                    let common = WebsocketCommon::new(
4661                        vec![conn.clone()],
4662                        WebsocketMode::Single,
4663                        0,
4664                        None,
4665                        None,
4666                    );
4667                    let url = format!("ws://{addr}");
4668                    common
4669                        .clone()
4670                        .init_connect(&url, false, Some(conn.clone()))
4671                        .await
4672                        .unwrap();
4673
4674                    let tx = conn.state.lock().await.ws_write_tx.clone().unwrap();
4675                    tx.send(Message::Text("ping".into())).unwrap();
4676
4677                    sleep(Duration::from_millis(50)).await;
4678
4679                    let lock = received.lock().await;
4680                    assert_eq!(lock.as_deref(), Some("ping"));
4681                });
4682            }
4683
4684            #[test]
4685            fn does_not_skip_when_reconnection_pending_even_if_writer_exists() {
4686                TOKIO_SHARED_RT.block_on(async {
4687                    let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
4688                    let addr = listener.local_addr().unwrap();
4689                    let _listener_guard = spawn_mock_ws_listener(listener);
4690
4691                    let conn = WebsocketConnection::new("c-reconnect");
4692                    {
4693                        let mut st = conn.state.lock().await;
4694                        let (tx, _) = unbounded_channel::<Message>();
4695                        st.ws_write_tx = Some(tx);
4696                        st.reconnection_pending = true;
4697                    }
4698
4699                    let common = WebsocketCommon::new(
4700                        vec![conn.clone()],
4701                        WebsocketMode::Single,
4702                        0,
4703                        None,
4704                        None,
4705                    );
4706
4707                    let url = format!("ws://{addr}");
4708                    common
4709                        .clone()
4710                        .init_connect(&url, false, Some(conn.clone()))
4711                        .await
4712                        .unwrap();
4713
4714                    let st = conn.state.lock().await;
4715                    assert!(
4716                        st.ws_write_tx.is_some(),
4717                        "writer should be set after connect"
4718                    );
4719                    assert!(
4720                        !st.reconnection_pending,
4721                        "reconnection_pending should be cleared after successful connect"
4722                    );
4723                });
4724            }
4725
4726            #[test]
4727            fn responds_to_ping_with_pong() {
4728                TOKIO_SHARED_RT.block_on(async {
4729                    let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
4730                    let addr = listener.local_addr().unwrap();
4731
4732                    let saw_pong = Arc::new(Mutex::new(false));
4733                    let saw_pong2 = saw_pong.clone();
4734
4735                    tokio::spawn(async move {
4736                        if let Ok((stream, _)) = listener.accept().await {
4737                            let mut ws = accept_async(stream).await.unwrap();
4738                            ws.send(Message::Ping(vec![1, 2, 3].into())).await.unwrap();
4739                            if let Some(Ok(Message::Pong(payload))) = ws.next().await {
4740                                if payload[..] == [1, 2, 3] {
4741                                    *saw_pong2.lock().await = true;
4742                                }
4743                            }
4744                            let _ = ws.close(None).await;
4745                        }
4746                    });
4747
4748                    let conn = WebsocketConnection::new("c-ping");
4749                    let common = WebsocketCommon::new(
4750                        vec![conn.clone()],
4751                        WebsocketMode::Single,
4752                        0,
4753                        None,
4754                        None,
4755                    );
4756                    let url = format!("ws://{addr}");
4757                    common
4758                        .clone()
4759                        .init_connect(&url, false, Some(conn))
4760                        .await
4761                        .unwrap();
4762
4763                    sleep(Duration::from_millis(50)).await;
4764
4765                    assert!(*saw_pong.lock().await, "server should have seen a Pong");
4766                });
4767            }
4768
4769            #[test]
4770            fn handshake_error_on_invalid_url() {
4771                TOKIO_SHARED_RT.block_on(async {
4772                    let conn = WebsocketConnection::new("c-invalid");
4773                    let common = WebsocketCommon::new(
4774                        vec![conn.clone()],
4775                        WebsocketMode::Single,
4776                        0,
4777                        None,
4778                        None,
4779                    );
4780                    let res = common
4781                        .clone()
4782                        .init_connect("not-a-url", false, Some(conn.clone()))
4783                        .await;
4784                    assert!(matches!(res, Err(WebsocketError::Handshake(_))));
4785                });
4786            }
4787
4788            #[test]
4789            fn skip_if_writer_exists_and_not_renewal() {
4790                TOKIO_SHARED_RT.block_on(async {
4791                    let conn = WebsocketConnection::new("c-writer");
4792                    let (tx, mut rx) = unbounded_channel::<Message>();
4793                    {
4794                        let mut st = conn.state.lock().await;
4795                        st.ws_write_tx = Some(tx.clone());
4796                    }
4797                    let common = WebsocketCommon::new(
4798                        vec![conn.clone()],
4799                        WebsocketMode::Single,
4800                        0,
4801                        None,
4802                        None,
4803                    );
4804                    let res = common
4805                        .clone()
4806                        .init_connect("ws://127.0.0.1:1", false, Some(conn.clone()))
4807                        .await;
4808
4809                    assert!(res.is_ok());
4810                    assert!(rx.try_recv().is_err());
4811                });
4812            }
4813
4814            #[test]
4815            fn short_circuit_on_already_renewing() {
4816                TOKIO_SHARED_RT.block_on(async {
4817                    let conn = WebsocketConnection::new("c-renew");
4818                    {
4819                        let mut st = conn.state.lock().await;
4820                        st.renewal_pending = true;
4821                    }
4822                    let common = WebsocketCommon::new(
4823                        vec![conn.clone()],
4824                        WebsocketMode::Single,
4825                        0,
4826                        None,
4827                        None,
4828                    );
4829                    let res = common
4830                        .clone()
4831                        .init_connect("ws://127.0.0.1:1", true, Some(conn.clone()))
4832                        .await;
4833
4834                    assert!(res.is_ok());
4835                    assert!(conn.state.lock().await.ws_write_tx.is_none());
4836                });
4837            }
4838
4839            #[test]
4840            fn is_renewal_true_sets_and_clears_flag() {
4841                struct GatedHandler {
4842                    gate: Mutex<Option<oneshot::Receiver<()>>>,
4843                }
4844                #[async_trait]
4845                impl WebsocketHandler for GatedHandler {
4846                    async fn on_open(&self, _url: String, _connection: Arc<WebsocketConnection>) {
4847                        if let Some(rx) = self.gate.lock().await.take() {
4848                            let _ = rx.await;
4849                        }
4850                    }
4851                    async fn on_message(
4852                        &self,
4853                        _data: String,
4854                        _connection: Arc<WebsocketConnection>,
4855                    ) {
4856                    }
4857                    async fn get_reconnect_url(
4858                        &self,
4859                        url: String,
4860                        _connection: Arc<WebsocketConnection>,
4861                    ) -> String {
4862                        url
4863                    }
4864                }
4865
4866                TOKIO_SHARED_RT.block_on(async {
4867                    let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
4868                    let addr = listener.local_addr().unwrap();
4869                    let _listener_guard = spawn_mock_ws_listener(listener);
4870
4871                    let conn = WebsocketConnection::new("c-new-renew");
4872                    let (gate_tx, gate_rx) = oneshot::channel();
4873                    conn.set_handler(Arc::new(GatedHandler {
4874                        gate: Mutex::new(Some(gate_rx)),
4875                    }))
4876                    .await;
4877
4878                    let common = WebsocketCommon::new(
4879                        vec![conn.clone()],
4880                        WebsocketMode::Single,
4881                        0,
4882                        None,
4883                        None,
4884                    );
4885                    let url = format!("ws://{addr}");
4886                    let res = common
4887                        .clone()
4888                        .init_connect(&url, true, Some(conn.clone()))
4889                        .await;
4890
4891                    assert!(res.is_ok());
4892
4893                    {
4894                        let st = conn.state.lock().await;
4895                        assert!(st.ws_write_tx.is_some(), "writer should be set");
4896                        assert!(
4897                            st.renewal_pending,
4898                            "renewal_pending must be true until on_open"
4899                        );
4900                    }
4901
4902                    let _ = gate_tx.send(());
4903
4904                    let ok = eventually_async(Duration::from_secs(2), || {
4905                        let conn = conn.clone();
4906                        async move { !conn.state.lock().await.renewal_pending }
4907                    })
4908                    .await;
4909                    assert!(ok, "renewal_pending should be cleared in on_open");
4910                    let st = conn.state.lock().await;
4911                    assert!(
4912                        !st.renewal_pending,
4913                        "renewal_pending should be cleared in on_open"
4914                    );
4915                });
4916            }
4917
4918            #[test]
4919            fn does_not_schedule_reconnect_on_renewal() {
4920                TOKIO_SHARED_RT.block_on(async {
4921                    let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
4922                    let addr = listener.local_addr().unwrap();
4923                    let server_handle = tokio::spawn(async move {
4924                        if let Ok((stream, _)) = listener.accept().await {
4925                            let _ws = accept_async(stream).await.unwrap();
4926                            if let Ok((stream, _)) = listener.accept().await {
4927                                let _ws = accept_async(stream).await.unwrap();
4928                            }
4929                            sleep(Duration::from_secs(10)).await;
4930                        }
4931                    });
4932
4933                    let conn = WebsocketConnection::new("c-needs-renew");
4934                    let (reconnect_tx, mut reconnect_rx) = channel::<ReconnectEntry>(1);
4935                    let (renewal_tx, _renewal_rx) = channel::<(String, String)>(1);
4936                    let common = Arc::new(WebsocketCommon {
4937                        events: WebsocketEventEmitter::new(),
4938                        mode: WebsocketMode::Single,
4939                        round_robin_index: AtomicUsize::new(0),
4940                        connection_pool: vec![conn.clone()],
4941                        reconnect_tx,
4942                        renewal_tx,
4943                        reconnect_delay: 0,
4944                        agent: None,
4945                        user_agent: None,
4946                    });
4947                    let url = format!("ws://{addr}");
4948                    let res = common
4949                        .clone()
4950                        .init_connect(&url, false, Some(conn.clone()))
4951                        .await;
4952
4953                    assert!(res.is_ok());
4954
4955                    {
4956                        let st = conn.state.lock().await;
4957                        assert!(st.ws_write_tx.is_some(), "writer should be set");
4958                    }
4959
4960                    common
4961                        .clone()
4962                        .init_connect(&url, true, Some(conn.clone()))
4963                        .await
4964                        .expect("Renewal init_connect should succeed");
4965
4966                    sleep(Duration::from_millis(1000)).await;
4967
4968                    match reconnect_rx.try_recv() {
4969                        Err(TryRecvError::Empty) => {}
4970                        Ok(_) => panic!("Received reconnection request on renewal"),
4971                        Err(TryRecvError::Disconnected) => {
4972                            panic!("Sender for reconnection_rx disconnected")
4973                        }
4974                    }
4975                    server_handle.abort();
4976                });
4977            }
4978
4979            #[test]
4980            fn default_connection_selected_when_none_passed() {
4981                TOKIO_SHARED_RT.block_on(async {
4982                    let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
4983                    let addr = listener.local_addr().unwrap();
4984                    let _listener_guard = spawn_mock_ws_listener(listener);
4985                    let conn = WebsocketConnection::new("c-default");
4986                    let common = WebsocketCommon::new(
4987                        vec![conn.clone()],
4988                        WebsocketMode::Single,
4989                        0,
4990                        None,
4991                        None,
4992                    );
4993                    let url = format!("ws://{addr}");
4994                    let res = common.clone().init_connect(&url, false, None).await;
4995
4996                    assert!(res.is_ok());
4997                    let ok = eventually_async(Duration::from_secs(5), || {
4998                        let conn = conn.clone();
4999                        async move { conn.state.lock().await.ws_write_tx.is_some() }
5000                    })
5001                    .await;
5002
5003                    assert!(ok, "default connection was never selected");
5004                });
5005            }
5006
5007            #[test]
5008            fn schedules_reconnect_on_abnormal_close() {
5009                TOKIO_SHARED_RT.block_on(async {
5010                    let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
5011                    let addr = listener.local_addr().unwrap();
5012                    tokio::spawn(async move {
5013                        if let Ok((stream, _)) = listener.accept().await {
5014                            let ws = accept_async(stream).await.unwrap();
5015
5016                            // Explicitly sending an AbnormalClose (1006) is not allowed by the websocket
5017                            // protocol and will be converted into Protocol (1002), which the client receives.
5018                            //
5019                            // Abruptly dropping the connection simulates a real abnormal close, which is
5020                            // handled as an error in the websocket protocol
5021                            drop(ws);
5022                        }
5023                    });
5024                    let conn = WebsocketConnection::new("c-close");
5025                    let common = WebsocketCommon::new(
5026                        vec![conn.clone()],
5027                        WebsocketMode::Single,
5028                        5_000,
5029                        None,
5030                        None,
5031                    );
5032                    let url = format!("ws://{addr}");
5033                    common
5034                        .clone()
5035                        .init_connect(&url, false, Some(conn.clone()))
5036                        .await
5037                        .unwrap();
5038
5039                    sleep(Duration::from_millis(50)).await;
5040
5041                    let st = conn.state.lock().await;
5042                    assert!(
5043                        st.reconnection_pending,
5044                        "expected reconnection_pending to be true after abnormal close"
5045                    );
5046                    assert!(
5047                        st.ws_write_tx.is_none(),
5048                        "ws_write_tx should be cleared when scheduling a reconnect"
5049                    );
5050                });
5051            }
5052        }
5053
5054        mod disconnect {
5055            use super::*;
5056
5057            #[test]
5058            fn returns_ok_when_no_connections_are_ready() {
5059                TOKIO_SHARED_RT.block_on(async {
5060                    let conn = WebsocketConnection::new("c1");
5061                    let common = WebsocketCommon::new(
5062                        vec![conn.clone()],
5063                        WebsocketMode::Single,
5064                        0,
5065                        None,
5066                        None,
5067                    );
5068                    let res = common.disconnect().await;
5069
5070                    assert!(res.is_ok());
5071                    assert!(!conn.state.lock().await.close_initiated);
5072                });
5073            }
5074
5075            #[test]
5076            fn closes_all_ready_connections() {
5077                TOKIO_SHARED_RT.block_on(async {
5078                    let conn1 = WebsocketConnection::new("c1");
5079                    let conn2 = WebsocketConnection::new("c2");
5080                    let (tx1, mut rx1) = unbounded_channel::<Message>();
5081                    let (tx2, mut rx2) = unbounded_channel::<Message>();
5082                    {
5083                        let mut s1 = conn1.state.lock().await;
5084                        s1.ws_write_tx = Some(tx1);
5085                    }
5086                    {
5087                        let mut s2 = conn2.state.lock().await;
5088                        s2.ws_write_tx = Some(tx2);
5089                    }
5090                    let common = WebsocketCommon::new(
5091                        vec![conn1.clone(), conn2.clone()],
5092                        WebsocketMode::Pool(2),
5093                        0,
5094                        None,
5095                        None,
5096                    );
5097                    let fut = common.disconnect();
5098
5099                    sleep(Duration::from_millis(50)).await;
5100
5101                    fut.await.unwrap();
5102
5103                    assert!(conn1.state.lock().await.close_initiated);
5104                    assert!(conn2.state.lock().await.close_initiated);
5105
5106                    {
5107                        let st = conn1.state.lock().await;
5108                        assert!(!st.is_session_logged_on, "conn1 should be logged out");
5109                        assert!(st.session_logon_req.is_none(), "conn1 req cleared");
5110                    }
5111                    {
5112                        let st = conn2.state.lock().await;
5113                        assert!(!st.is_session_logged_on, "conn2 should be logged out");
5114                        assert!(st.session_logon_req.is_none(), "conn2 req cleared");
5115                    }
5116
5117                    match (rx1.try_recv(), rx2.try_recv()) {
5118                        (Ok(Message::Close(_)), Ok(Message::Close(_))) => {}
5119                        other => panic!("expected two Close frames, got {other:?}"),
5120                    }
5121                });
5122            }
5123
5124            #[test]
5125            fn does_not_mark_close_initiated_if_no_writer() {
5126                TOKIO_SHARED_RT.block_on(async {
5127                    let conn = WebsocketConnection::new("c-new");
5128                    let common = WebsocketCommon::new(
5129                        vec![conn.clone()],
5130                        WebsocketMode::Single,
5131                        0,
5132                        None,
5133                        None,
5134                    );
5135                    common.disconnect().await.unwrap();
5136
5137                    assert!(!conn.state.lock().await.close_initiated);
5138                });
5139            }
5140
5141            #[test]
5142            fn mixed_pool_marks_all_and_closes_only_writers() {
5143                TOKIO_SHARED_RT.block_on(async {
5144                    let conn_w = WebsocketConnection::new("with");
5145                    let conn_wo = WebsocketConnection::new("without");
5146                    let (tx, mut rx) = unbounded_channel::<Message>();
5147                    {
5148                        let mut st = conn_w.state.lock().await;
5149                        st.ws_write_tx = Some(tx);
5150                    }
5151                    let common = WebsocketCommon::new(
5152                        vec![conn_w.clone(), conn_wo.clone()],
5153                        WebsocketMode::Pool(2),
5154                        0,
5155                        None,
5156                        None,
5157                    );
5158                    let fut = common.disconnect();
5159
5160                    sleep(Duration::from_millis(50)).await;
5161
5162                    fut.await.unwrap();
5163
5164                    assert!(conn_w.state.lock().await.close_initiated);
5165                    assert!(conn_wo.state.lock().await.close_initiated);
5166                    assert!(matches!(rx.try_recv(), Ok(Message::Close(_))));
5167                });
5168            }
5169
5170            #[test]
5171            fn after_disconnect_not_connected() {
5172                TOKIO_SHARED_RT.block_on(async {
5173                    let conn = WebsocketConnection::new("c1");
5174                    let (tx, mut _rx) = unbounded_channel::<Message>();
5175                    {
5176                        let mut st = conn.state.lock().await;
5177                        st.ws_write_tx = Some(tx);
5178                    }
5179                    let common = WebsocketCommon::new(
5180                        vec![conn.clone()],
5181                        WebsocketMode::Single,
5182                        0,
5183                        None,
5184                        None,
5185                    );
5186                    common.disconnect().await.unwrap();
5187                    assert!(!common.is_connected(Some(&conn)).await);
5188                });
5189            }
5190        }
5191
5192        mod ping_server {
5193            use super::*;
5194
5195            #[test]
5196            fn sends_ping_to_all_ready_connections() {
5197                TOKIO_SHARED_RT.block_on(async {
5198                    let mut conns = Vec::new();
5199                    for i in 0..3 {
5200                        let conn = WebsocketConnection::new(format!("c{i}"));
5201                        let (tx, rx) = unbounded_channel::<Message>();
5202                        {
5203                            let mut st = conn.state.lock().await;
5204                            st.ws_write_tx = Some(tx);
5205                        }
5206                        conns.push((conn, rx));
5207                    }
5208                    let common = WebsocketCommon::new(
5209                        conns.iter().map(|(c, _)| c.clone()).collect(),
5210                        WebsocketMode::Pool(3),
5211                        0,
5212                        None,
5213                        None,
5214                    );
5215                    common.ping_server().await;
5216                    for (_, mut rx) in conns {
5217                        match rx.try_recv() {
5218                            Ok(Message::Ping(payload)) if payload.is_empty() => {}
5219                            other => panic!("expected empty-payload Ping, got {other:?}"),
5220                        }
5221                    }
5222                });
5223            }
5224
5225            #[test]
5226            fn skips_not_ready_and_partial() {
5227                TOKIO_SHARED_RT.block_on(async {
5228                    let ready = WebsocketConnection::new("ready");
5229                    let not_ready = WebsocketConnection::new("not-ready");
5230                    let (tx_r, mut rx_r) = unbounded_channel::<Message>();
5231                    {
5232                        let mut st = ready.state.lock().await;
5233                        st.ws_write_tx = Some(tx_r);
5234                    }
5235                    {
5236                        let mut st = not_ready.state.lock().await;
5237                        st.ws_write_tx = None;
5238                    }
5239                    let common = WebsocketCommon::new(
5240                        vec![ready.clone(), not_ready.clone()],
5241                        WebsocketMode::Pool(2),
5242                        0,
5243                        None,
5244                        None,
5245                    );
5246                    common.ping_server().await;
5247                    match rx_r.try_recv() {
5248                        Ok(Message::Ping(payload)) if payload.is_empty() => {}
5249                        other => panic!("expected Ping on ready, got {other:?}"),
5250                    }
5251                });
5252            }
5253
5254            #[test]
5255            fn no_ping_when_flags_block() {
5256                TOKIO_SHARED_RT.block_on(async {
5257                    let conn = WebsocketConnection::new("c1");
5258                    let (tx, mut rx) = unbounded_channel::<Message>();
5259                    {
5260                        let mut st = conn.state.lock().await;
5261                        st.ws_write_tx = Some(tx);
5262                        st.reconnection_pending = true;
5263                    }
5264                    let common = WebsocketCommon::new(
5265                        vec![conn.clone()],
5266                        WebsocketMode::Single,
5267                        0,
5268                        None,
5269                        None,
5270                    );
5271                    common.ping_server().await;
5272                    assert!(rx.try_recv().is_err());
5273                });
5274            }
5275        }
5276
5277        mod send {
5278            use super::*;
5279
5280            #[test]
5281            fn round_robin_send_without_specific() {
5282                TOKIO_SHARED_RT.block_on(async {
5283                    let conn1 = WebsocketConnection::new("c1");
5284                    let conn2 = WebsocketConnection::new("c2");
5285                    let (tx1, mut rx1) = unbounded_channel::<Message>();
5286                    let (tx2, mut rx2) = unbounded_channel::<Message>();
5287                    {
5288                        let mut s1 = conn1.state.lock().await;
5289                        s1.ws_write_tx = Some(tx1);
5290                    }
5291                    {
5292                        let mut s2 = conn2.state.lock().await;
5293                        s2.ws_write_tx = Some(tx2);
5294                    }
5295                    let common = WebsocketCommon::new(
5296                        vec![conn1.clone(), conn2.clone()],
5297                        WebsocketMode::Pool(2),
5298                        0,
5299                        None,
5300                        None,
5301                    );
5302
5303                    let res1 = common
5304                        .send("a".into(), None, false, Duration::from_secs(1), None)
5305                        .await
5306                        .unwrap();
5307                    assert!(res1.is_none());
5308
5309                    let res2 = common
5310                        .send("b".into(), None, false, Duration::from_secs(1), None)
5311                        .await
5312                        .unwrap();
5313                    assert!(res2.is_none());
5314
5315                    assert_eq!(
5316                        if let Message::Text(t) = rx1.try_recv().unwrap() {
5317                            t
5318                        } else {
5319                            panic!()
5320                        },
5321                        "a"
5322                    );
5323                    assert_eq!(
5324                        if let Message::Text(t) = rx2.try_recv().unwrap() {
5325                            t
5326                        } else {
5327                            panic!()
5328                        },
5329                        "b"
5330                    );
5331                });
5332            }
5333
5334            #[test]
5335            fn round_robin_skips_not_ready() {
5336                TOKIO_SHARED_RT.block_on(async {
5337                    let conn1 = WebsocketConnection::new("c1");
5338                    let conn2 = WebsocketConnection::new("c2");
5339                    let (tx2, mut rx2) = unbounded_channel::<Message>();
5340                    {
5341                        let mut s1 = conn1.state.lock().await;
5342                        s1.ws_write_tx = None;
5343                    }
5344                    {
5345                        let mut s2 = conn2.state.lock().await;
5346                        s2.ws_write_tx = Some(tx2);
5347                    }
5348                    let common = WebsocketCommon::new(
5349                        vec![conn1.clone(), conn2.clone()],
5350                        WebsocketMode::Pool(2),
5351                        0,
5352                        None,
5353                        None,
5354                    );
5355                    let res = common
5356                        .send("bar".into(), None, false, Duration::from_secs(1), None)
5357                        .await
5358                        .unwrap();
5359                    assert!(res.is_none());
5360                    match rx2.try_recv().unwrap() {
5361                        Message::Text(t) => assert_eq!(t, "bar"),
5362                        other => panic!("unexpected {other:?}"),
5363                    }
5364                });
5365            }
5366
5367            #[test]
5368            fn sync_send_on_specific_connection() {
5369                TOKIO_SHARED_RT.block_on(async {
5370                    let conn1 = WebsocketConnection::new("c1");
5371                    let conn2 = WebsocketConnection::new("c2");
5372                    let (tx2, mut rx2) = unbounded_channel::<Message>();
5373                    {
5374                        let mut st = conn2.state.lock().await;
5375                        st.ws_write_tx = Some(tx2);
5376                    }
5377                    let common = WebsocketCommon::new(
5378                        vec![conn1.clone(), conn2.clone()],
5379                        WebsocketMode::Pool(2),
5380                        0,
5381                        None,
5382                        None,
5383                    );
5384                    let res = common
5385                        .send(
5386                            "payload".into(),
5387                            Some("id".into()),
5388                            false,
5389                            Duration::from_secs(1),
5390                            Some(conn2.clone()),
5391                        )
5392                        .await
5393                        .unwrap();
5394                    assert!(res.is_none());
5395                    match rx2.try_recv() {
5396                        Ok(Message::Text(t)) => assert_eq!(t, "payload"),
5397                        other => panic!("expected Text, got {other:?}"),
5398                    }
5399                });
5400            }
5401
5402            #[test]
5403            fn sync_send_with_id_does_not_insert_pending() {
5404                TOKIO_SHARED_RT.block_on(async {
5405                    let conn = WebsocketConnection::new("c1");
5406                    let (tx, mut rx) = unbounded_channel::<Message>();
5407                    {
5408                        let mut st = conn.state.lock().await;
5409                        st.ws_write_tx = Some(tx);
5410                    }
5411                    let common = WebsocketCommon::new(
5412                        vec![conn.clone()],
5413                        WebsocketMode::Single,
5414                        0,
5415                        None,
5416                        None,
5417                    );
5418                    let res = common
5419                        .send(
5420                            "msg".into(),
5421                            Some("id".into()),
5422                            false,
5423                            Duration::from_secs(1),
5424                            Some(conn.clone()),
5425                        )
5426                        .await
5427                        .unwrap();
5428                    assert!(res.is_none());
5429                    assert!(conn.state.lock().await.pending_requests.is_empty());
5430                    match rx.try_recv().unwrap() {
5431                        Message::Text(t) => assert_eq!(t, "msg"),
5432                        other => panic!("unexpected {other:?}"),
5433                    }
5434                });
5435            }
5436
5437            #[test]
5438            fn sync_send_error_if_not_ready() {
5439                TOKIO_SHARED_RT.block_on(async {
5440                    let conn = WebsocketConnection::new("c1");
5441                    let common = WebsocketCommon::new(
5442                        vec![conn.clone()],
5443                        WebsocketMode::Single,
5444                        0,
5445                        None,
5446                        None,
5447                    );
5448                    let err = common
5449                        .send(
5450                            "msg".into(),
5451                            Some("id".into()),
5452                            false,
5453                            Duration::from_secs(1),
5454                            Some(conn.clone()),
5455                        )
5456                        .await
5457                        .unwrap_err();
5458                    assert!(matches!(err, WebsocketError::NotConnected));
5459                });
5460            }
5461
5462            #[test]
5463            fn sync_send_error_when_no_ready() {
5464                TOKIO_SHARED_RT.block_on(async {
5465                    let conn = WebsocketConnection::new("c1");
5466                    let common = WebsocketCommon::new(
5467                        vec![conn.clone()],
5468                        WebsocketMode::Single,
5469                        0,
5470                        None,
5471                        None,
5472                    );
5473                    let err = common
5474                        .send("msg".into(), None, false, Duration::from_secs(1), None)
5475                        .await
5476                        .unwrap_err();
5477                    assert!(matches!(err, WebsocketError::NotConnected));
5478                });
5479            }
5480
5481            #[test]
5482            fn async_send_and_receive() {
5483                TOKIO_SHARED_RT.block_on(async {
5484                    let conn = WebsocketConnection::new("c1");
5485                    let (tx, mut rx) = unbounded_channel::<Message>();
5486                    {
5487                        let mut st = conn.state.lock().await;
5488                        st.ws_write_tx = Some(tx);
5489                    }
5490                    let common = WebsocketCommon::new(
5491                        vec![conn.clone()],
5492                        WebsocketMode::Single,
5493                        0,
5494                        None,
5495                        None,
5496                    );
5497                    let fut = common
5498                        .send(
5499                            "hello".into(),
5500                            Some("id".into()),
5501                            true,
5502                            Duration::from_secs(5),
5503                            Some(conn.clone()),
5504                        )
5505                        .await
5506                        .unwrap()
5507                        .unwrap();
5508                    match rx.try_recv() {
5509                        Ok(Message::Text(t)) => assert_eq!(t, "hello"),
5510                        other => panic!("expected Text, got {other:?}"),
5511                    }
5512                    {
5513                        let mut st = conn.state.lock().await;
5514                        let pr = st.pending_requests.remove("id").unwrap();
5515                        pr.completion.send(Ok(serde_json::json!("ok"))).unwrap();
5516                    }
5517                    let resp = fut.await.unwrap().unwrap();
5518                    assert_eq!(resp, serde_json::json!("ok"));
5519                });
5520            }
5521
5522            #[test]
5523            fn async_send_default_connection() {
5524                TOKIO_SHARED_RT.block_on(async {
5525                    let conn = WebsocketConnection::new("c1");
5526                    let (tx, mut rx) = unbounded_channel::<Message>();
5527                    {
5528                        let mut st = conn.state.lock().await;
5529                        st.ws_write_tx = Some(tx);
5530                    }
5531                    let common = WebsocketCommon::new(
5532                        vec![conn.clone()],
5533                        WebsocketMode::Single,
5534                        0,
5535                        None,
5536                        None,
5537                    );
5538                    let fut = common
5539                        .send(
5540                            "msg".into(),
5541                            Some("id".into()),
5542                            true,
5543                            Duration::from_secs(5),
5544                            None,
5545                        )
5546                        .await
5547                        .unwrap()
5548                        .unwrap();
5549                    match rx.try_recv() {
5550                        Ok(Message::Text(t)) => assert_eq!(t, "msg"),
5551                        _ => panic!("no text"),
5552                    }
5553                    {
5554                        let mut st = conn.state.lock().await;
5555                        let pr = st.pending_requests.remove("id").unwrap();
5556                        pr.completion.send(Ok(serde_json::json!(123))).unwrap();
5557                    }
5558                    let resp = fut.await.unwrap().unwrap();
5559                    assert_eq!(resp, serde_json::json!(123));
5560                });
5561            }
5562
5563            #[test]
5564            fn async_send_failure_removes_pending_entry() {
5565                // Regression test: when wait_for_reply=true, send() registers a
5566                // pending_requests entry BEFORE writing to ws_write_tx so that a
5567                // racing response can never be lost. If the write itself fails
5568                // afterwards, the pending entry must be removed so it doesn't
5569                // leak (otherwise close_connection_gracefully would block on it
5570                // until the 30s drain timeout, and on_message could match a
5571                // future request that reuses the id).
5572                TOKIO_SHARED_RT.block_on(async {
5573                    let conn = WebsocketConnection::new("c1");
5574                    let (tx, rx) = unbounded_channel::<Message>();
5575                    drop(rx);
5576                    {
5577                        let mut st = conn.state.lock().await;
5578                        st.ws_write_tx = Some(tx);
5579                    }
5580                    let common = WebsocketCommon::new(
5581                        vec![conn.clone()],
5582                        WebsocketMode::Single,
5583                        0,
5584                        None,
5585                        None,
5586                    );
5587                    let err = common
5588                        .send(
5589                            "msg".into(),
5590                            Some("id".into()),
5591                            true,
5592                            Duration::from_secs(1),
5593                            Some(conn.clone()),
5594                        )
5595                        .await
5596                        .unwrap_err();
5597                    assert!(matches!(err, WebsocketError::NotConnected));
5598                    assert!(
5599                        conn.state.lock().await.pending_requests.is_empty(),
5600                        "pending_requests must be empty after a failed send"
5601                    );
5602                });
5603            }
5604
5605            #[test]
5606            fn async_send_error_if_no_id() {
5607                TOKIO_SHARED_RT.block_on(async {
5608                    let conn = WebsocketConnection::new("c§");
5609                    let (tx, _rx) = unbounded_channel::<Message>();
5610                    {
5611                        let mut st = conn.state.lock().await;
5612                        st.ws_write_tx = Some(tx);
5613                    }
5614                    let common = WebsocketCommon::new(
5615                        vec![conn.clone()],
5616                        WebsocketMode::Single,
5617                        0,
5618                        None,
5619                        None,
5620                    );
5621                    let err = common
5622                        .send(
5623                            "msg".into(),
5624                            None,
5625                            true,
5626                            Duration::from_secs(1),
5627                            Some(conn.clone()),
5628                        )
5629                        .await
5630                        .unwrap_err();
5631                    assert!(matches!(err, WebsocketError::NotConnected));
5632                });
5633            }
5634
5635            #[test]
5636            fn timeout_rejects_async() {
5637                TOKIO_SHARED_RT.block_on(async {
5638                    pause();
5639                    let conn = WebsocketConnection::new("c1");
5640                    let (tx, _rx) = unbounded_channel::<Message>();
5641                    {
5642                        let mut st = conn.state.lock().await;
5643                        st.ws_write_tx = Some(tx);
5644                    }
5645                    let common = WebsocketCommon::new(
5646                        vec![conn.clone()],
5647                        WebsocketMode::Single,
5648                        0,
5649                        None,
5650                        None,
5651                    );
5652                    let fut = common
5653                        .send(
5654                            "msg".into(),
5655                            Some("id".into()),
5656                            true,
5657                            Duration::from_secs(1),
5658                            Some(conn.clone()),
5659                        )
5660                        .await
5661                        .unwrap()
5662                        .unwrap();
5663                    advance(Duration::from_secs(1)).await;
5664                    let res = fut.await.unwrap();
5665                    assert!(res.is_err(), "expected timeout error");
5666                    assert!(!conn.state.lock().await.pending_requests.contains_key("id"));
5667                });
5668            }
5669
5670            #[test]
5671            fn async_send_errors_if_no_connection_ready() {
5672                TOKIO_SHARED_RT.block_on(async {
5673                    let conn = WebsocketConnection::new("c1");
5674                    let common = WebsocketCommon::new(
5675                        vec![conn.clone()],
5676                        WebsocketMode::Single,
5677                        0,
5678                        None,
5679                        None,
5680                    );
5681                    let err = common
5682                        .send(
5683                            "msg".into(),
5684                            Some("id".into()),
5685                            true,
5686                            Duration::from_secs(1),
5687                            None,
5688                        )
5689                        .await
5690                        .unwrap_err();
5691                    assert!(matches!(err, WebsocketError::NotConnected));
5692                });
5693            }
5694        }
5695    }
5696
5697    mod websocket_api {
5698        use super::*;
5699
5700        mod initialisation {
5701            use super::*;
5702
5703            #[test]
5704            fn new_initializes_common() {
5705                TOKIO_SHARED_RT.block_on(async {
5706                    let conn = WebsocketConnection::new("id");
5707                    let pool = vec![conn.clone()];
5708
5709                    let sig_gen = SignatureGenerator::new(
5710                        Some("api_secret".to_string()),
5711                        None::<PrivateKey>,
5712                        None::<String>,
5713                    );
5714
5715                    let config = ConfigurationWebsocketApi {
5716                        api_key: Some("api_key".to_string()),
5717                        api_secret: Some("api_secret".to_string()),
5718                        private_key: None,
5719                        private_key_passphrase: None,
5720                        ws_url: Some("wss://example".to_string()),
5721                        mode: WebsocketMode::Single,
5722                        reconnect_delay: 1000,
5723                        signature_gen: sig_gen,
5724                        timeout: 500,
5725                        time_unit: None,
5726                        auto_session_relogon: false,
5727                        agent: None,
5728                        user_agent: build_user_agent("product"),
5729                    };
5730
5731                    let api = WebsocketApi::new(config, pool.clone());
5732
5733                    assert_eq!(api.common.connection_pool.len(), 1);
5734                    assert_eq!(api.common.mode, WebsocketMode::Single);
5735
5736                    let flag = *api.is_connecting.lock().await;
5737                    assert!(!flag);
5738                });
5739            }
5740        }
5741
5742        mod connect {
5743            use super::*;
5744
5745            #[test]
5746            fn connect_when_not_connected_establishes() {
5747                TOKIO_SHARED_RT.block_on(async {
5748                    let conn = WebsocketConnection::new("id");
5749                    {
5750                        let mut st = conn.state.lock().await;
5751                        st.ws_write_tx = None;
5752                    }
5753                    let sig = SignatureGenerator::new(
5754                        Some("api_secret".into()),
5755                        None::<PrivateKey>,
5756                        None::<String>,
5757                    );
5758                    let cfg = ConfigurationWebsocketApi {
5759                        api_key: Some("api_key".into()),
5760                        api_secret: Some("api_secret".to_string()),
5761                        private_key: None,
5762                        private_key_passphrase: None,
5763                        ws_url: Some("ws://doesnotexist:1".to_string()),
5764                        mode: WebsocketMode::Single,
5765                        reconnect_delay: 0,
5766                        signature_gen: sig,
5767                        timeout: 10,
5768                        time_unit: None,
5769                        auto_session_relogon: false,
5770                        agent: None,
5771                        user_agent: build_user_agent("product"),
5772                    };
5773                    let api = WebsocketApi::new(cfg, vec![conn.clone()]);
5774                    let res = api.clone().connect().await;
5775                    assert!(!matches!(res, Err(WebsocketError::Timeout)));
5776                });
5777            }
5778
5779            #[test]
5780            fn already_connected_returns_ok() {
5781                TOKIO_SHARED_RT.block_on(async {
5782                    let conn = WebsocketConnection::new("id2");
5783                    let (tx, _) = unbounded_channel();
5784                    {
5785                        let mut st = conn.state.lock().await;
5786                        st.ws_write_tx = Some(tx);
5787                    }
5788                    let sig = SignatureGenerator::new(
5789                        Some("api_secret".to_string()),
5790                        None::<PrivateKey>,
5791                        None::<String>,
5792                    );
5793                    let cfg = ConfigurationWebsocketApi {
5794                        api_key: Some("api_key".to_string()),
5795                        api_secret: Some("api_secret".to_string()),
5796                        private_key: None,
5797                        private_key_passphrase: None,
5798                        ws_url: Some("ws://example.com".to_string()),
5799                        mode: WebsocketMode::Single,
5800                        reconnect_delay: 0,
5801                        signature_gen: sig,
5802                        timeout: 10,
5803                        time_unit: None,
5804                        auto_session_relogon: false,
5805                        agent: None,
5806                        user_agent: build_user_agent("product"),
5807                    };
5808                    let api = WebsocketApi::new(cfg, vec![conn.clone()]);
5809                    let res = api.connect().await;
5810                    assert!(res.is_ok());
5811                });
5812            }
5813
5814            #[test]
5815            fn not_connected_returns_error() {
5816                TOKIO_SHARED_RT.block_on(async {
5817                    let conn = WebsocketConnection::new("id1");
5818                    let sig = SignatureGenerator::new(
5819                        Some("api_secret".to_string()),
5820                        None::<PrivateKey>,
5821                        None::<String>,
5822                    );
5823                    let cfg = ConfigurationWebsocketApi {
5824                        api_key: Some("api_key".to_string()),
5825                        api_secret: Some("api_secret".to_string()),
5826                        private_key: None,
5827                        private_key_passphrase: None,
5828                        ws_url: Some("ws://127.0.0.1:9".to_string()),
5829                        mode: WebsocketMode::Single,
5830                        reconnect_delay: 0,
5831                        signature_gen: sig,
5832                        timeout: 10,
5833                        time_unit: None,
5834                        auto_session_relogon: false,
5835                        agent: None,
5836                        user_agent: build_user_agent("product"),
5837                    };
5838                    let api = WebsocketApi::new(cfg, vec![conn.clone()]);
5839                    let res = api.connect().await;
5840                    assert!(res.is_err());
5841                });
5842            }
5843
5844            #[test]
5845            fn concurrent_calls_both_error_or_ok() {
5846                TOKIO_SHARED_RT.block_on(async {
5847                    let conn = WebsocketConnection::new("id3");
5848                    let sig = SignatureGenerator::new(
5849                        Some("api_secret".to_string()),
5850                        None::<PrivateKey>,
5851                        None::<String>,
5852                    );
5853                    let cfg = ConfigurationWebsocketApi {
5854                        api_key: Some("api_key".to_string()),
5855                        api_secret: Some("api_secret".to_string()),
5856                        private_key: None,
5857                        private_key_passphrase: None,
5858                        ws_url: Some("wss://invalid-domain".to_string()),
5859                        mode: WebsocketMode::Single,
5860                        reconnect_delay: 0,
5861                        signature_gen: sig,
5862                        timeout: 10,
5863                        time_unit: None,
5864                        auto_session_relogon: false,
5865                        agent: None,
5866                        user_agent: build_user_agent("product"),
5867                    };
5868                    let api = WebsocketApi::new(cfg, vec![conn.clone()]);
5869                    let fut1 = tokio::spawn(api.clone().connect());
5870                    let fut2 = tokio::spawn(api.clone().connect());
5871                    let r1 = fut1.await.unwrap();
5872                    let r2 = fut2.await.unwrap();
5873
5874                    assert!(r1.is_err());
5875                    assert!(r2.is_err() || r2.is_ok());
5876                });
5877            }
5878
5879            #[test]
5880            fn pool_failure_is_propagated() {
5881                TOKIO_SHARED_RT.block_on(async {
5882                    let conn = WebsocketConnection::new("w");
5883                    let sig = SignatureGenerator::new(
5884                        Some("api_secret".to_string()),
5885                        None::<PrivateKey>,
5886                        None::<String>,
5887                    );
5888                    let cfg = ConfigurationWebsocketApi {
5889                        api_key: Some("api_key".into()),
5890                        api_secret: Some("api_secret".to_string()),
5891                        private_key: None,
5892                        private_key_passphrase: None,
5893                        ws_url: Some("ws://doesnotexist:1".to_string()),
5894                        mode: WebsocketMode::Single,
5895                        reconnect_delay: 0,
5896                        signature_gen: sig,
5897                        timeout: 10,
5898                        time_unit: None,
5899                        auto_session_relogon: false,
5900                        agent: None,
5901                        user_agent: build_user_agent("product"),
5902                    };
5903                    let api = WebsocketApi::new(cfg, vec![conn.clone()]);
5904                    let res = api.clone().connect().await;
5905                    match res {
5906                        Err(WebsocketError::Handshake(_) | WebsocketError::Timeout) => {}
5907                        _ => panic!("expected handshake or timeout error"),
5908                    }
5909                });
5910            }
5911        }
5912
5913        mod send_message {
5914            use super::*;
5915
5916            #[test]
5917            fn unsigned_message() {
5918                TOKIO_SHARED_RT.block_on(async {
5919                    let api = create_websocket_api(None, None, None);
5920                    let conn = &api.common.connection_pool[0];
5921                    let (tx, mut rx) = unbounded_channel::<Message>();
5922                    {
5923                        let mut st = conn.state.lock().await;
5924                        st.ws_write_tx = Some(tx);
5925                    }
5926
5927                    let fut = tokio::spawn({
5928                        let api = api.clone();
5929                        async move {
5930                            let mut params = BTreeMap::new();
5931                            params.insert("foo".into(), Value::String("bar".into()));
5932                            let send_res = api
5933                                .send_message::<Value>(
5934                                    "method",
5935                                    params,
5936                                    WebsocketMessageSendOptions {
5937                                        with_api_key: false,
5938                                        is_signed: false,
5939                                        ..Default::default()
5940                                    },
5941                                )
5942                                .await
5943                                .unwrap();
5944
5945                            match send_res {
5946                                SendWebsocketMessageResult::Single(resp) => resp,
5947                                SendWebsocketMessageResult::Multiple(_) => {
5948                                    panic!("expected single response")
5949                                }
5950                            }
5951                        }
5952                    });
5953
5954                    let Message::Text(txt) = rx.recv().await.unwrap() else {
5955                        panic!()
5956                    };
5957                    let req: Value = serde_json::from_str(&txt).unwrap();
5958                    assert_eq!(req["method"], "method");
5959                    assert_eq!(req["params"]["foo"], "bar");
5960                    assert!(req["params"].get("apiKey").is_none());
5961                    assert!(req["params"].get("timestamp").is_none());
5962                    assert!(req["params"].get("signature").is_none());
5963
5964                    let id = req["id"].as_str().unwrap().to_string();
5965                    let mut st = conn.state.lock().await;
5966                    let pending = st.pending_requests.remove(&id).unwrap();
5967                    let reply = json!({
5968                        "id": id,
5969                        "result": { "x": 42 },
5970                        "rateLimits": [{ "limit": 7 }]
5971                    });
5972                    pending.completion.send(Ok(reply)).unwrap();
5973
5974                    let resp = fut.await.unwrap();
5975                    let rate_limits = resp.rate_limits.unwrap_or_default();
5976
5977                    assert!(rate_limits.is_empty());
5978                    assert_eq!(resp.raw, json!({"x": 42}));
5979                });
5980            }
5981
5982            #[test]
5983            fn with_api_key_only() {
5984                TOKIO_SHARED_RT.block_on(async {
5985                    let api = create_websocket_api(None, None, None);
5986                    let conn = &api.common.connection_pool[0];
5987                    let (tx, mut rx) = unbounded_channel::<Message>();
5988                    {
5989                        let mut st = conn.state.lock().await;
5990                        st.ws_write_tx = Some(tx);
5991                    }
5992
5993                    let fut = tokio::spawn({
5994                        let api = api.clone();
5995                        async move {
5996                            let params = BTreeMap::new();
5997                            let send_res = api
5998                                .send_message::<Value>(
5999                                    "method",
6000                                    params,
6001                                    WebsocketMessageSendOptions {
6002                                        with_api_key: true,
6003                                        is_signed: false,
6004                                        ..Default::default()
6005                                    },
6006                                )
6007                                .await
6008                                .unwrap();
6009
6010                            match send_res {
6011                                SendWebsocketMessageResult::Single(resp) => resp,
6012                                SendWebsocketMessageResult::Multiple(_) => {
6013                                    panic!("expected single response")
6014                                }
6015                            }
6016                        }
6017                    });
6018
6019                    let Message::Text(txt) = rx.recv().await.unwrap() else {
6020                        panic!()
6021                    };
6022                    let req: Value = serde_json::from_str(&txt).unwrap();
6023                    assert_eq!(req["params"]["apiKey"], "api_key");
6024
6025                    let id = req["id"].as_str().unwrap().to_string();
6026                    let mut st = conn.state.lock().await;
6027                    let pending = st.pending_requests.remove(&id).unwrap();
6028                    pending
6029                        .completion
6030                        .send(Ok(json!({
6031                            "id": id,
6032                            "result": {},
6033                            "rateLimits": []
6034                        })))
6035                        .unwrap();
6036
6037                    let resp = fut.await.unwrap();
6038
6039                    assert_eq!(resp.raw, json!({}));
6040                    assert!(st.pending_requests.is_empty());
6041                });
6042            }
6043
6044            #[test]
6045            fn signed_message_has_timestamp_and_signature() {
6046                TOKIO_SHARED_RT.block_on(async {
6047                    let api = create_websocket_api(None, None, None);
6048                    let conn = &api.common.connection_pool[0];
6049                    let (tx, mut rx) = unbounded_channel::<Message>();
6050                    {
6051                        let mut st = conn.state.lock().await;
6052                        st.ws_write_tx = Some(tx);
6053                    }
6054
6055                    let fut = tokio::spawn({
6056                        let api = api.clone();
6057                        async move {
6058                            let mut params = BTreeMap::new();
6059                            params.insert("foo".into(), Value::String("bar".into()));
6060                            let send_res = api
6061                                .send_message::<Value>(
6062                                    "method",
6063                                    params,
6064                                    WebsocketMessageSendOptions {
6065                                        with_api_key: true,
6066                                        is_signed: true,
6067                                        ..Default::default()
6068                                    },
6069                                )
6070                                .await
6071                                .unwrap();
6072
6073                            match send_res {
6074                                SendWebsocketMessageResult::Single(resp) => resp,
6075                                SendWebsocketMessageResult::Multiple(_) => {
6076                                    panic!("expected single response")
6077                                }
6078                            }
6079                        }
6080                    });
6081
6082                    let Message::Text(txt) = rx.recv().await.unwrap() else {
6083                        panic!()
6084                    };
6085                    let req: Value = serde_json::from_str(&txt).unwrap();
6086                    let p = &req["params"];
6087                    assert_eq!(p["apiKey"], "api_key");
6088                    assert!(p["timestamp"].is_number());
6089                    assert!(p["signature"].is_string());
6090
6091                    let id = req["id"].as_str().unwrap().to_string();
6092                    let mut st = conn.state.lock().await;
6093                    let pending = st.pending_requests.remove(&id).unwrap();
6094                    pending
6095                        .completion
6096                        .send(Ok(json!({
6097                            "id": id,
6098                            "result": { "ok": true },
6099                            "rateLimits": []
6100                        })))
6101                        .unwrap();
6102
6103                    let resp = fut.await.unwrap();
6104                    assert_eq!(resp.raw, json!({ "ok": true }));
6105                });
6106            }
6107
6108            #[test]
6109            fn multi_session_logon() {
6110                TOKIO_SHARED_RT.block_on(async {
6111                    let api = create_websocket_api(None, Some(WebsocketMode::Pool(2)), None);
6112                    let conn0 = &api.common.connection_pool[0];
6113                    let conn1 = &api.common.connection_pool[1];
6114
6115                    let (tx0, mut rx0) = unbounded_channel::<Message>();
6116                    let (tx1, mut rx1) = unbounded_channel::<Message>();
6117                    {
6118                        let mut st0 = conn0.state.lock().await;
6119                        st0.ws_write_tx = Some(tx0);
6120                    }
6121                    {
6122                        let mut st1 = conn1.state.lock().await;
6123                        st1.ws_write_tx = Some(tx1);
6124                    }
6125
6126                    let fut = tokio::spawn({
6127                        let api = api.clone();
6128                        async move {
6129                            let params = BTreeMap::new();
6130                            let send_res = api
6131                                .send_message::<Value>(
6132                                    "method",
6133                                    params,
6134                                    WebsocketMessageSendOptions {
6135                                        is_session_logon: Some(true),
6136                                        ..Default::default()
6137                                    },
6138                                )
6139                                .await
6140                                .unwrap();
6141
6142                            match send_res {
6143                                SendWebsocketMessageResult::Multiple(v) => v,
6144                                SendWebsocketMessageResult::Single(_) => {
6145                                    panic!("expected multiple responses")
6146                                }
6147                            }
6148                        }
6149                    });
6150
6151                    let Message::Text(txt0) = rx0.recv().await.unwrap() else {
6152                        panic!()
6153                    };
6154                    let Message::Text(txt1) = rx1.recv().await.unwrap() else {
6155                        panic!()
6156                    };
6157                    let req0: Value = serde_json::from_str(&txt0).unwrap();
6158                    let req1: Value = serde_json::from_str(&txt1).unwrap();
6159                    assert_eq!(req0["method"], "method");
6160                    assert_eq!(req1["method"], "method");
6161                    let id = req0["id"].as_str().unwrap().to_string();
6162                    assert_eq!(req1["id"].as_str().unwrap(), &id);
6163
6164                    {
6165                        let mut st0 = conn0.state.lock().await;
6166                        let pending0 = st0.pending_requests.remove(&id).unwrap();
6167                        pending0
6168                            .completion
6169                            .send(Ok(json!({
6170                                "id": id,
6171                                "result": { "ok": true },
6172                                "rateLimits": []
6173                            })))
6174                            .unwrap();
6175                    }
6176                    {
6177                        let mut st1 = conn1.state.lock().await;
6178                        let pending1 = st1.pending_requests.remove(&id).unwrap();
6179                        pending1
6180                            .completion
6181                            .send(Ok(json!({
6182                                "id": id,
6183                                "result": { "ok": true },
6184                                "rateLimits": []
6185                            })))
6186                            .unwrap();
6187                    }
6188
6189                    let results = fut.await.unwrap();
6190                    assert_eq!(results.len(), 2);
6191
6192                    for conn in &api.common.connection_pool {
6193                        let st = conn.state.lock().await;
6194                        assert!(st.is_session_logged_on, "should be logged out");
6195                        assert!(st.session_logon_req.is_some(), "req cleared");
6196
6197                        // let req = st
6198                        //     .session_logon_req
6199                        //     .as_ref()
6200                        //     .expect("session_logon_req should be Some(_)");
6201                        // assert_eq!(req.method, "method");
6202                        // let mut expected = BTreeMap::new();
6203                        // expected.insert("ok".to_string(), Value::Bool(true));
6204                        // assert_eq!(
6205                        //     req.payload, expected,
6206                        //     "stored payload should be {{ \"ok\": true }}"
6207                        // );
6208                        // assert!(
6209                        //     req.options.is_session_logon.unwrap_or(false),
6210                        //     "expected options.is_session_logon = true"
6211                        // );
6212                    }
6213                });
6214            }
6215
6216            #[test]
6217            fn multi_session_logout() {
6218                TOKIO_SHARED_RT.block_on(async {
6219                    let api = create_websocket_api(None, Some(WebsocketMode::Pool(2)), None);
6220
6221                    for conn in &api.common.connection_pool {
6222                        let (tx, _rx) = unbounded_channel::<Message>();
6223                        let mut st = conn.state.lock().await;
6224                        st.ws_write_tx = Some(tx);
6225                        st.is_session_logged_on = true;
6226                        st.session_logon_req = Some(WebsocketSessionLogonReq {
6227                            method: "method".into(),
6228                            payload: BTreeMap::new(),
6229                            options: WebsocketMessageSendOptions::default(),
6230                        });
6231                    }
6232
6233                    let mut rxs = Vec::new();
6234                    for conn in &api.common.connection_pool {
6235                        let rx = {
6236                            let (tx, rx) = unbounded_channel::<Message>();
6237                            conn.state.lock().await.ws_write_tx = Some(tx);
6238                            rx
6239                        };
6240                        rxs.push(rx);
6241                    }
6242
6243                    let fut = tokio::spawn({
6244                        let api = api.clone();
6245                        async move {
6246                            let send_res = api
6247                                .send_message::<Value>(
6248                                    "method",
6249                                    BTreeMap::new(),
6250                                    WebsocketMessageSendOptions {
6251                                        is_signed: false,
6252                                        with_api_key: false,
6253                                        is_session_logout: Some(true),
6254                                        ..Default::default()
6255                                    },
6256                                )
6257                                .await
6258                                .unwrap();
6259
6260                            match send_res {
6261                                SendWebsocketMessageResult::Multiple(v) => v,
6262                                SendWebsocketMessageResult::Single(_) => panic!("expected multi"),
6263                            }
6264                        }
6265                    });
6266
6267                    let mut ids = Vec::new();
6268                    for mut rx in rxs {
6269                        let Message::Text(txt) = rx.recv().await.unwrap() else {
6270                            panic!()
6271                        };
6272                        let req: Value = serde_json::from_str(&txt).unwrap();
6273                        assert_eq!(req["method"], "method");
6274                        ids.push(req["id"].as_str().unwrap().to_string());
6275                    }
6276
6277                    assert_eq!(ids[0], ids[1]);
6278
6279                    for conn in &api.common.connection_pool {
6280                        let id = &ids[0];
6281                        let mut st = conn.state.lock().await;
6282                        let pending = st.pending_requests.remove(id).unwrap();
6283                        pending
6284                            .completion
6285                            .send(Ok(json!({
6286                                "id": id,
6287                                "result": {},
6288                                "rateLimits": []
6289                            })))
6290                            .unwrap();
6291                    }
6292
6293                    let results = fut.await.unwrap();
6294                    assert_eq!(results.len(), 2);
6295
6296                    for conn in &api.common.connection_pool {
6297                        let st = conn.state.lock().await;
6298                        assert!(!st.is_session_logged_on, "should be logged out");
6299                        assert!(st.session_logon_req.is_none(), "req cleared");
6300                    }
6301                });
6302            }
6303
6304            #[test]
6305            fn skip_signature_when_logged_on_and_auto_relogon() {
6306                TOKIO_SHARED_RT.block_on(async {
6307                    let api = create_websocket_api(None, Some(WebsocketMode::Single), None);
6308                    let conn = &api.common.connection_pool[0];
6309                    {
6310                        let mut st = conn.state.lock().await;
6311                        st.ws_write_tx = Some(unbounded_channel::<Message>().0);
6312                        st.is_session_logged_on = true;
6313                    }
6314
6315                    let mut rx;
6316                    {
6317                        let mut st = conn.state.lock().await;
6318                        let (tx, new_rx) = unbounded_channel::<Message>();
6319                        st.ws_write_tx = Some(tx);
6320                        rx = new_rx;
6321                    }
6322
6323                    let fut = tokio::spawn({
6324                        let api = api.clone();
6325                        async move {
6326                            let send_res = api
6327                                .send_message::<Value>(
6328                                    "method",
6329                                    BTreeMap::new(),
6330                                    WebsocketMessageSendOptions {
6331                                        is_signed: true,
6332                                        ..Default::default()
6333                                    },
6334                                )
6335                                .await
6336                                .unwrap();
6337
6338                            match send_res {
6339                                SendWebsocketMessageResult::Single(resp) => resp,
6340                                SendWebsocketMessageResult::Multiple(_) => {
6341                                    panic!("expected single")
6342                                }
6343                            }
6344                        }
6345                    });
6346
6347                    let Message::Text(txt) = rx.recv().await.unwrap() else {
6348                        panic!()
6349                    };
6350                    let req: Value = serde_json::from_str(&txt).unwrap();
6351                    let p = &req["params"];
6352                    assert!(p.get("timestamp").is_some());
6353                    assert!(p.get("signature").is_none());
6354
6355                    let id = req["id"].as_str().unwrap();
6356                    let mut st = conn.state.lock().await;
6357                    let pending = st.pending_requests.remove(id).unwrap();
6358                    pending
6359                        .completion
6360                        .send(Ok(json!({
6361                            "id": id,
6362                            "result": {},
6363                            "rateLimits": []
6364                        })))
6365                        .unwrap();
6366
6367                    let resp = fut.await.unwrap();
6368                    assert_eq!(resp.raw, json!({}));
6369                });
6370            }
6371
6372            #[test]
6373            fn include_signature_when_logged_on_and_no_auto_relogon() {
6374                TOKIO_SHARED_RT.block_on(async {
6375                    let api = create_websocket_api(None, Some(WebsocketMode::Single), Some(false));
6376                    let conn = &api.common.connection_pool[0];
6377                    {
6378                        let mut st = conn.state.lock().await;
6379                        st.ws_write_tx = Some(unbounded_channel::<Message>().0);
6380                        st.is_session_logged_on = true;
6381                    }
6382
6383                    let mut rx;
6384                    {
6385                        let mut st = conn.state.lock().await;
6386                        let (tx, new_rx) = unbounded_channel::<Message>();
6387                        st.ws_write_tx = Some(tx);
6388                        rx = new_rx;
6389                    }
6390
6391                    let fut = tokio::spawn({
6392                        let api = api.clone();
6393                        async move {
6394                            let send_res = api
6395                                .send_message::<Value>(
6396                                    "method",
6397                                    BTreeMap::new(),
6398                                    WebsocketMessageSendOptions {
6399                                        is_signed: true,
6400                                        ..Default::default()
6401                                    },
6402                                )
6403                                .await
6404                                .unwrap();
6405
6406                            match send_res {
6407                                SendWebsocketMessageResult::Single(resp) => resp,
6408                                SendWebsocketMessageResult::Multiple(_) => {
6409                                    panic!("expected single")
6410                                }
6411                            }
6412                        }
6413                    });
6414
6415                    let Message::Text(txt) = rx.recv().await.unwrap() else {
6416                        panic!()
6417                    };
6418                    let req: Value = serde_json::from_str(&txt).unwrap();
6419                    let p = &req["params"];
6420                    assert!(p.get("timestamp").is_some());
6421                    assert!(p.get("signature").is_some());
6422
6423                    let id = req["id"].as_str().unwrap();
6424                    let mut st = conn.state.lock().await;
6425                    let pending = st.pending_requests.remove(id).unwrap();
6426                    pending
6427                        .completion
6428                        .send(Ok(json!({
6429                            "id": id,
6430                            "result": {},
6431                            "rateLimits": []
6432                        })))
6433                        .unwrap();
6434
6435                    let resp = fut.await.unwrap();
6436                    assert_eq!(resp.raw, json!({}));
6437                });
6438            }
6439
6440            #[test]
6441            fn error_if_not_connected() {
6442                TOKIO_SHARED_RT.block_on(async {
6443                    let api = create_websocket_api(None, None, None);
6444                    let conn = &api.common.connection_pool[0];
6445                    {
6446                        let mut st = conn.state.lock().await;
6447                        st.ws_write_tx = None;
6448                    }
6449                    let params = BTreeMap::new();
6450                    let err = api
6451                        .send_message::<Value>(
6452                            "method",
6453                            params,
6454                            WebsocketMessageSendOptions {
6455                                with_api_key: false,
6456                                is_signed: false,
6457                                ..Default::default()
6458                            },
6459                        )
6460                        .await
6461                        .unwrap_err();
6462                    matches!(err, WebsocketError::NotConnected);
6463                });
6464            }
6465        }
6466
6467        mod prepare_url {
6468            use super::*;
6469
6470            #[test]
6471            fn no_time_unit() {
6472                TOKIO_SHARED_RT.block_on(async {
6473                    let api = create_websocket_api(None, None, None);
6474                    let url = "wss://example.com/ws".to_string();
6475                    assert_eq!(api.prepare_url(&url), url);
6476                });
6477            }
6478
6479            #[test]
6480            fn appends_time_unit() {
6481                TOKIO_SHARED_RT.block_on(async {
6482                    let api = create_websocket_api(Some(TimeUnit::Millisecond), None, None);
6483                    let base = "wss://example.com/ws".to_string();
6484                    let got = api.prepare_url(&base);
6485                    assert_eq!(got, format!("{base}?timeUnit=millisecond"));
6486                });
6487            }
6488
6489            #[test]
6490            fn handles_existing_query() {
6491                TOKIO_SHARED_RT.block_on(async {
6492                    let api = create_websocket_api(Some(TimeUnit::Microsecond), None, None);
6493                    let base = "wss://example.com/ws?foo=bar".to_string();
6494                    let got = api.prepare_url(&base);
6495                    assert_eq!(got, format!("{base}&timeUnit=microsecond"));
6496                });
6497            }
6498        }
6499
6500        mod on_open {
6501            use super::*;
6502
6503            fn create_websocket_api_and_conn() -> (Arc<WebsocketApi>, Arc<WebsocketConnection>) {
6504                let sig_gen = SignatureGenerator::new(
6505                    Some("api_secret".to_string()),
6506                    None::<_>,
6507                    None::<String>,
6508                );
6509                let config = ConfigurationWebsocketApi {
6510                    api_key: Some("api_key".to_string()),
6511                    api_secret: Some("api_secret".to_string()),
6512                    private_key: None,
6513                    private_key_passphrase: None,
6514                    ws_url: Some("wss://example".to_string()),
6515                    mode: WebsocketMode::Single,
6516                    reconnect_delay: 0,
6517                    signature_gen: sig_gen,
6518                    timeout: 1000,
6519                    time_unit: None,
6520                    auto_session_relogon: true,
6521                    agent: None,
6522                    user_agent: build_user_agent("product"),
6523                };
6524                let conn = WebsocketConnection::new("test-conn");
6525                let api = WebsocketApi::new(config, vec![conn.clone()]);
6526                (api, conn)
6527            }
6528
6529            #[test]
6530            fn session_relogon_on_open() {
6531                TOKIO_SHARED_RT.block_on(async {
6532                    let (api, conn) = create_websocket_api_and_conn();
6533
6534                    let req = WebsocketSessionLogonReq {
6535                        method: "method".into(),
6536                        payload: {
6537                            let mut m = BTreeMap::new();
6538                            m.insert("foo".into(), Value::String("bar".into()));
6539                            m
6540                        },
6541                        options: WebsocketMessageSendOptions {
6542                            with_api_key: true,
6543                            is_signed: true,
6544                            is_session_logon: Some(true),
6545                            ..Default::default()
6546                        },
6547                    };
6548
6549                    let (tx, mut rx) = unbounded_channel::<Message>();
6550                    {
6551                        let mut st = conn.state.lock().await;
6552                        st.session_logon_req = Some(req.clone());
6553                        st.is_session_logged_on = false;
6554                        st.ws_write_tx = Some(tx);
6555                    }
6556
6557                    api.on_open("wss://example".to_string(), conn.clone()).await;
6558
6559                    let Message::Text(raw) = rx.recv().await.unwrap() else {
6560                        panic!("expected a Text message");
6561                    };
6562                    let msg: Value = serde_json::from_str(&raw).unwrap();
6563                    assert_eq!(msg["method"], "method");
6564                    assert_eq!(msg["params"]["foo"], "bar");
6565
6566                    let id = msg["id"].as_str().unwrap().to_string();
6567                    {
6568                        let mut st = conn.state.lock().await;
6569                        let pending = st.pending_requests.remove(&id).expect("pending request");
6570                        pending
6571                            .completion
6572                            .send(Ok(json!({
6573                                "id": id,
6574                                "result": {},
6575                                "rateLimits": []
6576                            })))
6577                            .unwrap();
6578                    }
6579
6580                    sleep(Duration::from_millis(10)).await;
6581
6582                    let st = conn.state.lock().await;
6583                    assert!(st.is_session_logged_on, "should now be logged on");
6584                });
6585            }
6586
6587            #[test]
6588            fn no_relogon_if_already_logged_on() {
6589                TOKIO_SHARED_RT.block_on(async {
6590                    let (api, conn) = create_websocket_api_and_conn();
6591
6592                    let req = WebsocketSessionLogonReq {
6593                        method: "method".into(),
6594                        payload: BTreeMap::new(),
6595                        options: WebsocketMessageSendOptions {
6596                            is_session_logon: Some(true),
6597                            ..Default::default()
6598                        },
6599                    };
6600
6601                    let (tx, mut rx) = unbounded_channel::<Message>();
6602                    {
6603                        let mut st = conn.state.lock().await;
6604                        st.session_logon_req = Some(req);
6605                        st.is_session_logged_on = true;
6606                        st.ws_write_tx = Some(tx);
6607                    }
6608
6609                    api.on_open("wss://example".to_string(), conn.clone()).await;
6610
6611                    assert!(rx.try_recv().is_err(), "no re‐logon when already on");
6612
6613                    let st = conn.state.lock().await;
6614                    assert!(st.is_session_logged_on);
6615                });
6616            }
6617
6618            #[test]
6619            fn session_relogon_fails_gracefully() {
6620                TOKIO_SHARED_RT.block_on(async {
6621                    let (api, conn) = create_websocket_api_and_conn();
6622
6623                    let req = WebsocketSessionLogonReq {
6624                        method: "method".into(),
6625                        payload: {
6626                            let mut m = BTreeMap::new();
6627                            m.insert("x".into(), Value::Number(1.into()));
6628                            m
6629                        },
6630                        options: WebsocketMessageSendOptions {
6631                            is_session_logon: Some(true),
6632                            ..Default::default()
6633                        },
6634                    };
6635                    {
6636                        let mut st = conn.state.lock().await;
6637                        st.session_logon_req = Some(req);
6638                        st.is_session_logged_on = false;
6639                        st.ws_write_tx = None;
6640                    }
6641
6642                    api.on_open("wss://example".into(), conn.clone()).await;
6643
6644                    let st = conn.state.lock().await;
6645                    assert!(
6646                        !st.is_session_logged_on,
6647                        "should remain logged‐off on failure"
6648                    );
6649                });
6650            }
6651
6652            #[test]
6653            fn session_relogon_noop_when_no_req() {
6654                TOKIO_SHARED_RT.block_on(async {
6655                    let (api, conn) = create_websocket_api_and_conn();
6656
6657                    {
6658                        let mut st = conn.state.lock().await;
6659                        st.session_logon_req = None;
6660                        st.is_session_logged_on = false;
6661                        st.ws_write_tx = Some(unbounded_channel::<Message>().0);
6662                    }
6663
6664                    api.on_open("wss://example".into(), conn.clone()).await;
6665
6666                    let st = conn.state.lock().await;
6667                    assert!(!st.is_session_logged_on, "still logged‐off");
6668                });
6669            }
6670        }
6671
6672        mod on_message {
6673            use super::*;
6674
6675            fn create_websocket_api_and_conn() -> (Arc<WebsocketApi>, Arc<WebsocketConnection>) {
6676                let sig_gen = SignatureGenerator::new(
6677                    Some("api_secret".to_string()),
6678                    None::<_>,
6679                    None::<String>,
6680                );
6681                let config = ConfigurationWebsocketApi {
6682                    api_key: Some("api_key".to_string()),
6683                    api_secret: Some("api_secret".to_string()),
6684                    private_key: None,
6685                    private_key_passphrase: None,
6686                    ws_url: Some("wss://example".to_string()),
6687                    mode: WebsocketMode::Single,
6688                    reconnect_delay: 0,
6689                    signature_gen: sig_gen,
6690                    timeout: 1000,
6691                    time_unit: None,
6692                    auto_session_relogon: false,
6693                    agent: None,
6694                    user_agent: build_user_agent("product"),
6695                };
6696                let conn = WebsocketConnection::new("test");
6697                let api = WebsocketApi::new(config, vec![conn.clone()]);
6698                (api, conn)
6699            }
6700
6701            #[test]
6702            fn resolves_pending_and_removes_request() {
6703                TOKIO_SHARED_RT.block_on(async {
6704                    let (api, conn) = create_websocket_api_and_conn();
6705                    let (tx, rx) = oneshot::channel();
6706                    {
6707                        let mut st = conn.state.lock().await;
6708                        st.pending_requests
6709                            .insert("id1".to_string(), PendingRequest { completion: tx });
6710                    }
6711                    let msg = json!({"id":"id1","status":200,"foo":"bar"});
6712                    api.on_message(msg.to_string(), conn.clone()).await;
6713                    let got = rx.await.unwrap().unwrap();
6714                    assert_eq!(got, msg);
6715                    let st = conn.state.lock().await;
6716                    assert!(!st.pending_requests.contains_key("id1"));
6717                });
6718            }
6719
6720            #[test]
6721            fn uses_result_when_present() {
6722                TOKIO_SHARED_RT.block_on(async {
6723                    let (api, conn) = create_websocket_api_and_conn();
6724                    let (tx, rx) = oneshot::channel();
6725                    {
6726                        let mut st = conn.state.lock().await;
6727                        st.pending_requests
6728                            .insert("id1".to_string(), PendingRequest { completion: tx });
6729                    }
6730                    let msg = json!({
6731                        "id": "id1",
6732                        "status": 200,
6733                        "response": [1,2],
6734                        "result": {"a":1}
6735                    });
6736                    api.on_message(msg.to_string(), conn.clone()).await;
6737                    let got = rx.await.unwrap().unwrap();
6738                    assert_eq!(got.get("result").unwrap(), &json!({"a":1}));
6739                });
6740            }
6741
6742            #[test]
6743            fn uses_response_when_no_result() {
6744                TOKIO_SHARED_RT.block_on(async {
6745                    let (api, conn) = create_websocket_api_and_conn();
6746                    let (tx, rx) = oneshot::channel();
6747                    {
6748                        let mut st = conn.state.lock().await;
6749                        st.pending_requests
6750                            .insert("id1".to_string(), PendingRequest { completion: tx });
6751                    }
6752                    let msg = json!({
6753                        "id": "id1",
6754                        "status": 200,
6755                        "response": ["ok"]
6756                    });
6757                    api.on_message(msg.to_string(), conn.clone()).await;
6758                    let got = rx.await.unwrap().unwrap();
6759                    assert_eq!(got.get("response").unwrap(), &json!(["ok"]));
6760                });
6761            }
6762
6763            #[test]
6764            fn errors_for_status_ge_400() {
6765                TOKIO_SHARED_RT.block_on(async {
6766                    let (api, conn) = create_websocket_api_and_conn();
6767                    let (tx, rx) = oneshot::channel();
6768                    {
6769                        let mut st = conn.state.lock().await;
6770                        st.pending_requests
6771                            .insert("bad".to_string(), PendingRequest { completion: tx });
6772                    }
6773                    let err_obj = json!({"code":123,"msg":"oops"});
6774                    let msg = json!({"id":"bad","status":500,"error":err_obj});
6775                    api.on_message(msg.to_string(), conn.clone()).await;
6776                    match rx.await.unwrap() {
6777                        Err(WebsocketError::ResponseError { code, message }) => {
6778                            assert_eq!(code, 123);
6779                            assert_eq!(message, "oops");
6780                        }
6781                        other => panic!("expected ResponseError, got {other:?}"),
6782                    }
6783                    let st = conn.state.lock().await;
6784                    assert!(!st.pending_requests.contains_key("bad"));
6785                });
6786            }
6787
6788            #[test]
6789            fn ignores_unknown_id() {
6790                TOKIO_SHARED_RT.block_on(async {
6791                    let (api, conn) = create_websocket_api_and_conn();
6792                    let msg = json!({"id":"nope","status":200});
6793                    api.on_message(msg.to_string(), conn.clone()).await;
6794                    let st = conn.state.lock().await;
6795                    assert!(st.pending_requests.is_empty());
6796                });
6797            }
6798
6799            #[test]
6800            fn parse_error_ignored() {
6801                TOKIO_SHARED_RT.block_on(async {
6802                    let (api, conn) = create_websocket_api_and_conn();
6803                    api.on_message("not json".to_string(), conn.clone()).await;
6804                    let st = conn.state.lock().await;
6805                    assert!(st.pending_requests.is_empty());
6806                });
6807            }
6808
6809            #[test]
6810            fn error_status_sends_error() {
6811                TOKIO_SHARED_RT.block_on(async {
6812                    let (api, conn) = create_websocket_api_and_conn();
6813                    let (tx, rx) = oneshot::channel();
6814                    {
6815                        let mut st = conn.state.lock().await;
6816                        st.pending_requests
6817                            .insert("err".to_string(), PendingRequest { completion: tx });
6818                    }
6819                    let msg = json!({
6820                        "id": "err",
6821                        "status": 500,
6822                        "error": { "code": 42, "msg": "Bad!" }
6823                    });
6824                    api.on_message(msg.to_string(), conn.clone()).await;
6825                    match rx.await.unwrap() {
6826                        Err(WebsocketError::ResponseError { code, message }) => {
6827                            assert_eq!(code, 42);
6828                            assert_eq!(message, "Bad!");
6829                        }
6830                        other => panic!("expected ResponseError, got {other:?}"),
6831                    }
6832                });
6833            }
6834
6835            #[test]
6836            fn unknown_id_logs_warning_and_leaves_pending() {
6837                TOKIO_SHARED_RT.block_on(async {
6838                    let (api, conn) = create_websocket_api_and_conn();
6839                    {
6840                        let mut st = conn.state.lock().await;
6841                        st.pending_requests.insert(
6842                            "keep".to_string(),
6843                            PendingRequest {
6844                                completion: oneshot::channel().0,
6845                            },
6846                        );
6847                    }
6848                    api.on_message(
6849                        json!({ "id": "foo", "status": 200, "result": 1 }).to_string(),
6850                        conn.clone(),
6851                    )
6852                    .await;
6853                    let st = conn.state.lock().await;
6854                    assert!(st.pending_requests.contains_key("keep"));
6855                });
6856            }
6857
6858            #[test]
6859            fn server_shutdown_enqueues_reconnect() {
6860                TOKIO_SHARED_RT.block_on(async {
6861                    let (api, conn) = create_websocket_api_and_conn();
6862
6863                    let msg = json!({
6864                        "event": { "e": "serverShutdown" }
6865                    });
6866
6867                    api.on_message(msg.to_string(), conn.clone()).await;
6868
6869                    let st = conn.state.lock().await;
6870                    assert!(st.renewal_pending);
6871                });
6872            }
6873
6874            #[test]
6875            fn server_shutdown_ignored_if_renewal_pending() {
6876                TOKIO_SHARED_RT.block_on(async {
6877                    let (api, conn) = create_websocket_api_and_conn();
6878
6879                    {
6880                        let mut st = conn.state.lock().await;
6881                        st.renewal_pending = true;
6882                    }
6883
6884                    let msg = json!({
6885                        "event": { "e": "serverShutdown" }
6886                    });
6887
6888                    api.on_message(msg.to_string(), conn.clone()).await;
6889
6890                    let st = conn.state.lock().await;
6891
6892                    assert!(st.renewal_pending);
6893                });
6894            }
6895
6896            #[test]
6897            fn server_shutdown_ignored_if_close_initiated() {
6898                TOKIO_SHARED_RT.block_on(async {
6899                    let (api, conn) = create_websocket_api_and_conn();
6900
6901                    {
6902                        let mut st = conn.state.lock().await;
6903                        st.close_initiated = true;
6904                    }
6905
6906                    let msg = json!({
6907                        "event": { "e": "serverShutdown" }
6908                    });
6909
6910                    api.on_message(msg.to_string(), conn.clone()).await;
6911
6912                    let st = conn.state.lock().await;
6913
6914                    assert!(!st.renewal_pending);
6915                });
6916            }
6917
6918            #[test]
6919            fn server_shutdown_does_not_touch_pending_requests() {
6920                TOKIO_SHARED_RT.block_on(async {
6921                    let (api, conn) = create_websocket_api_and_conn();
6922
6923                    {
6924                        let mut st = conn.state.lock().await;
6925                        st.pending_requests.insert(
6926                            "keep".to_string(),
6927                            PendingRequest {
6928                                completion: oneshot::channel().0,
6929                            },
6930                        );
6931                    }
6932
6933                    let msg = json!({
6934                        "event": { "e": "serverShutdown" }
6935                    });
6936
6937                    api.on_message(msg.to_string(), conn.clone()).await;
6938
6939                    let st = conn.state.lock().await;
6940
6941                    assert!(st.pending_requests.contains_key("keep"));
6942                    assert!(st.renewal_pending);
6943                });
6944            }
6945        }
6946    }
6947
6948    mod websocket_streams {
6949        use super::*;
6950
6951        mod initialisation {
6952            use super::*;
6953
6954            #[test]
6955            fn new_initializes_fields() {
6956                TOKIO_SHARED_RT.block_on(async {
6957                    let config = ConfigurationWebsocketStreams {
6958                        ws_url: Some("wss://example".to_string()),
6959                        mode: WebsocketMode::Pool(2),
6960                        reconnect_delay: 500,
6961                        time_unit: None,
6962                        agent: None,
6963                        user_agent: build_user_agent("product"),
6964                    };
6965                    let conn1 = WebsocketConnection::new("c1");
6966                    let conn2 = WebsocketConnection::new("c2");
6967                    let api = WebsocketStreams::new(
6968                        config.clone(),
6969                        vec![conn1.clone(), conn2.clone()],
6970                        vec![],
6971                    );
6972
6973                    assert_eq!(api.common.connection_pool.len(), 2);
6974                    assert!(Arc::ptr_eq(&api.common.connection_pool[0], &conn1));
6975                    assert!(Arc::ptr_eq(&api.common.connection_pool[1], &conn2));
6976                    assert_eq!(api.configuration.ws_url, Some("wss://example".to_string()));
6977                    let flag = api.is_connecting.lock().await;
6978                    assert!(!*flag);
6979                });
6980            }
6981
6982            #[test]
6983            fn new_expands_pool_when_url_paths_present() {
6984                TOKIO_SHARED_RT.block_on(async {
6985                    let config = ConfigurationWebsocketStreams {
6986                        ws_url: Some("wss://example".to_string()),
6987                        mode: WebsocketMode::Pool(2),
6988                        reconnect_delay: 500,
6989                        time_unit: None,
6990                        agent: None,
6991                        user_agent: build_user_agent("product"),
6992                    };
6993
6994                    let conn1 = WebsocketConnection::new("c1");
6995                    let conn2 = WebsocketConnection::new("c2");
6996
6997                    let api = WebsocketStreams::new(
6998                        config,
6999                        vec![conn1.clone(), conn2.clone()],
7000                        vec!["path1".to_string(), "path2".to_string()],
7001                    );
7002
7003                    assert_eq!(api.common.connection_pool.len(), 4);
7004                    assert!(Arc::ptr_eq(&api.common.connection_pool[0], &conn1));
7005                    assert!(Arc::ptr_eq(&api.common.connection_pool[1], &conn2));
7006                });
7007            }
7008
7009            #[test]
7010            fn new_does_not_expand_pool_when_already_sized_for_url_paths() {
7011                TOKIO_SHARED_RT.block_on(async {
7012                    let config = ConfigurationWebsocketStreams {
7013                        ws_url: Some("wss://example".to_string()),
7014                        mode: WebsocketMode::Pool(2),
7015                        reconnect_delay: 500,
7016                        time_unit: None,
7017                        agent: None,
7018                        user_agent: build_user_agent("product"),
7019                    };
7020
7021                    let conns = vec![
7022                        WebsocketConnection::new("c1"),
7023                        WebsocketConnection::new("c2"),
7024                        WebsocketConnection::new("c3"),
7025                        WebsocketConnection::new("c4"),
7026                    ];
7027
7028                    let api = WebsocketStreams::new(
7029                        config,
7030                        conns.clone(),
7031                        vec!["path1".to_string(), "path2".to_string()],
7032                    );
7033
7034                    assert_eq!(api.common.connection_pool.len(), 4);
7035                    for (i, c) in conns.iter().enumerate() {
7036                        assert!(Arc::ptr_eq(&api.common.connection_pool[i], c));
7037                    }
7038                });
7039            }
7040        }
7041
7042        mod connect {
7043            use super::*;
7044
7045            #[test]
7046            fn establishes_successfully() {
7047                TOKIO_SHARED_RT.block_on(async {
7048                    let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
7049                    let port = listener.local_addr().unwrap().port();
7050
7051                    tokio::spawn(async move {
7052                        for _ in 0..2 {
7053                            if let Ok((stream, _)) = listener.accept().await {
7054                                let mut ws = accept_async(stream).await.unwrap();
7055                                ws.close(None).await.ok();
7056                            }
7057                        }
7058                    });
7059
7060                    let create_websocket_streams = |ws_url: &str| {
7061                        let c1 = WebsocketConnection::new("c1");
7062                        let c2 = WebsocketConnection::new("c2");
7063                        let config = ConfigurationWebsocketStreams {
7064                            ws_url: Some(ws_url.to_string()),
7065                            mode: WebsocketMode::Pool(2),
7066                            reconnect_delay: 500,
7067                            time_unit: None,
7068                            agent: None,
7069                            user_agent: build_user_agent("product"),
7070                        };
7071                        WebsocketStreams::new(config, vec![c1, c2], vec![])
7072                    };
7073
7074                    let url = format!("ws://127.0.0.1:{port}");
7075                    let ws = create_websocket_streams(&url);
7076
7077                    let res = ws.connect(vec!["stream1".into()]).await;
7078                    assert!(res.is_ok());
7079                });
7080            }
7081
7082            #[test]
7083            fn establishes_successfully_with_url_paths() {
7084                TOKIO_SHARED_RT.block_on(async {
7085                    let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
7086                    let addr = listener.local_addr().unwrap();
7087
7088                    tokio::spawn(async move {
7089                        for _ in 0..4 {
7090                            if let Ok((stream, _)) = listener.accept().await {
7091                                let mut ws = accept_async(stream).await.unwrap();
7092                                ws.close(None).await.ok();
7093                            }
7094                        }
7095                    });
7096
7097                    let config = ConfigurationWebsocketStreams {
7098                        ws_url: Some(format!("ws://{}", addr)),
7099                        mode: WebsocketMode::Pool(2),
7100                        reconnect_delay: 500,
7101                        time_unit: None,
7102                        agent: None,
7103                        user_agent: build_user_agent("product"),
7104                    };
7105
7106                    let ws = WebsocketStreams::new(
7107                        config,
7108                        vec![],
7109                        vec!["path1".to_string(), "path2".to_string()],
7110                    );
7111
7112                    let res = ws.clone().connect(vec!["stream1".into()]).await;
7113                    assert!(res.is_ok());
7114                    assert_eq!(ws.common.connection_pool.len(), 4);
7115                });
7116            }
7117
7118            #[test]
7119            fn connect_sets_url_path_on_connections_when_url_paths_present() {
7120                TOKIO_SHARED_RT.block_on(async {
7121                    let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
7122                    let addr = listener.local_addr().unwrap();
7123
7124                    tokio::spawn(async move {
7125                        for _ in 0..4 {
7126                            if let Ok((stream, _)) = listener.accept().await {
7127                                let mut ws = accept_async(stream).await.unwrap();
7128                                ws.close(None).await.ok();
7129                            }
7130                        }
7131                    });
7132
7133                    let config = ConfigurationWebsocketStreams {
7134                        ws_url: Some(format!("ws://{}", addr)),
7135                        mode: WebsocketMode::Pool(2),
7136                        reconnect_delay: 500,
7137                        time_unit: None,
7138                        agent: None,
7139                        user_agent: build_user_agent("product"),
7140                    };
7141
7142                    let ws = WebsocketStreams::new(
7143                        config,
7144                        vec![],
7145                        vec!["path1".to_string(), "path2".to_string()],
7146                    );
7147
7148                    ws.clone().connect(vec!["stream1".into()]).await.unwrap();
7149
7150                    let pool_size = ws.configuration.mode.pool_size();
7151
7152                    for (i, conn) in ws.common.connection_pool.iter().enumerate() {
7153                        let expected = if i < pool_size { "path1" } else { "path2" };
7154                        let st = conn.state.lock().await;
7155                        assert_eq!(st.url_path.as_deref(), Some(expected));
7156                    }
7157                });
7158            }
7159
7160            #[test]
7161            fn refused_returns_error() {
7162                TOKIO_SHARED_RT.block_on(async {
7163                    let ws = create_websocket_streams(Some("ws://127.0.0.1:9"), None, None);
7164                    let res = ws.connect(vec!["stream1".into()]).await;
7165                    assert!(res.is_err());
7166                });
7167            }
7168
7169            #[test]
7170            fn invalid_url_returns_error() {
7171                TOKIO_SHARED_RT.block_on(async {
7172                    let ws = create_websocket_streams(Some("not-a-url"), None, None);
7173                    let res = ws.connect(vec!["s".into()]).await;
7174                    assert!(res.is_err());
7175                });
7176            }
7177        }
7178
7179        mod disconnect {
7180            use super::*;
7181
7182            #[test]
7183            fn disconnect_clears_state_and_streams() {
7184                TOKIO_SHARED_RT.block_on(async {
7185                    let ws = create_websocket_streams(None, None, None);
7186                    let conn = &ws.common.connection_pool[0];
7187                    {
7188                        let mut state = conn.state.lock().await;
7189                        state.stream_callbacks.insert("s1".to_string(), Vec::new());
7190                        state.pending_subscriptions.push_back("s2".to_string());
7191                    }
7192                    {
7193                        let mut map = ws.connection_streams.lock().await;
7194                        map.insert("s3".to_string(), Arc::clone(conn));
7195                    }
7196
7197                    let res = ws.disconnect().await;
7198                    assert!(res.is_ok());
7199
7200                    let state = conn.state.lock().await;
7201                    assert!(state.stream_callbacks.is_empty());
7202                    assert!(state.pending_subscriptions.is_empty());
7203
7204                    let map = ws.connection_streams.lock().await;
7205                    assert!(map.is_empty());
7206                });
7207            }
7208        }
7209
7210        mod subscribe {
7211            use super::*;
7212
7213            #[test]
7214            fn empty_list_does_nothing() {
7215                TOKIO_SHARED_RT.block_on(async {
7216                    let ws = create_websocket_streams(None, None, None);
7217                    ws.clone().subscribe(Vec::new(), None, None).await;
7218                    let map = ws.connection_streams.lock().await;
7219                    assert!(map.is_empty());
7220                });
7221            }
7222
7223            #[test]
7224            fn queue_when_not_ready() {
7225                TOKIO_SHARED_RT.block_on(async {
7226                    let ws = create_websocket_streams(None, None, None);
7227                    let conn = ws.common.connection_pool[0].clone();
7228                    ws.clone().subscribe(vec!["s1".into()], None, None).await;
7229                    let state = conn.state.lock().await;
7230                    let pending: Vec<String> =
7231                        state.pending_subscriptions.iter().cloned().collect();
7232                    assert_eq!(pending, vec!["s1".to_string()]);
7233                });
7234            }
7235
7236            #[test]
7237            fn only_one_subscription_per_stream() {
7238                TOKIO_SHARED_RT.block_on(async {
7239                    let ws = create_websocket_streams(None, None, None);
7240                    let conn = ws.common.connection_pool[0].clone();
7241                    ws.clone().subscribe(vec!["s1".into()], None, None).await;
7242                    ws.clone().subscribe(vec!["s1".into()], None, None).await;
7243                    let state = conn.state.lock().await;
7244                    let pending: Vec<String> =
7245                        state.pending_subscriptions.iter().cloned().collect();
7246                    assert_eq!(pending, vec!["s1".to_string()]);
7247                });
7248            }
7249
7250            #[test]
7251            fn multiple_streams_assigned() {
7252                TOKIO_SHARED_RT.block_on(async {
7253                    let ws = create_websocket_streams(None, None, None);
7254                    ws.clone()
7255                        .subscribe(vec!["s1".into(), "s2".into()], None, None)
7256                        .await;
7257                    let map = ws.connection_streams.lock().await;
7258                    assert!(map.contains_key("s1"));
7259                    assert!(map.contains_key("s2"));
7260                });
7261            }
7262
7263            #[test]
7264            fn existing_stream_not_reassigned() {
7265                TOKIO_SHARED_RT.block_on(async {
7266                    let ws = create_websocket_streams(None, None, None);
7267                    ws.clone().subscribe(vec!["s1".into()], None, None).await;
7268                    let first_id = {
7269                        let map = ws.connection_streams.lock().await;
7270                        map.get("s1").unwrap().id.clone()
7271                    };
7272                    ws.clone()
7273                        .subscribe(vec!["s1".into(), "s2".into()], None, None)
7274                        .await;
7275                    let map = ws.connection_streams.lock().await;
7276                    let second_id = map.get("s1").unwrap().id.clone();
7277                    assert_eq!(first_id, second_id);
7278                    assert!(map.contains_key("s2"));
7279                });
7280            }
7281
7282            #[test]
7283            fn queue_when_not_ready_with_url_path() {
7284                TOKIO_SHARED_RT.block_on(async {
7285                    let ws = create_websocket_streams(None, None, None);
7286
7287                    let conn = ws.common.connection_pool[0].clone();
7288                    {
7289                        let mut st = conn.state.lock().await;
7290                        st.ws_write_tx = None;
7291                        st.url_path = Some("path1".to_string());
7292                        st.reconnection_pending = false;
7293                        st.close_initiated = false;
7294                    }
7295
7296                    ws.clone()
7297                        .subscribe(vec!["s1".into()], None, Some("path1"))
7298                        .await;
7299
7300                    let state = conn.state.lock().await;
7301                    let pending: Vec<String> =
7302                        state.pending_subscriptions.iter().cloned().collect();
7303                    assert_eq!(pending, vec!["s1".to_string()]);
7304                });
7305            }
7306
7307            #[test]
7308            fn only_one_subscription_per_stream_per_url_path() {
7309                TOKIO_SHARED_RT.block_on(async {
7310                    let ws = create_websocket_streams(None, None, None);
7311
7312                    let conn = ws.common.connection_pool[0].clone();
7313                    {
7314                        let mut st = conn.state.lock().await;
7315                        st.ws_write_tx = None;
7316                        st.url_path = Some("path1".to_string());
7317                        st.reconnection_pending = false;
7318                        st.close_initiated = false;
7319                    }
7320
7321                    ws.clone()
7322                        .subscribe(vec!["s1".into()], None, Some("path1"))
7323                        .await;
7324                    ws.clone()
7325                        .subscribe(vec!["s1".into()], None, Some("path1"))
7326                        .await;
7327
7328                    let state = conn.state.lock().await;
7329                    let pending: Vec<String> =
7330                        state.pending_subscriptions.iter().cloned().collect();
7331                    assert_eq!(pending, vec!["s1".to_string()]);
7332                });
7333            }
7334
7335            #[test]
7336            fn same_stream_can_be_subscribed_on_different_url_paths() {
7337                TOKIO_SHARED_RT.block_on(async {
7338                    let ws = create_websocket_streams(None, None, None);
7339
7340                    let conn1 = ws.common.connection_pool[0].clone();
7341                    let conn2 = ws.common.connection_pool[1].clone();
7342
7343                    {
7344                        let mut st1 = conn1.state.lock().await;
7345                        st1.ws_write_tx = None;
7346                        st1.url_path = Some("path1".to_string());
7347                        st1.reconnection_pending = false;
7348                        st1.close_initiated = false;
7349                    }
7350                    {
7351                        let mut st2 = conn2.state.lock().await;
7352                        st2.ws_write_tx = None;
7353                        st2.url_path = Some("path2".to_string());
7354                        st2.reconnection_pending = false;
7355                        st2.close_initiated = false;
7356                    }
7357
7358                    ws.clone()
7359                        .subscribe(vec!["s1".into()], None, Some("path1"))
7360                        .await;
7361                    ws.clone()
7362                        .subscribe(vec!["s1".into()], None, Some("path2"))
7363                        .await;
7364
7365                    let map = ws.connection_streams.lock().await;
7366                    assert!(map.contains_key("path1::s1"));
7367                    assert!(map.contains_key("path2::s1"));
7368                });
7369            }
7370        }
7371
7372        mod unsubscribe {
7373            use super::*;
7374
7375            #[test]
7376            fn removes_stream_with_no_callbacks() {
7377                TOKIO_SHARED_RT.block_on(async {
7378                    let ws = create_websocket_streams(None, None, None);
7379                    let conn = ws.common.connection_pool[0].clone();
7380
7381                    {
7382                        let (tx, _rx) = unbounded_channel::<Message>();
7383                        let mut st = conn.state.lock().await;
7384                        st.ws_write_tx = Some(tx);
7385                    }
7386
7387                    {
7388                        let mut map = ws.connection_streams.lock().await;
7389                        map.insert("s1".to_string(), conn.clone());
7390                    }
7391                    {
7392                        let mut st = conn.state.lock().await;
7393                        st.stream_callbacks.insert("s1".to_string(), Vec::new());
7394                    }
7395
7396                    ws.unsubscribe(vec!["s1".to_string()], None, None).await;
7397
7398                    assert!(!ws.connection_streams.lock().await.contains_key("s1"));
7399                    assert!(!conn.state.lock().await.stream_callbacks.contains_key("s1"));
7400                });
7401            }
7402
7403            #[test]
7404            fn preserves_stream_with_callbacks() {
7405                TOKIO_SHARED_RT.block_on(async {
7406                    let ws = create_websocket_streams(None, None, None);
7407                    let conn = ws.common.connection_pool[1].clone();
7408
7409                    {
7410                        let mut map = ws.connection_streams.lock().await;
7411                        map.insert("s2".to_string(), conn.clone());
7412                    }
7413                    {
7414                        let mut state = conn.state.lock().await;
7415                        state
7416                            .stream_callbacks
7417                            .insert("s2".to_string(), vec![Arc::new(|_: &Value| {})]);
7418                    }
7419
7420                    ws.unsubscribe(vec!["s2".to_string()], None, None).await;
7421
7422                    assert!(ws.connection_streams.lock().await.contains_key("s2"));
7423                    assert!(conn.state.lock().await.stream_callbacks.contains_key("s2"));
7424                });
7425            }
7426
7427            #[test]
7428            fn does_not_send_if_callbacks_exist() {
7429                TOKIO_SHARED_RT.block_on(async {
7430                    let ws = create_websocket_streams(None, None, None);
7431                    let conn = ws.common.connection_pool[0].clone();
7432                    {
7433                        let mut map = ws.connection_streams.lock().await;
7434                        map.insert("s1".to_string(), conn.clone());
7435                    }
7436                    {
7437                        let mut state = conn.state.lock().await;
7438                        state.stream_callbacks.insert(
7439                            "s1".to_string(),
7440                            vec![Arc::new(|_: &Value| {}), Arc::new(|_: &Value| {})],
7441                        );
7442                    }
7443                    ws.unsubscribe(vec!["s1".into()], None, None).await;
7444                    assert!(ws.connection_streams.lock().await.contains_key("s1"));
7445                    assert!(conn.state.lock().await.stream_callbacks.contains_key("s1"));
7446                });
7447            }
7448
7449            #[test]
7450            fn warns_if_not_associated() {
7451                TOKIO_SHARED_RT.block_on(async {
7452                    let ws = create_websocket_streams(None, None, None);
7453                    ws.unsubscribe(vec!["nope".into()], None, None).await;
7454                });
7455            }
7456
7457            #[test]
7458            fn empty_list_does_nothing() {
7459                TOKIO_SHARED_RT.block_on(async {
7460                    let ws = create_websocket_streams(None, None, None);
7461                    let before = ws.connection_streams.lock().await.len();
7462                    ws.unsubscribe(Vec::<String>::new(), None, None).await;
7463                    let after = ws.connection_streams.lock().await.len();
7464                    assert_eq!(before, after);
7465                });
7466            }
7467
7468            #[test]
7469            fn invalid_custom_id_falls_back() {
7470                TOKIO_SHARED_RT.block_on(async {
7471                    let ws = create_websocket_streams(None, None, None);
7472                    let conn = ws.common.connection_pool[0].clone();
7473                    {
7474                        let mut map = ws.connection_streams.lock().await;
7475                        map.insert("foo".to_string(), conn.clone());
7476                    }
7477                    {
7478                        let mut state = conn.state.lock().await;
7479                        let (tx, _rx) = unbounded_channel();
7480                        state.ws_write_tx = Some(tx);
7481                        state.stream_callbacks.insert("foo".to_string(), Vec::new());
7482                    }
7483                    ws.unsubscribe(
7484                        vec!["foo".into()],
7485                        Some(StreamId::Str("bad-id".into())),
7486                        None,
7487                    )
7488                    .await;
7489                    assert!(!ws.connection_streams.lock().await.contains_key("foo"));
7490                });
7491            }
7492
7493            #[test]
7494            fn removes_even_without_write_channel() {
7495                TOKIO_SHARED_RT.block_on(async {
7496                    let ws = create_websocket_streams(None, None, None);
7497                    let conn = ws.common.connection_pool[0].clone();
7498                    {
7499                        let mut map = ws.connection_streams.lock().await;
7500                        map.insert("x".to_string(), conn.clone());
7501                    }
7502                    {
7503                        let mut state = conn.state.lock().await;
7504                        let (tx, _rx) = unbounded_channel();
7505                        state.ws_write_tx = Some(tx);
7506                        state.stream_callbacks.insert("x".to_string(), Vec::new());
7507                    }
7508                    ws.unsubscribe(vec!["x".into()], None, None).await;
7509                    assert!(!ws.connection_streams.lock().await.contains_key("x"));
7510                });
7511            }
7512
7513            #[test]
7514            fn removes_stream_with_no_callbacks_with_url_path() {
7515                TOKIO_SHARED_RT.block_on(async {
7516                    let ws = create_websocket_streams(None, None, None);
7517                    let conn = ws.common.connection_pool[0].clone();
7518
7519                    {
7520                        let (tx, _rx) = unbounded_channel::<Message>();
7521                        let mut st = conn.state.lock().await;
7522                        st.ws_write_tx = Some(tx);
7523                        st.url_path = Some("path1".to_string());
7524                    }
7525
7526                    {
7527                        let mut map = ws.connection_streams.lock().await;
7528                        map.insert("path1::s1".to_string(), conn.clone());
7529                    }
7530                    {
7531                        let mut st = conn.state.lock().await;
7532                        st.stream_callbacks
7533                            .insert("path1::s1".to_string(), Vec::new());
7534                    }
7535
7536                    ws.unsubscribe(vec!["s1".to_string()], None, Some("path1"))
7537                        .await;
7538
7539                    assert!(!ws.connection_streams.lock().await.contains_key("path1::s1"));
7540                    assert!(
7541                        !conn
7542                            .state
7543                            .lock()
7544                            .await
7545                            .stream_callbacks
7546                            .contains_key("path1::s1")
7547                    );
7548                });
7549            }
7550
7551            #[test]
7552            fn preserves_stream_with_callbacks_with_url_path() {
7553                TOKIO_SHARED_RT.block_on(async {
7554                    let ws = create_websocket_streams(None, None, None);
7555                    let conn = ws.common.connection_pool[0].clone();
7556
7557                    {
7558                        let (tx, _rx) = unbounded_channel::<Message>();
7559                        let mut st = conn.state.lock().await;
7560                        st.ws_write_tx = Some(tx);
7561                        st.url_path = Some("path1".to_string());
7562                    }
7563
7564                    {
7565                        let mut map = ws.connection_streams.lock().await;
7566                        map.insert("path1::s2".to_string(), conn.clone());
7567                    }
7568                    {
7569                        let mut state = conn.state.lock().await;
7570                        state
7571                            .stream_callbacks
7572                            .insert("path1::s2".to_string(), vec![Arc::new(|_: &Value| {})]);
7573                    }
7574
7575                    ws.unsubscribe(vec!["s2".to_string()], None, Some("path1"))
7576                        .await;
7577
7578                    assert!(ws.connection_streams.lock().await.contains_key("path1::s2"));
7579                    assert!(
7580                        conn.state
7581                            .lock()
7582                            .await
7583                            .stream_callbacks
7584                            .contains_key("path1::s2")
7585                    );
7586                });
7587            }
7588
7589            #[test]
7590            fn url_path_mismatch_does_not_remove_other_path_subscription() {
7591                TOKIO_SHARED_RT.block_on(async {
7592                    let ws = create_websocket_streams(None, None, None);
7593                    let conn = ws.common.connection_pool[0].clone();
7594
7595                    {
7596                        let (tx, _rx) = unbounded_channel::<Message>();
7597                        let mut st = conn.state.lock().await;
7598                        st.ws_write_tx = Some(tx);
7599                        st.url_path = Some("path1".to_string());
7600                    }
7601
7602                    {
7603                        let mut map = ws.connection_streams.lock().await;
7604                        map.insert("path1::s1".to_string(), conn.clone());
7605                    }
7606                    {
7607                        let mut st = conn.state.lock().await;
7608                        st.stream_callbacks
7609                            .insert("path1::s1".to_string(), Vec::new());
7610                    }
7611
7612                    ws.unsubscribe(vec!["s1".to_string()], None, Some("path2"))
7613                        .await;
7614
7615                    assert!(ws.connection_streams.lock().await.contains_key("path1::s1"));
7616                    assert!(
7617                        conn.state
7618                            .lock()
7619                            .await
7620                            .stream_callbacks
7621                            .contains_key("path1::s1")
7622                    );
7623                });
7624            }
7625        }
7626
7627        mod is_subscribed {
7628            use super::*;
7629
7630            #[test]
7631            fn returns_false_when_not_subscribed() {
7632                TOKIO_SHARED_RT.block_on(async {
7633                    let ws = create_websocket_streams(None, None, None);
7634                    assert!(!ws.is_subscribed("unknown").await);
7635                });
7636            }
7637
7638            #[test]
7639            fn returns_true_when_subscribed() {
7640                TOKIO_SHARED_RT.block_on(async {
7641                    let ws = create_websocket_streams(None, None, None);
7642                    let conn = ws.common.connection_pool[0].clone();
7643                    {
7644                        let mut map = ws.connection_streams.lock().await;
7645                        map.insert("stream1".to_string(), conn);
7646                    }
7647                    assert!(ws.is_subscribed("stream1").await);
7648                });
7649            }
7650
7651            #[test]
7652            fn returns_true_when_subscribed_with_url_path_key() {
7653                TOKIO_SHARED_RT.block_on(async {
7654                    let ws = create_websocket_streams(None, None, None);
7655                    let conn = ws.common.connection_pool[0].clone();
7656                    {
7657                        let mut map = ws.connection_streams.lock().await;
7658                        map.insert("path1::stream1".to_string(), conn);
7659                    }
7660                    assert!(ws.is_subscribed("stream1").await);
7661                });
7662            }
7663
7664            #[test]
7665            fn returns_true_when_same_stream_subscribed_on_multiple_paths() {
7666                TOKIO_SHARED_RT.block_on(async {
7667                    let ws = create_websocket_streams(None, None, None);
7668                    let conn1 = ws.common.connection_pool[0].clone();
7669                    let conn2 = ws.common.connection_pool[1].clone();
7670                    {
7671                        let mut map = ws.connection_streams.lock().await;
7672                        map.insert("path1::stream1".to_string(), conn1);
7673                        map.insert("path2::stream1".to_string(), conn2);
7674                    }
7675                    assert!(ws.is_subscribed("stream1").await);
7676                });
7677            }
7678
7679            #[test]
7680            fn returns_false_when_only_similar_suffix_exists() {
7681                TOKIO_SHARED_RT.block_on(async {
7682                    let ws = create_websocket_streams(None, None, None);
7683                    let conn = ws.common.connection_pool[0].clone();
7684                    {
7685                        let mut map = ws.connection_streams.lock().await;
7686                        map.insert("path1::stream10".to_string(), conn);
7687                    }
7688                    assert!(!ws.is_subscribed("stream1").await);
7689                });
7690            }
7691        }
7692
7693        mod stream_key {
7694            use super::*;
7695
7696            #[test]
7697            fn stream_key_without_url_path_returns_stream() {
7698                TOKIO_SHARED_RT.block_on(async {
7699                    let ws = create_websocket_streams(None, None, None);
7700                    assert_eq!(ws.stream_key("s1", None), "s1");
7701                });
7702            }
7703
7704            #[test]
7705            fn stream_key_with_empty_url_path_returns_stream() {
7706                TOKIO_SHARED_RT.block_on(async {
7707                    let ws = create_websocket_streams(None, None, None);
7708                    assert_eq!(ws.stream_key("s1", Some("")), "s1");
7709                });
7710            }
7711
7712            #[test]
7713            fn stream_key_with_url_path_prefixes_stream() {
7714                TOKIO_SHARED_RT.block_on(async {
7715                    let ws = create_websocket_streams(None, None, None);
7716                    assert_eq!(ws.stream_key("s1", Some("path1")), "path1::s1");
7717                });
7718            }
7719
7720            #[test]
7721            fn stream_key_distinguishes_paths() {
7722                TOKIO_SHARED_RT.block_on(async {
7723                    let ws = create_websocket_streams(None, None, None);
7724                    assert_eq!(ws.stream_key("s1", Some("path1")), "path1::s1");
7725                    assert_eq!(ws.stream_key("s1", Some("path2")), "path2::s1");
7726                });
7727            }
7728        }
7729
7730        mod prepare_url {
7731            use super::*;
7732
7733            #[test]
7734            fn without_time_unit_returns_base_url() {
7735                TOKIO_SHARED_RT.block_on(async {
7736                    let conns = vec![
7737                        WebsocketConnection::new("c1"),
7738                        WebsocketConnection::new("c2"),
7739                    ];
7740                    let config = ConfigurationWebsocketStreams {
7741                        ws_url: Some("wss://example".to_string()),
7742                        mode: WebsocketMode::Single,
7743                        reconnect_delay: 100,
7744                        time_unit: None,
7745                        agent: None,
7746                        user_agent: build_user_agent("product"),
7747                    };
7748                    let ws = WebsocketStreams::new(config, conns, vec![]);
7749                    let url = ws.prepare_url(&["s1".into(), "s2".into()], None);
7750                    assert_eq!(url, "wss://example/stream?streams=s1/s2");
7751                });
7752            }
7753
7754            #[test]
7755            fn with_time_unit_appends_parameter() {
7756                TOKIO_SHARED_RT.block_on(async {
7757                    let conns = vec![WebsocketConnection::new("c1")];
7758                    let config = ConfigurationWebsocketStreams {
7759                        ws_url: Some("wss://example".to_string()),
7760                        mode: WebsocketMode::Single,
7761                        reconnect_delay: 100,
7762                        time_unit: Some(TimeUnit::Millisecond),
7763                        agent: None,
7764                        user_agent: build_user_agent("product"),
7765                    };
7766                    let ws = WebsocketStreams::new(config, conns, vec![]);
7767                    let url = ws.prepare_url(&["a".into()], None);
7768                    assert_eq!(url, "wss://example/stream?streams=a&timeUnit=millisecond");
7769                });
7770            }
7771
7772            #[test]
7773            fn multiple_streams_and_time_unit() {
7774                TOKIO_SHARED_RT.block_on(async {
7775                    let conns = vec![WebsocketConnection::new("c1")];
7776                    let config = ConfigurationWebsocketStreams {
7777                        ws_url: Some("wss://example".to_string()),
7778                        mode: WebsocketMode::Single,
7779                        reconnect_delay: 100,
7780                        time_unit: Some(TimeUnit::Microsecond),
7781                        agent: None,
7782                        user_agent: build_user_agent("product"),
7783                    };
7784                    let ws = WebsocketStreams::new(config, conns, vec![]);
7785                    let url = ws.prepare_url(&["x".into(), "y".into(), "z".into()], None);
7786                    assert_eq!(
7787                        url,
7788                        "wss://example/stream?streams=x/y/z&timeUnit=microsecond"
7789                    );
7790                });
7791            }
7792
7793            #[test]
7794            fn with_url_path_prefixes_base_url() {
7795                TOKIO_SHARED_RT.block_on(async {
7796                    let conns = vec![WebsocketConnection::new("c1")];
7797                    let config = ConfigurationWebsocketStreams {
7798                        ws_url: Some("wss://example".to_string()),
7799                        mode: WebsocketMode::Single,
7800                        reconnect_delay: 100,
7801                        time_unit: None,
7802                        agent: None,
7803                        user_agent: build_user_agent("product"),
7804                    };
7805                    let ws = WebsocketStreams::new(config, conns, vec![]);
7806                    let url = ws.prepare_url(["s1".into()].as_ref(), Some("path1"));
7807                    assert_eq!(url, "wss://example/path1/stream?streams=s1");
7808                });
7809            }
7810
7811            #[test]
7812            fn with_url_path_and_time_unit_appends_parameter() {
7813                TOKIO_SHARED_RT.block_on(async {
7814                    let conns = vec![WebsocketConnection::new("c1")];
7815                    let config = ConfigurationWebsocketStreams {
7816                        ws_url: Some("wss://example".to_string()),
7817                        mode: WebsocketMode::Single,
7818                        reconnect_delay: 100,
7819                        time_unit: Some(TimeUnit::Millisecond),
7820                        agent: None,
7821                        user_agent: build_user_agent("product"),
7822                    };
7823                    let ws = WebsocketStreams::new(config, conns, vec![]);
7824                    let url = ws.prepare_url(["a".into()].as_ref(), Some("path1"));
7825                    assert_eq!(
7826                        url,
7827                        "wss://example/path1/stream?streams=a&timeUnit=millisecond"
7828                    );
7829                });
7830            }
7831
7832            #[test]
7833            fn url_path_distinguishes_urls_for_same_streams() {
7834                TOKIO_SHARED_RT.block_on(async {
7835                    let conns = vec![WebsocketConnection::new("c1")];
7836                    let config = ConfigurationWebsocketStreams {
7837                        ws_url: Some("wss://example".to_string()),
7838                        mode: WebsocketMode::Single,
7839                        reconnect_delay: 100,
7840                        time_unit: None,
7841                        agent: None,
7842                        user_agent: build_user_agent("product"),
7843                    };
7844                    let ws = WebsocketStreams::new(config, conns, vec![]);
7845                    let u1 = ws.prepare_url(["s1".into()].as_ref(), Some("path1"));
7846                    let u2 = ws.prepare_url(["s1".into()].as_ref(), Some("path2"));
7847                    assert_eq!(u1, "wss://example/path1/stream?streams=s1");
7848                    assert_eq!(u2, "wss://example/path2/stream?streams=s1");
7849                });
7850            }
7851        }
7852
7853        mod handle_stream_assignment {
7854            use super::*;
7855
7856            #[test]
7857            fn assigns_new_streams_to_connections() {
7858                TOKIO_SHARED_RT.block_on(async {
7859                    let ws = create_websocket_streams(None, None, None);
7860                    let groups = ws
7861                        .clone()
7862                        .handle_stream_assignment(vec!["s1".into(), "s2".into()], None)
7863                        .await;
7864                    let mut seen_streams = HashSet::new();
7865                    for (_conn, streams) in &groups {
7866                        for s in streams {
7867                            seen_streams.insert(s);
7868                        }
7869                    }
7870                    assert_eq!(
7871                        seen_streams,
7872                        ["s1".to_string(), "s2".to_string()].iter().collect()
7873                    );
7874                    assert_eq!(groups.len(), 1);
7875                });
7876            }
7877
7878            #[test]
7879            fn reuses_existing_connection_for_duplicate_stream() {
7880                TOKIO_SHARED_RT.block_on(async {
7881                    let ws = create_websocket_streams(None, None, None);
7882                    let _ = ws
7883                        .clone()
7884                        .handle_stream_assignment(vec!["s1".into()], None)
7885                        .await;
7886                    let groups = ws
7887                        .clone()
7888                        .handle_stream_assignment(vec!["s1".into(), "s3".into()], None)
7889                        .await;
7890                    let mut all_streams = Vec::new();
7891                    for (_conn, streams) in groups {
7892                        all_streams.extend(streams);
7893                    }
7894                    all_streams.sort();
7895                    assert_eq!(all_streams, vec!["s1".to_string(), "s3".to_string()]);
7896                });
7897            }
7898
7899            #[test]
7900            fn empty_stream_list_returns_empty() {
7901                TOKIO_SHARED_RT.block_on(async {
7902                    let ws = create_websocket_streams(None, None, None);
7903                    let groups = ws.clone().handle_stream_assignment(vec![], None).await;
7904                    assert!(groups.is_empty());
7905                });
7906            }
7907
7908            #[test]
7909            fn closed_or_reconnecting_forces_reassignment_of_stream() {
7910                TOKIO_SHARED_RT.block_on(async {
7911                    let ws = create_websocket_streams(None, None, None);
7912                    let mut groups = ws
7913                        .clone()
7914                        .handle_stream_assignment(vec!["s1".into()], None)
7915                        .await;
7916                    let (conn, _) = groups.pop().unwrap();
7917                    {
7918                        let mut st = conn.state.lock().await;
7919                        st.close_initiated = true;
7920                    }
7921                    let groups2 = ws
7922                        .clone()
7923                        .handle_stream_assignment(vec!["s2".into()], None)
7924                        .await;
7925                    assert_eq!(groups2.len(), 1);
7926                    let (_new_conn, streams) = &groups2[0];
7927                    assert_eq!(streams, &vec!["s2".to_string()]);
7928                });
7929            }
7930
7931            #[test]
7932            fn no_available_connections_falls_back_to_one() {
7933                TOKIO_SHARED_RT.block_on(async {
7934                    let ws = create_websocket_streams(None, Some(vec![]), None);
7935                    let assigned = ws.handle_stream_assignment(vec!["foo".into()], None).await;
7936                    assert_eq!(assigned.len(), 1);
7937                    let (_conn, streams) = &assigned[0];
7938                    assert_eq!(streams.as_slice(), &["foo".to_string()]);
7939                });
7940            }
7941
7942            #[test]
7943            fn single_connection_groups_multiple_streams() {
7944                TOKIO_SHARED_RT.block_on(async {
7945                    let conn = WebsocketConnection::new("c1");
7946                    let ws = create_websocket_streams(None, Some(vec![conn.clone()]), None);
7947                    let assigned = ws
7948                        .handle_stream_assignment(vec!["s1".into(), "s2".into()], None)
7949                        .await;
7950                    assert_eq!(assigned.len(), 1);
7951                    let (assigned_conn, streams) = &assigned[0];
7952                    assert!(Arc::ptr_eq(assigned_conn, &conn));
7953                    assert_eq!(streams.len(), 2);
7954                    assert!(streams.contains(&"s1".to_string()));
7955                    assert!(streams.contains(&"s2".to_string()));
7956                });
7957            }
7958
7959            #[test]
7960            fn reuse_existing_healthy_connection() {
7961                TOKIO_SHARED_RT.block_on(async {
7962                    let conn = WebsocketConnection::new("c");
7963                    let ws = create_websocket_streams(None, Some(vec![conn.clone()]), None);
7964                    let _ = ws.handle_stream_assignment(vec!["s1".into()], None).await;
7965                    let second = ws.handle_stream_assignment(vec!["s1".into()], None).await;
7966                    assert_eq!(second.len(), 1);
7967                    let (assigned_conn, streams) = &second[0];
7968                    assert!(Arc::ptr_eq(assigned_conn, &conn));
7969                    assert_eq!(streams.as_slice(), &["s1".to_string()]);
7970                });
7971            }
7972
7973            #[test]
7974            fn mix_new_and_assigned_streams() {
7975                TOKIO_SHARED_RT.block_on(async {
7976                    let conn = WebsocketConnection::new("c");
7977                    let ws = create_websocket_streams(None, Some(vec![conn.clone()]), None);
7978                    let _ = ws
7979                        .handle_stream_assignment(vec!["s1".into(), "s2".into()], None)
7980                        .await;
7981                    let mixed = ws
7982                        .handle_stream_assignment(vec!["s2".into(), "s3".into()], None)
7983                        .await;
7984                    assert_eq!(mixed.len(), 1);
7985                    let (assigned_conn, streams) = &mixed[0];
7986                    assert!(Arc::ptr_eq(assigned_conn, &conn));
7987                    let mut got = streams.clone();
7988                    got.sort();
7989                    assert_eq!(got, vec!["s2".to_string(), "s3".to_string()]);
7990                });
7991            }
7992
7993            #[test]
7994            fn assigns_streams_with_url_path_keys() {
7995                TOKIO_SHARED_RT.block_on(async {
7996                    let ws = create_websocket_streams(None, None, None);
7997
7998                    let conn = ws.common.connection_pool[0].clone();
7999                    {
8000                        let mut st = conn.state.lock().await;
8001                        st.url_path = Some("path1".to_string());
8002                        st.ws_write_tx = None;
8003                        st.reconnection_pending = false;
8004                        st.close_initiated = false;
8005                    }
8006
8007                    let groups = ws
8008                        .handle_stream_assignment(vec!["s1".into(), "s2".into()], Some("path1"))
8009                        .await;
8010
8011                    let map = ws.connection_streams.lock().await;
8012                    assert!(map.contains_key("path1::s1"));
8013                    assert!(map.contains_key("path1::s2"));
8014                    assert_eq!(groups.len(), 1);
8015
8016                    let (_assigned_conn, streams) = &groups[0];
8017                    let mut got = streams.clone();
8018                    got.sort();
8019                    assert_eq!(got, vec!["s1".to_string(), "s2".to_string()]);
8020                });
8021            }
8022
8023            #[test]
8024            fn same_stream_on_different_paths_creates_distinct_keys() {
8025                TOKIO_SHARED_RT.block_on(async {
8026                    let ws = create_websocket_streams(None, None, None);
8027
8028                    let conn1 = ws.common.connection_pool[0].clone();
8029                    let conn2 = ws.common.connection_pool[1].clone();
8030
8031                    {
8032                        let mut st = conn1.state.lock().await;
8033                        st.url_path = Some("path1".to_string());
8034                        st.ws_write_tx = None;
8035                        st.reconnection_pending = false;
8036                        st.close_initiated = false;
8037                    }
8038                    {
8039                        let mut st = conn2.state.lock().await;
8040                        st.url_path = Some("path2".to_string());
8041                        st.ws_write_tx = None;
8042                        st.reconnection_pending = false;
8043                        st.close_initiated = false;
8044                    }
8045
8046                    let g1 = ws
8047                        .handle_stream_assignment(vec!["s1".into()], Some("path1"))
8048                        .await;
8049                    let g2 = ws
8050                        .handle_stream_assignment(vec!["s1".into()], Some("path2"))
8051                        .await;
8052
8053                    assert_eq!(g1.len(), 1);
8054                    assert_eq!(g2.len(), 1);
8055
8056                    let map = ws.connection_streams.lock().await;
8057                    assert!(map.contains_key("path1::s1"));
8058                    assert!(map.contains_key("path2::s1"));
8059                });
8060            }
8061
8062            #[test]
8063            fn reuses_existing_connection_for_same_path_and_stream() {
8064                TOKIO_SHARED_RT.block_on(async {
8065                    let ws = create_websocket_streams(None, None, None);
8066
8067                    let conn = ws.common.connection_pool[0].clone();
8068                    {
8069                        let mut st = conn.state.lock().await;
8070                        st.url_path = Some("path1".to_string());
8071                        st.ws_write_tx = None;
8072                        st.reconnection_pending = false;
8073                        st.close_initiated = false;
8074                    }
8075
8076                    let first = ws
8077                        .handle_stream_assignment(vec!["s1".into()], Some("path1"))
8078                        .await;
8079                    let second = ws
8080                        .handle_stream_assignment(vec!["s1".into(), "s2".into()], Some("path1"))
8081                        .await;
8082
8083                    assert_eq!(first.len(), 1);
8084                    assert_eq!(second.len(), 1);
8085
8086                    let map = ws.connection_streams.lock().await;
8087                    let c1 = map.get("path1::s1").unwrap().clone();
8088                    let c2 = map.get("path1::s2").unwrap().clone();
8089                    assert!(Arc::ptr_eq(&c1, &c2));
8090                });
8091            }
8092
8093            #[test]
8094            fn closed_or_reconnecting_forces_reassignment_with_url_path() {
8095                TOKIO_SHARED_RT.block_on(async {
8096                    let ws = create_websocket_streams(None, None, None);
8097
8098                    let conn1 = ws.common.connection_pool[0].clone();
8099                    let conn2 = ws.common.connection_pool[1].clone();
8100
8101                    {
8102                        let mut st = conn1.state.lock().await;
8103                        st.url_path = Some("path1".to_string());
8104                        st.ws_write_tx = None;
8105                        st.reconnection_pending = false;
8106                        st.close_initiated = false;
8107                    }
8108                    {
8109                        let mut st = conn2.state.lock().await;
8110                        st.url_path = Some("path1".to_string());
8111                        st.ws_write_tx = None;
8112                        st.reconnection_pending = false;
8113                        st.close_initiated = false;
8114                    }
8115
8116                    let _ = ws
8117                        .handle_stream_assignment(vec!["s1".into()], Some("path1"))
8118                        .await;
8119
8120                    {
8121                        let mut st = conn1.state.lock().await;
8122                        st.close_initiated = true;
8123                    }
8124
8125                    let _ = ws
8126                        .handle_stream_assignment(vec!["s1".into()], Some("path1"))
8127                        .await;
8128
8129                    let map = ws.connection_streams.lock().await;
8130                    let assigned = map.get("path1::s1").unwrap().clone();
8131                    assert!(!Arc::ptr_eq(&assigned, &conn1));
8132                });
8133            }
8134        }
8135
8136        mod send_subscription_payload {
8137            use super::*;
8138
8139            #[test]
8140            fn subscribe_payload_with_custom_id_fallbacks_if_invalid() {
8141                TOKIO_SHARED_RT.block_on(async {
8142                    let ws: Arc<WebsocketStreams> =
8143                        create_websocket_streams(Some("ws://example.com"), None, None);
8144                    let conn = &ws.common.connection_pool[0];
8145                    let (tx, mut rx) = unbounded_channel();
8146                    {
8147                        let mut st = conn.state.lock().await;
8148                        st.ws_write_tx = Some(tx);
8149                    }
8150                    let id = Some("badid".to_string());
8151                    ws.send_subscription_payload(
8152                        conn,
8153                        &vec!["s1".to_string()],
8154                        id.map(StreamId::from),
8155                    );
8156                    let msg = rx.recv().await.expect("no message sent");
8157                    if let Message::Text(txt) = msg {
8158                        let v: serde_json::Value = serde_json::from_str(&txt).unwrap();
8159                        assert_eq!(v["method"], "SUBSCRIBE");
8160                        let id = v["id"].as_str().unwrap();
8161                        assert_ne!(id, "badid");
8162                        assert!(Regex::new(r"^[0-9a-fA-F]{32}$").unwrap().is_match(id));
8163                    } else {
8164                        panic!("unexpected message: {msg:?}");
8165                    }
8166                });
8167            }
8168
8169            #[test]
8170            fn subscribe_payload_with_and_without_custom_string_id() {
8171                TOKIO_SHARED_RT.block_on(async {
8172                    let ws: Arc<WebsocketStreams> =
8173                        create_websocket_streams(Some("ws://unused"), None, None);
8174                    let conn = &ws.common.connection_pool[0];
8175                    let (tx, mut rx) = unbounded_channel();
8176                    {
8177                        let mut st = conn.state.lock().await;
8178                        st.ws_write_tx = Some(tx);
8179                    }
8180                    let id = Some("deadbeefdeadbeefdeadbeefdeadbeef".to_string());
8181                    ws.send_subscription_payload(
8182                        conn,
8183                        &vec!["a".to_string(), "b".to_string()],
8184                        id.map(StreamId::from),
8185                    );
8186                    let msg1 = rx.recv().await.unwrap();
8187                    ws.send_subscription_payload(conn, &vec!["x".to_string()], None);
8188                    let msg2 = rx.recv().await.unwrap();
8189
8190                    if let Message::Text(txt1) = msg1 {
8191                        let v1: serde_json::Value = serde_json::from_str(&txt1).unwrap();
8192                        assert_eq!(v1["id"], "deadbeefdeadbeefdeadbeefdeadbeef");
8193                        assert_eq!(
8194                            v1["params"].as_array().unwrap(),
8195                            &vec![serde_json::json!("a"), serde_json::json!("b")]
8196                        );
8197                    } else {
8198                        panic!()
8199                    }
8200
8201                    if let Message::Text(txt2) = msg2 {
8202                        let v2: serde_json::Value = serde_json::from_str(&txt2).unwrap();
8203                        assert_eq!(v2["method"], "SUBSCRIBE");
8204                        let params = v2["params"].as_array().unwrap();
8205                        assert_eq!(params.len(), 1);
8206                        assert_eq!(params[0], "x");
8207                        let id2 = v2["id"].as_str().unwrap();
8208                        assert!(Regex::new(r"^[0-9a-fA-F]{32}$").unwrap().is_match(id2));
8209                    } else {
8210                        panic!()
8211                    }
8212                });
8213            }
8214
8215            #[test]
8216            fn subscribe_payload_with_and_without_custom_integer_id() {
8217                TOKIO_SHARED_RT.block_on(async {
8218                    let ws: Arc<WebsocketStreams> =
8219                        create_websocket_streams(Some("ws://unused"), None, None);
8220                    ws.stream_id_is_strictly_number
8221                        .store(true, Ordering::Relaxed);
8222                    let conn = &ws.common.connection_pool[0];
8223                    let (tx, mut rx) = unbounded_channel();
8224                    {
8225                        let mut st = conn.state.lock().await;
8226                        st.ws_write_tx = Some(tx);
8227                    }
8228
8229                    let id = Some(123u32);
8230
8231                    ws.send_subscription_payload(
8232                        conn,
8233                        &vec!["a".to_string(), "b".to_string()],
8234                        id.map(StreamId::from),
8235                    );
8236                    let msg1 = rx.recv().await.unwrap();
8237
8238                    ws.send_subscription_payload(conn, &vec!["x".to_string()], None);
8239                    let msg2 = rx.recv().await.unwrap();
8240
8241                    if let Message::Text(txt1) = msg1 {
8242                        let v1: serde_json::Value = serde_json::from_str(&txt1).unwrap();
8243                        assert_eq!(v1["method"], "SUBSCRIBE");
8244                        assert_eq!(v1["id"].as_u64(), Some(123));
8245                        assert_eq!(
8246                            v1["params"].as_array().unwrap(),
8247                            &vec![serde_json::json!("a"), serde_json::json!("b")]
8248                        );
8249                    } else {
8250                        panic!("Expected Message::Text for msg1");
8251                    }
8252
8253                    if let Message::Text(txt2) = msg2 {
8254                        let v2: serde_json::Value = serde_json::from_str(&txt2).unwrap();
8255                        assert_eq!(v2["method"], "SUBSCRIBE");
8256
8257                        let params = v2["params"].as_array().unwrap();
8258                        assert_eq!(params.len(), 1);
8259                        assert_eq!(params[0], "x");
8260
8261                        let id2 = v2.get("id").expect("payload should contain id");
8262                        assert!(
8263                            id2.is_number(),
8264                            "expected numeric id in strict-number mode, got: {id2:?}"
8265                        );
8266                        let n = id2.as_u64().unwrap();
8267                        assert!(u32::try_from(n).is_ok(), "id should fit u32, got {n}");
8268                    } else {
8269                        panic!("Expected Message::Text for msg2");
8270                    }
8271                });
8272            }
8273        }
8274
8275        mod on_open {
8276            use super::*;
8277
8278            #[test]
8279            fn sends_pending_subscriptions() {
8280                TOKIO_SHARED_RT.block_on(async {
8281                    let ws: Arc<WebsocketStreams> =
8282                        create_websocket_streams(Some("ws://example.com"), None, None);
8283                    let conn = &ws.common.connection_pool[0];
8284                    let (tx, mut rx) = unbounded_channel();
8285                    {
8286                        let mut st = conn.state.lock().await;
8287                        st.ws_write_tx = Some(tx);
8288                        st.pending_subscriptions.push_back("foo".to_string());
8289                        st.pending_subscriptions.push_back("bar".to_string());
8290                    }
8291                    ws.on_open("ws://example.com".to_string(), conn.clone())
8292                        .await;
8293                    let msg = rx.recv().await.expect("no subscription sent");
8294                    if let Message::Text(txt) = msg {
8295                        let v: Value = serde_json::from_str(&txt).unwrap();
8296                        assert_eq!(v["method"], "SUBSCRIBE");
8297                        let params = v["params"].as_array().unwrap();
8298                        assert_eq!(
8299                            params,
8300                            &vec![Value::String("foo".into()), Value::String("bar".into())]
8301                        );
8302                    } else {
8303                        panic!("unexpected message: {msg:?}");
8304                    }
8305                    let st_after = conn.state.lock().await;
8306                    assert!(st_after.pending_subscriptions.is_empty());
8307                });
8308            }
8309
8310            #[test]
8311            fn with_no_pending_subscriptions_sends_nothing() {
8312                TOKIO_SHARED_RT.block_on(async {
8313                    let ws: Arc<WebsocketStreams> =
8314                        create_websocket_streams(Some("ws://example.com"), None, None);
8315                    let conn = &ws.common.connection_pool[0];
8316                    let (tx, mut rx) = unbounded_channel();
8317                    {
8318                        let mut st = conn.state.lock().await;
8319                        st.ws_write_tx = Some(tx);
8320                    }
8321                    ws.on_open("ws://example.com".to_string(), conn.clone())
8322                        .await;
8323                    assert!(rx.try_recv().is_err(), "unexpected message sent");
8324                });
8325            }
8326
8327            #[test]
8328            fn clears_pending_without_write_channel() {
8329                TOKIO_SHARED_RT.block_on(async {
8330                    let ws: Arc<WebsocketStreams> =
8331                        create_websocket_streams(Some("ws://example.com"), None, None);
8332                    let conn = &ws.common.connection_pool[0];
8333                    {
8334                        let mut st = conn.state.lock().await;
8335                        st.pending_subscriptions.push_back("solo".to_string());
8336                    }
8337                    ws.on_open("ws://example.com".to_string(), conn.clone())
8338                        .await;
8339                    let st_after = conn.state.lock().await;
8340                    assert!(st_after.pending_subscriptions.is_empty());
8341                });
8342            }
8343        }
8344
8345        mod on_message {
8346            use super::*;
8347
8348            #[test]
8349            fn invokes_registered_callback() {
8350                TOKIO_SHARED_RT.block_on(async {
8351                    let ws: Arc<WebsocketStreams> =
8352                        create_websocket_streams(Some("ws://example.com"), None, None);
8353                    let conn = &ws.common.connection_pool[0];
8354                    let called = Arc::new(AtomicBool::new(false));
8355                    let called_clone = called.clone();
8356
8357                    {
8358                        let mut st = conn.state.lock().await;
8359                        st.stream_callbacks
8360                            .entry("stream1".to_string())
8361                            .or_default()
8362                            .push(
8363                                (Box::new(move |_: &Value| {
8364                                    called_clone.store(true, Ordering::SeqCst);
8365                                })
8366                                    as Box<dyn Fn(&Value) + Send + Sync>)
8367                                    .into(),
8368                            );
8369                    }
8370
8371                    let msg = json!({
8372                        "stream": "stream1",
8373                        "data": { "key": "value" }
8374                    })
8375                    .to_string();
8376
8377                    ws.on_message(msg, conn.clone()).await;
8378
8379                    assert!(called.load(Ordering::SeqCst));
8380                });
8381            }
8382
8383            #[test]
8384            fn invokes_all_registered_callbacks() {
8385                TOKIO_SHARED_RT.block_on(async {
8386                    let ws: Arc<WebsocketStreams> =
8387                        create_websocket_streams(Some("ws://example.com"), None, None);
8388                    let conn = &ws.common.connection_pool[0];
8389                    let counter = Arc::new(AtomicUsize::new(0));
8390
8391                    {
8392                        let mut st = conn.state.lock().await;
8393                        let entry = st.stream_callbacks.entry("s".into()).or_default();
8394                        let c1 = counter.clone();
8395                        entry.push(
8396                            (Box::new(move |_: &Value| {
8397                                c1.fetch_add(1, Ordering::SeqCst);
8398                            }) as Box<dyn Fn(&Value) + Send + Sync>)
8399                                .into(),
8400                        );
8401                        let c2 = counter.clone();
8402                        entry.push(
8403                            (Box::new(move |_: &Value| {
8404                                c2.fetch_add(1, Ordering::SeqCst);
8405                            }) as Box<dyn Fn(&Value) + Send + Sync>)
8406                                .into(),
8407                        );
8408                    }
8409
8410                    let msg = json!({"stream":"s","data":42}).to_string();
8411                    ws.on_message(msg, conn.clone()).await;
8412
8413                    assert_eq!(counter.load(Ordering::SeqCst), 2);
8414                });
8415            }
8416
8417            #[test]
8418            fn handles_null_data_field() {
8419                TOKIO_SHARED_RT.block_on(async {
8420                    let ws: Arc<WebsocketStreams> =
8421                        create_websocket_streams(Some("ws://example.com"), None, None);
8422                    let conn = &ws.common.connection_pool[0];
8423                    let called = Arc::new(AtomicUsize::new(0));
8424                    {
8425                        let mut st = conn.state.lock().await;
8426                        st.stream_callbacks.entry("n".into()).or_default().push(
8427                            (Box::new({
8428                                let c = called.clone();
8429                                move |data: &Value| {
8430                                    if data.is_null() {
8431                                        c.fetch_add(1, Ordering::SeqCst);
8432                                    }
8433                                }
8434                            }) as Box<dyn Fn(&Value) + Send + Sync>)
8435                                .into(),
8436                        );
8437                    }
8438                    let msg = json!({"stream":"n","data":null}).to_string();
8439                    ws.on_message(msg, conn.clone()).await;
8440                    assert_eq!(called.load(Ordering::SeqCst), 1);
8441                });
8442            }
8443
8444            #[test]
8445            fn with_invalid_json_does_not_panic() {
8446                TOKIO_SHARED_RT.block_on(async {
8447                    let ws: Arc<WebsocketStreams> =
8448                        create_websocket_streams(Some("ws://example.com"), None, None);
8449                    let conn = &ws.common.connection_pool[0];
8450                    let bad = "not a json";
8451                    ws.on_message(bad.to_string(), conn.clone()).await;
8452                });
8453            }
8454
8455            #[test]
8456            fn without_stream_field_does_nothing() {
8457                TOKIO_SHARED_RT.block_on(async {
8458                    let ws: Arc<WebsocketStreams> =
8459                        create_websocket_streams(Some("ws://example.com"), None, None);
8460                    let conn = &ws.common.connection_pool[0];
8461                    let msg = json!({ "data": { "foo": 1 } }).to_string();
8462                    ws.on_message(msg, conn.clone()).await;
8463                });
8464            }
8465
8466            #[test]
8467            fn with_unregistered_stream_does_not_panic() {
8468                TOKIO_SHARED_RT.block_on(async {
8469                    let ws: Arc<WebsocketStreams> =
8470                        create_websocket_streams(Some("ws://example.com"), None, None);
8471                    let conn = &ws.common.connection_pool[0];
8472                    let msg = json!({
8473                        "stream": "nope",
8474                        "data": { "foo": 1 }
8475                    })
8476                    .to_string();
8477                    ws.on_message(msg, conn.clone()).await;
8478                });
8479            }
8480
8481            #[test]
8482            fn invokes_registered_callback_with_url_path_key() {
8483                TOKIO_SHARED_RT.block_on(async {
8484                    let ws: Arc<WebsocketStreams> =
8485                        create_websocket_streams(Some("ws://example.com"), None, None);
8486                    let conn = &ws.common.connection_pool[0];
8487
8488                    {
8489                        let mut st = conn.state.lock().await;
8490                        st.url_path = Some("path1".to_string());
8491                    }
8492
8493                    let called = Arc::new(AtomicBool::new(false));
8494                    let called_clone = called.clone();
8495
8496                    {
8497                        let mut st = conn.state.lock().await;
8498                        st.stream_callbacks
8499                            .entry("path1::stream1".to_string())
8500                            .or_default()
8501                            .push(
8502                                (Box::new(move |_: &Value| {
8503                                    called_clone.store(true, Ordering::SeqCst);
8504                                })
8505                                    as Box<dyn Fn(&Value) + Send + Sync>)
8506                                    .into(),
8507                            );
8508                    }
8509
8510                    let msg = json!({
8511                        "stream": "stream1",
8512                        "data": { "key": "value" }
8513                    })
8514                    .to_string();
8515
8516                    ws.on_message(msg, conn.clone()).await;
8517
8518                    assert!(called.load(Ordering::SeqCst));
8519                });
8520            }
8521
8522            #[test]
8523            fn does_not_invoke_callback_when_url_path_mismatch() {
8524                TOKIO_SHARED_RT.block_on(async {
8525                    let ws: Arc<WebsocketStreams> =
8526                        create_websocket_streams(Some("ws://example.com"), None, None);
8527                    let conn = &ws.common.connection_pool[0];
8528
8529                    {
8530                        let mut st = conn.state.lock().await;
8531                        st.url_path = Some("path2".to_string());
8532                    }
8533
8534                    let called = Arc::new(AtomicBool::new(false));
8535                    let called_clone = called.clone();
8536
8537                    {
8538                        let mut st = conn.state.lock().await;
8539                        st.stream_callbacks
8540                            .entry("path1::stream1".to_string())
8541                            .or_default()
8542                            .push(
8543                                (Box::new(move |_: &Value| {
8544                                    called_clone.store(true, Ordering::SeqCst);
8545                                })
8546                                    as Box<dyn Fn(&Value) + Send + Sync>)
8547                                    .into(),
8548                            );
8549                    }
8550
8551                    let msg = json!({
8552                        "stream": "stream1",
8553                        "data": { "key": "value" }
8554                    })
8555                    .to_string();
8556
8557                    ws.on_message(msg, conn.clone()).await;
8558
8559                    assert!(!called.load(Ordering::SeqCst));
8560                });
8561            }
8562
8563            #[test]
8564            fn invokes_only_callbacks_for_current_url_path_when_both_exist() {
8565                TOKIO_SHARED_RT.block_on(async {
8566                    let ws: Arc<WebsocketStreams> =
8567                        create_websocket_streams(Some("ws://example.com"), None, None);
8568                    let conn = &ws.common.connection_pool[0];
8569
8570                    {
8571                        let mut st = conn.state.lock().await;
8572                        st.url_path = Some("path1".to_string());
8573                    }
8574
8575                    let c1 = Arc::new(AtomicUsize::new(0));
8576                    let c2 = Arc::new(AtomicUsize::new(0));
8577
8578                    {
8579                        let mut st = conn.state.lock().await;
8580
8581                        let a = c1.clone();
8582                        st.stream_callbacks
8583                            .entry("path1::s".to_string())
8584                            .or_default()
8585                            .push(
8586                                (Box::new(move |_: &Value| {
8587                                    a.fetch_add(1, Ordering::SeqCst);
8588                                })
8589                                    as Box<dyn Fn(&Value) + Send + Sync>)
8590                                    .into(),
8591                            );
8592
8593                        let b = c2.clone();
8594                        st.stream_callbacks
8595                            .entry("path2::s".to_string())
8596                            .or_default()
8597                            .push(
8598                                (Box::new(move |_: &Value| {
8599                                    b.fetch_add(1, Ordering::SeqCst);
8600                                })
8601                                    as Box<dyn Fn(&Value) + Send + Sync>)
8602                                    .into(),
8603                            );
8604                    }
8605
8606                    let msg = json!({"stream":"s","data":42}).to_string();
8607                    ws.on_message(msg, conn.clone()).await;
8608
8609                    assert_eq!(c1.load(Ordering::SeqCst), 1);
8610                    assert_eq!(c2.load(Ordering::SeqCst), 0);
8611                });
8612            }
8613        }
8614
8615        mod get_reconnect_url {
8616            use super::*;
8617
8618            #[test]
8619            fn single_stream_reconnect_url() {
8620                TOKIO_SHARED_RT.block_on(async {
8621                    let ws: Arc<WebsocketStreams> =
8622                        create_websocket_streams(Some("ws://example.com"), None, None);
8623                    let c0 = ws.common.connection_pool[0].clone();
8624                    {
8625                        let mut map = ws.connection_streams.lock().await;
8626                        map.insert("s1".to_string(), c0.clone());
8627                    }
8628                    let url = ws.get_reconnect_url("default_url".into(), c0).await;
8629                    assert_eq!(url, "ws://example.com/stream?streams=s1");
8630                });
8631            }
8632
8633            #[test]
8634            fn multiple_streams_same_connection() {
8635                TOKIO_SHARED_RT.block_on(async {
8636                    let ws: Arc<WebsocketStreams> =
8637                        create_websocket_streams(Some("ws://example.com"), None, None);
8638                    let c0 = ws.common.connection_pool[0].clone();
8639                    {
8640                        let mut map = ws.connection_streams.lock().await;
8641                        map.insert("a".to_string(), c0.clone());
8642                        map.insert("b".to_string(), c0.clone());
8643                    }
8644                    let url = ws.get_reconnect_url("default_url".into(), c0).await;
8645                    let suffix = url
8646                        .strip_prefix("ws://example.com/stream?streams=")
8647                        .unwrap();
8648                    let parts: Vec<_> = suffix.split('&').next().unwrap().split('/').collect();
8649                    let set = parts.into_iter().collect::<std::collections::HashSet<_>>();
8650                    assert_eq!(set, ["a", "b"].iter().copied().collect());
8651                });
8652            }
8653
8654            #[test]
8655            fn reconnect_url_with_time_unit() {
8656                TOKIO_SHARED_RT.block_on(async {
8657                    let mut ws: Arc<WebsocketStreams> =
8658                        create_websocket_streams(Some("ws://example.com"), None, None);
8659                    Arc::get_mut(&mut ws).unwrap().configuration.time_unit =
8660                        Some(TimeUnit::Microsecond);
8661                    let c0 = ws.common.connection_pool[0].clone();
8662                    {
8663                        let mut map = ws.connection_streams.lock().await;
8664                        map.insert("x".to_string(), c0.clone());
8665                    }
8666                    let url = ws.get_reconnect_url("default_url".into(), c0).await;
8667                    assert_eq!(
8668                        url,
8669                        "ws://example.com/stream?streams=x&timeUnit=microsecond"
8670                    );
8671                });
8672            }
8673
8674            #[test]
8675            fn reconnect_url_uses_url_path_from_connection_state() {
8676                TOKIO_SHARED_RT.block_on(async {
8677                    let ws: Arc<WebsocketStreams> =
8678                        create_websocket_streams(Some("ws://example.com"), None, None);
8679                    let c0 = ws.common.connection_pool[0].clone();
8680
8681                    {
8682                        let mut st = c0.state.lock().await;
8683                        st.url_path = Some("path1".to_string());
8684                    }
8685
8686                    {
8687                        let mut map = ws.connection_streams.lock().await;
8688                        map.insert("path1::s1".to_string(), c0.clone());
8689                    }
8690
8691                    let url = ws.get_reconnect_url("default_url".into(), c0).await;
8692                    assert_eq!(url, "ws://example.com/path1/stream?streams=s1");
8693                });
8694            }
8695
8696            #[test]
8697            fn reconnect_url_strips_prefix_from_multiple_keys_with_url_path() {
8698                TOKIO_SHARED_RT.block_on(async {
8699                    let ws: Arc<WebsocketStreams> =
8700                        create_websocket_streams(Some("ws://example.com"), None, None);
8701                    let c0 = ws.common.connection_pool[0].clone();
8702
8703                    {
8704                        let mut st = c0.state.lock().await;
8705                        st.url_path = Some("path1".to_string());
8706                    }
8707
8708                    {
8709                        let mut map = ws.connection_streams.lock().await;
8710                        map.insert("path1::a".to_string(), c0.clone());
8711                        map.insert("path1::b".to_string(), c0.clone());
8712                    }
8713
8714                    let url = ws.get_reconnect_url("default_url".into(), c0).await;
8715
8716                    let suffix = url
8717                        .strip_prefix("ws://example.com/path1/stream?streams=")
8718                        .unwrap();
8719                    let parts: Vec<_> = suffix.split('&').next().unwrap().split('/').collect();
8720                    let set = parts.into_iter().collect::<std::collections::HashSet<_>>();
8721                    assert_eq!(set, ["a", "b"].iter().copied().collect());
8722                });
8723            }
8724
8725            #[test]
8726            fn reconnect_url_with_url_path_and_time_unit() {
8727                TOKIO_SHARED_RT.block_on(async {
8728                    let mut ws: Arc<WebsocketStreams> =
8729                        create_websocket_streams(Some("ws://example.com"), None, None);
8730                    Arc::get_mut(&mut ws).unwrap().configuration.time_unit =
8731                        Some(TimeUnit::Microsecond);
8732
8733                    let c0 = ws.common.connection_pool[0].clone();
8734
8735                    {
8736                        let mut st = c0.state.lock().await;
8737                        st.url_path = Some("path1".to_string());
8738                    }
8739
8740                    {
8741                        let mut map = ws.connection_streams.lock().await;
8742                        map.insert("path1::x".to_string(), c0.clone());
8743                    }
8744
8745                    let url = ws.get_reconnect_url("default_url".into(), c0).await;
8746                    assert_eq!(
8747                        url,
8748                        "ws://example.com/path1/stream?streams=x&timeUnit=microsecond"
8749                    );
8750                });
8751            }
8752
8753            #[test]
8754            fn reconnect_url_ignores_streams_from_other_connections_even_if_same_path_prefix() {
8755                TOKIO_SHARED_RT.block_on(async {
8756                    let ws: Arc<WebsocketStreams> =
8757                        create_websocket_streams(Some("ws://example.com"), None, None);
8758                    let c0 = ws.common.connection_pool[0].clone();
8759                    let c1 = ws.common.connection_pool[1].clone();
8760
8761                    {
8762                        let mut st = c0.state.lock().await;
8763                        st.url_path = Some("path1".to_string());
8764                    }
8765                    {
8766                        let mut st = c1.state.lock().await;
8767                        st.url_path = Some("path1".to_string());
8768                    }
8769
8770                    {
8771                        let mut map = ws.connection_streams.lock().await;
8772                        map.insert("path1::a".to_string(), c0.clone());
8773                        map.insert("path1::b".to_string(), c1.clone());
8774                    }
8775
8776                    let url = ws.get_reconnect_url("default_url".into(), c0).await;
8777
8778                    let suffix = url
8779                        .strip_prefix("ws://example.com/path1/stream?streams=")
8780                        .unwrap();
8781                    let parts: Vec<_> = suffix.split('&').next().unwrap().split('/').collect();
8782                    let set = parts.into_iter().collect::<std::collections::HashSet<_>>();
8783                    assert_eq!(set, ["a"].iter().copied().collect());
8784                });
8785            }
8786        }
8787    }
8788
8789    mod websocket_stream {
8790        use super::*;
8791
8792        mod on {
8793            use super::*;
8794
8795            #[test]
8796            fn registers_callback_and_stream_callback_for_websocket_streams() {
8797                TOKIO_SHARED_RT.block_on(async {
8798                    let ws_base = create_websocket_streams(Some("example.com"), None, None);
8799                    let stream_name = "s1".to_string();
8800                    let conn = ws_base.common.connection_pool[0].clone();
8801
8802                    let key = ws_base.stream_key(&stream_name, None);
8803
8804                    {
8805                        let mut map = ws_base.connection_streams.lock().await;
8806                        map.insert(key.clone(), conn.clone());
8807                    }
8808                    {
8809                        let mut state = conn.state.lock().await;
8810                        state.stream_callbacks.insert(key.clone(), Vec::new());
8811                    }
8812
8813                    let stream = Arc::new(WebsocketStream::<Value> {
8814                        websocket_base: WebsocketBase::WebsocketStreams(ws_base.clone()),
8815                        stream_or_id: stream_name.clone(),
8816                        callback: Mutex::new(None),
8817                        url_path: None,
8818                        id: None,
8819                        _phantom: PhantomData,
8820                    });
8821
8822                    stream.on("message", |_| {}).await;
8823
8824                    let cb_guard = stream.callback.lock().await;
8825                    assert!(cb_guard.is_some());
8826
8827                    let cbs = {
8828                        let state = conn.state.lock().await;
8829                        state.stream_callbacks.get(&key).unwrap().clone()
8830                    };
8831                    assert_eq!(cbs.len(), 1);
8832                });
8833            }
8834
8835            #[test]
8836            fn message_twice_registers_two_wrappers_for_websocket_streams() {
8837                TOKIO_SHARED_RT.block_on(async {
8838                    let ws_base = create_websocket_streams(Some("example.com"), None, None);
8839                    let stream_name = "s2".to_string();
8840                    let conn = ws_base.common.connection_pool[0].clone();
8841
8842                    let key = ws_base.stream_key(&stream_name, None);
8843
8844                    {
8845                        let mut map = ws_base.connection_streams.lock().await;
8846                        map.insert(key.clone(), conn.clone());
8847                    }
8848                    {
8849                        let mut state = conn.state.lock().await;
8850                        state.stream_callbacks.insert(key.clone(), Vec::new());
8851                    }
8852
8853                    let stream = Arc::new(WebsocketStream::<Value> {
8854                        websocket_base: WebsocketBase::WebsocketStreams(ws_base.clone()),
8855                        stream_or_id: stream_name.clone(),
8856                        url_path: None,
8857                        callback: Mutex::new(None),
8858                        id: None,
8859                        _phantom: PhantomData,
8860                    });
8861
8862                    stream.on("message", |_| {}).await;
8863                    stream.on("message", |_| {}).await;
8864
8865                    let state = conn.state.lock().await;
8866                    let callbacks = state.stream_callbacks.get(&key).unwrap();
8867                    assert_eq!(callbacks.len(), 2);
8868                });
8869            }
8870
8871            #[test]
8872            fn ignores_non_message_event_for_websocket_streams() {
8873                TOKIO_SHARED_RT.block_on(async {
8874                    let ws_base = create_websocket_streams(Some("example.com"), None, None);
8875                    let stream = Arc::new(WebsocketStream::<Value> {
8876                        websocket_base: WebsocketBase::WebsocketStreams(ws_base.clone()),
8877                        stream_or_id: "s".into(),
8878                        url_path: None,
8879                        callback: Mutex::new(None),
8880                        id: None,
8881                        _phantom: PhantomData,
8882                    });
8883                    stream.on("open", |_| {}).await;
8884                    let guard = stream.callback.lock().await;
8885                    assert!(guard.is_none());
8886                });
8887            }
8888
8889            #[test]
8890            fn registers_callback_and_stream_callback_for_websocket_api() {
8891                TOKIO_SHARED_RT.block_on(async {
8892                    let ws_base = create_websocket_api(None, None, None);
8893
8894                    {
8895                        let mut stream_callbacks = ws_base.stream_callbacks.lock().await;
8896                        stream_callbacks.insert("id1".to_string(), Vec::new());
8897                    }
8898
8899                    let stream = Arc::new(WebsocketStream::<Value> {
8900                        websocket_base: WebsocketBase::WebsocketApi(ws_base.clone()),
8901                        stream_or_id: "id1".to_string(),
8902                        url_path: None,
8903                        callback: Mutex::new(None),
8904                        id: None,
8905                        _phantom: PhantomData,
8906                    });
8907
8908                    let called = Arc::new(Mutex::new(false));
8909                    let called_clone = called.clone();
8910                    stream
8911                        .on("message", move |v: Value| {
8912                            let mut lock = called_clone.blocking_lock();
8913                            *lock = v == Value::String("x".into());
8914                        })
8915                        .await;
8916
8917                    let cb_guard = stream.callback.lock().await;
8918                    assert!(cb_guard.is_some());
8919
8920                    let stream_callbacks = ws_base.stream_callbacks.lock().await;
8921                    let callbacks = stream_callbacks.get("id1").unwrap();
8922                    assert_eq!(callbacks.len(), 1);
8923                });
8924            }
8925
8926            #[test]
8927            fn message_twice_registers_two_wrappers_for_websocket_api() {
8928                TOKIO_SHARED_RT.block_on(async {
8929                    let ws_base = create_websocket_api(None, None, None);
8930
8931                    let stream = Arc::new(WebsocketStream::<Value> {
8932                        websocket_base: WebsocketBase::WebsocketApi(ws_base.clone()),
8933                        stream_or_id: "id2".to_string(),
8934                        url_path: None,
8935                        callback: Mutex::new(None),
8936                        id: None,
8937                        _phantom: PhantomData,
8938                    });
8939
8940                    stream.on("message", |_| {}).await;
8941                    stream.on("message", |_| {}).await;
8942
8943                    let stream_callbacks = ws_base.stream_callbacks.lock().await;
8944                    let callbacks = stream_callbacks.get("id2").unwrap();
8945                    assert_eq!(callbacks.len(), 2);
8946                });
8947            }
8948
8949            #[test]
8950            fn ignores_non_message_event_for_websocket_api() {
8951                TOKIO_SHARED_RT.block_on(async {
8952                    let ws_base = create_websocket_api(None, None, None);
8953
8954                    let stream = Arc::new(WebsocketStream::<Value> {
8955                        websocket_base: WebsocketBase::WebsocketApi(ws_base.clone()),
8956                        stream_or_id: "id3".into(),
8957                        url_path: None,
8958                        callback: Mutex::new(None),
8959                        id: None,
8960                        _phantom: PhantomData,
8961                    });
8962
8963                    stream.on("open", |_| {}).await;
8964
8965                    let guard = stream.callback.lock().await;
8966                    assert!(guard.is_none());
8967
8968                    let stream_callbacks = ws_base.stream_callbacks.lock().await;
8969                    assert!(stream_callbacks.get("id3").is_none());
8970                    assert!(stream_callbacks.is_empty());
8971                });
8972            }
8973
8974            #[test]
8975            fn registers_callback_for_websocket_streams_with_url_path() {
8976                TOKIO_SHARED_RT.block_on(async {
8977                    let ws_base = create_websocket_streams(Some("example.com"), None, None);
8978                    let stream_name = "s1".to_string();
8979                    let conn = ws_base.common.connection_pool[0].clone();
8980
8981                    let key = ws_base.stream_key(&stream_name, Some("path1"));
8982
8983                    {
8984                        let mut map = ws_base.connection_streams.lock().await;
8985                        map.insert(key.clone(), conn.clone());
8986                    }
8987                    {
8988                        let mut state = conn.state.lock().await;
8989                        state.stream_callbacks.insert(key.clone(), Vec::new());
8990                    }
8991
8992                    let stream = Arc::new(WebsocketStream::<Value> {
8993                        websocket_base: WebsocketBase::WebsocketStreams(ws_base.clone()),
8994                        stream_or_id: stream_name.clone(),
8995                        url_path: Some("path1".to_string()),
8996                        callback: Mutex::new(None),
8997                        id: None,
8998                        _phantom: PhantomData,
8999                    });
9000
9001                    stream.on("message", |_| {}).await;
9002
9003                    let cb_guard = stream.callback.lock().await;
9004                    assert!(cb_guard.is_some());
9005
9006                    let callbacks = {
9007                        let state = conn.state.lock().await;
9008                        state.stream_callbacks.get(&key).unwrap().clone()
9009                    };
9010                    assert_eq!(callbacks.len(), 1);
9011                });
9012            }
9013
9014            #[test]
9015            fn url_path_routes_callback_to_correct_key_when_same_stream_name_used() {
9016                TOKIO_SHARED_RT.block_on(async {
9017                    let ws_base = create_websocket_streams(Some("example.com"), None, None);
9018                    let stream_name = "s1".to_string();
9019                    let conn = ws_base.common.connection_pool[0].clone();
9020
9021                    let key1 = ws_base.stream_key(&stream_name, Some("path1"));
9022                    let key2 = ws_base.stream_key(&stream_name, Some("path2"));
9023
9024                    {
9025                        let mut map = ws_base.connection_streams.lock().await;
9026                        map.insert(key1.clone(), conn.clone());
9027                        map.insert(key2.clone(), conn.clone());
9028                    }
9029                    {
9030                        let mut state = conn.state.lock().await;
9031                        state.stream_callbacks.insert(key1.clone(), Vec::new());
9032                        state.stream_callbacks.insert(key2.clone(), Vec::new());
9033                    }
9034
9035                    let stream_path1 = Arc::new(WebsocketStream::<Value> {
9036                        websocket_base: WebsocketBase::WebsocketStreams(ws_base.clone()),
9037                        stream_or_id: stream_name.clone(),
9038                        url_path: Some("path1".to_string()),
9039                        callback: Mutex::new(None),
9040                        id: None,
9041                        _phantom: PhantomData,
9042                    });
9043
9044                    let stream_path2 = Arc::new(WebsocketStream::<Value> {
9045                        websocket_base: WebsocketBase::WebsocketStreams(ws_base.clone()),
9046                        stream_or_id: stream_name.clone(),
9047                        url_path: Some("path2".to_string()),
9048                        callback: Mutex::new(None),
9049                        id: None,
9050                        _phantom: PhantomData,
9051                    });
9052
9053                    stream_path1.on("message", |_| {}).await;
9054                    stream_path2.on("message", |_| {}).await;
9055
9056                    let state = conn.state.lock().await;
9057                    assert_eq!(state.stream_callbacks.get(&key1).unwrap().len(), 1);
9058                    assert_eq!(state.stream_callbacks.get(&key2).unwrap().len(), 1);
9059                });
9060            }
9061        }
9062
9063        mod on_message {
9064            use super::*;
9065
9066            #[test]
9067            fn on_message_registers_callback_for_websocket_streams() {
9068                TOKIO_SHARED_RT.block_on(async {
9069                    let ws_base = create_websocket_streams(Some("example.com"), None, None);
9070                    let stream_name = "s".to_string();
9071                    let conn = ws_base.common.connection_pool[0].clone();
9072                    {
9073                        let mut map = ws_base.connection_streams.lock().await;
9074                        map.insert(stream_name.clone(), conn.clone());
9075                    }
9076                    {
9077                        let mut state = conn.state.lock().await;
9078                        state
9079                            .stream_callbacks
9080                            .insert(stream_name.clone(), Vec::new());
9081                    }
9082                    let stream = Arc::new(WebsocketStream::<Value> {
9083                        websocket_base: WebsocketBase::WebsocketStreams(ws_base.clone()),
9084                        stream_or_id: stream_name.clone(),
9085                        url_path: None,
9086                        callback: Mutex::new(None),
9087                        id: None,
9088                        _phantom: PhantomData,
9089                    });
9090                    stream.on_message(|_v| {});
9091                    let callbacks = &conn.state.lock().await.stream_callbacks[&stream_name];
9092                    assert_eq!(callbacks.len(), 1);
9093                });
9094            }
9095
9096            #[test]
9097            fn on_message_twice_registers_two_callbacks_for_websocket_streams() {
9098                TOKIO_SHARED_RT.block_on(async {
9099                    let ws_base = create_websocket_streams(Some("example.com"), None, None);
9100                    let stream_name = "s".to_string();
9101                    let conn = ws_base.common.connection_pool[0].clone();
9102                    {
9103                        let mut map = ws_base.connection_streams.lock().await;
9104                        map.insert(stream_name.clone(), conn.clone());
9105                    }
9106                    {
9107                        let mut state = conn.state.lock().await;
9108                        state
9109                            .stream_callbacks
9110                            .insert(stream_name.clone(), Vec::new());
9111                    }
9112                    let stream = Arc::new(WebsocketStream::<Value> {
9113                        websocket_base: WebsocketBase::WebsocketStreams(ws_base.clone()),
9114                        stream_or_id: stream_name.clone(),
9115                        url_path: None,
9116                        callback: Mutex::new(None),
9117                        id: None,
9118                        _phantom: PhantomData,
9119                    });
9120                    stream.on_message(|_v| {});
9121                    stream.on_message(|_v| {});
9122                    let callbacks = &conn.state.lock().await.stream_callbacks[&stream_name];
9123                    assert_eq!(callbacks.len(), 2);
9124                });
9125            }
9126
9127            #[test]
9128            fn on_message_registers_callback_for_websocket_api() {
9129                TOKIO_SHARED_RT.block_on(async {
9130                    let ws_base = create_websocket_api(None, None, None);
9131                    let identifier = "id1".to_string();
9132
9133                    let stream = Arc::new(WebsocketStream::<Value> {
9134                        websocket_base: WebsocketBase::WebsocketApi(ws_base.clone()),
9135                        stream_or_id: identifier.clone(),
9136                        url_path: None,
9137                        callback: Mutex::new(None),
9138                        id: None,
9139                        _phantom: PhantomData,
9140                    });
9141
9142                    stream.on_message(|_v: Value| {});
9143
9144                    let stream_callbacks = ws_base.stream_callbacks.lock().await;
9145                    let callbacks = stream_callbacks.get(&identifier).unwrap();
9146                    assert_eq!(callbacks.len(), 1);
9147                });
9148            }
9149
9150            #[test]
9151            fn on_message_twice_registers_two_callbacks_for_websocket_api() {
9152                TOKIO_SHARED_RT.block_on(async {
9153                    let ws_base = create_websocket_api(None, None, None);
9154                    let identifier = "id2".to_string();
9155
9156                    let stream = Arc::new(WebsocketStream::<Value> {
9157                        websocket_base: WebsocketBase::WebsocketApi(ws_base.clone()),
9158                        stream_or_id: identifier.clone(),
9159                        url_path: None,
9160                        callback: Mutex::new(None),
9161                        id: None,
9162                        _phantom: PhantomData,
9163                    });
9164
9165                    stream.on_message(|_v: Value| {});
9166                    stream.on_message(|_v: Value| {});
9167
9168                    let stream_callbacks = ws_base.stream_callbacks.lock().await;
9169                    let callbacks = stream_callbacks.get(&identifier).unwrap();
9170                    assert_eq!(callbacks.len(), 2);
9171                });
9172            }
9173        }
9174
9175        mod unsubscribe {
9176            use super::*;
9177
9178            #[test]
9179            fn without_callback_does_nothing() {
9180                TOKIO_SHARED_RT.block_on(async {
9181                    let ws_base = create_websocket_streams(Some("example.com"), None, None);
9182                    let stream_name = "s1".to_string();
9183                    let conn = ws_base.common.connection_pool[0].clone();
9184                    {
9185                        let mut map = ws_base.connection_streams.lock().await;
9186                        map.insert(stream_name.clone(), conn.clone());
9187                    }
9188                    let mut state = conn.state.lock().await;
9189                    state.stream_callbacks.insert(stream_name.clone(), vec![]);
9190                    drop(state);
9191                    let stream = Arc::new(WebsocketStream::<Value> {
9192                        websocket_base: WebsocketBase::WebsocketStreams(ws_base.clone()),
9193                        stream_or_id: stream_name.clone(),
9194                        url_path: None,
9195                        callback: Mutex::new(None),
9196                        id: None,
9197                        _phantom: PhantomData,
9198                    });
9199                    stream.unsubscribe().await;
9200                    let state = conn.state.lock().await;
9201                    assert!(state.stream_callbacks.contains_key(&stream_name));
9202                });
9203            }
9204
9205            #[test]
9206            fn removes_registered_callback_and_clears_state() {
9207                TOKIO_SHARED_RT.block_on(async {
9208                    let ws_base = create_websocket_streams(Some("example.com"), None, None);
9209                    let stream_name = "s2".to_string();
9210                    let conn = ws_base.common.connection_pool[0].clone();
9211                    {
9212                        let mut map = ws_base.connection_streams.lock().await;
9213                        map.insert(stream_name.clone(), conn.clone());
9214                    }
9215                    {
9216                        let mut state = conn.state.lock().await;
9217                        state
9218                            .stream_callbacks
9219                            .insert(stream_name.clone(), Vec::new());
9220                    }
9221                    let stream = Arc::new(WebsocketStream::<Value> {
9222                        websocket_base: WebsocketBase::WebsocketStreams(ws_base.clone()),
9223                        stream_or_id: stream_name.clone(),
9224                        url_path: None,
9225                        callback: Mutex::new(None),
9226                        id: None,
9227                        _phantom: PhantomData,
9228                    });
9229                    stream.on("message", |_| {}).await;
9230                    {
9231                        let guard = stream.callback.lock().await;
9232                        assert!(guard.is_some());
9233                    }
9234                    stream.unsubscribe().await;
9235                    sleep(Duration::from_millis(10)).await;
9236                    let guard = stream.callback.lock().await;
9237                    assert!(guard.is_none());
9238                    let state = conn.state.lock().await;
9239                    assert!(
9240                        state
9241                            .stream_callbacks
9242                            .get(&stream_name)
9243                            .is_none_or(std::vec::Vec::is_empty)
9244                    );
9245                });
9246            }
9247
9248            #[test]
9249            fn without_callback_does_nothing_for_websocket_api() {
9250                TOKIO_SHARED_RT.block_on(async {
9251                    let ws_base = create_websocket_api(None, None, None);
9252                    let identifier = "id1".to_string();
9253
9254                    {
9255                        let mut stream_callbacks = ws_base.stream_callbacks.lock().await;
9256                        stream_callbacks.insert(identifier.clone(), Vec::new());
9257                    }
9258
9259                    let stream = Arc::new(WebsocketStream::<Value> {
9260                        websocket_base: WebsocketBase::WebsocketApi(ws_base.clone()),
9261                        stream_or_id: identifier.clone(),
9262                        url_path: None,
9263                        callback: Mutex::new(None),
9264                        id: None,
9265                        _phantom: PhantomData,
9266                    });
9267
9268                    stream.unsubscribe().await;
9269
9270                    let stream_callbacks = ws_base.stream_callbacks.lock().await;
9271                    assert!(stream_callbacks.contains_key(&identifier));
9272                    let callbacks = stream_callbacks.get(&identifier).unwrap();
9273                    assert!(callbacks.is_empty());
9274                });
9275            }
9276
9277            #[test]
9278            fn removes_registered_callback_and_clears_state_for_websocket_api() {
9279                TOKIO_SHARED_RT.block_on(async {
9280                    let ws_base = create_websocket_api(None, None, None);
9281                    let identifier = "id2".to_string();
9282
9283                    {
9284                        let mut stream_callbacks = ws_base.stream_callbacks.lock().await;
9285                        stream_callbacks.insert(identifier.clone(), Vec::new());
9286                    }
9287
9288                    let stream = Arc::new(WebsocketStream::<Value> {
9289                        websocket_base: WebsocketBase::WebsocketApi(ws_base.clone()),
9290                        stream_or_id: identifier.clone(),
9291                        url_path: None,
9292                        callback: Mutex::new(None),
9293                        id: None,
9294                        _phantom: PhantomData,
9295                    });
9296
9297                    stream.on("message", |_| {}).await;
9298
9299                    {
9300                        let stream_callbacks = ws_base.stream_callbacks.lock().await;
9301                        let callbacks = stream_callbacks
9302                            .get(&identifier)
9303                            .expect("Entry for 'id2' should exist");
9304                        assert_eq!(callbacks.len(), 1);
9305                    }
9306
9307                    stream.unsubscribe().await;
9308
9309                    {
9310                        let guard = stream.callback.lock().await;
9311                        assert!(guard.is_none());
9312                    }
9313
9314                    {
9315                        let stream_callbacks = ws_base.stream_callbacks.lock().await;
9316                        let callbacks = stream_callbacks
9317                            .get(&identifier)
9318                            .expect("Entry for 'id2' should still exist");
9319                        assert!(callbacks.is_empty());
9320                    }
9321                });
9322            }
9323        }
9324    }
9325
9326    mod create_stream_handler {
9327        use super::*;
9328
9329        #[test]
9330        fn create_stream_handler_without_id_registers_stream() {
9331            TOKIO_SHARED_RT.block_on(async {
9332                let ws = create_websocket_streams(Some("ws://example.com"), None, None);
9333                let stream_name = "foo".to_string();
9334                let handler = create_stream_handler::<serde_json::Value>(
9335                    WebsocketBase::WebsocketStreams(ws.clone()),
9336                    stream_name.clone(),
9337                    None,
9338                    None,
9339                )
9340                .await;
9341                assert_eq!(handler.stream_or_id, stream_name);
9342                assert!(handler.id.is_none());
9343                let map = ws.connection_streams.lock().await;
9344                assert!(map.contains_key(&stream_name));
9345            });
9346        }
9347
9348        #[test]
9349        fn create_stream_handler_with_custom_string_id_registers_stream_and_id() {
9350            TOKIO_SHARED_RT.block_on(async {
9351                let ws = create_websocket_streams(Some("ws://example.com"), None, None);
9352                let stream_name = "bar".to_string();
9353                let custom_id = StreamId::from("my-custom-id".to_string());
9354                let handler = create_stream_handler::<serde_json::Value>(
9355                    WebsocketBase::WebsocketStreams(ws.clone()),
9356                    stream_name.clone(),
9357                    Some(custom_id.clone()),
9358                    None,
9359                )
9360                .await;
9361                assert_eq!(handler.stream_or_id, stream_name);
9362                assert_eq!(handler.id, Some(custom_id));
9363                let map = ws.connection_streams.lock().await;
9364                assert!(map.contains_key(&stream_name));
9365            });
9366        }
9367
9368        #[test]
9369        fn create_stream_handler_with_custom_integer_id_registers_stream_and_id() {
9370            TOKIO_SHARED_RT.block_on(async {
9371                let ws = create_websocket_streams(Some("ws://example.com"), None, None);
9372                let stream_name = "bar".to_string();
9373                let custom_id = StreamId::from(123u32);
9374                let handler = create_stream_handler::<serde_json::Value>(
9375                    WebsocketBase::WebsocketStreams(ws.clone()),
9376                    stream_name.clone(),
9377                    Some(custom_id.clone()),
9378                    None,
9379                )
9380                .await;
9381                assert_eq!(handler.stream_or_id, stream_name);
9382                assert_eq!(handler.id, Some(custom_id));
9383                let map = ws.connection_streams.lock().await;
9384                assert!(map.contains_key(&stream_name));
9385            });
9386        }
9387
9388        #[test]
9389        fn create_stream_handler_without_id_registers_api_stream() {
9390            TOKIO_SHARED_RT.block_on(async {
9391                let ws_base = create_websocket_api(None, None, None);
9392                let identifier = "foo-api".to_string();
9393
9394                let handler = create_stream_handler::<Value>(
9395                    WebsocketBase::WebsocketApi(ws_base.clone()),
9396                    identifier.clone(),
9397                    None,
9398                    None,
9399                )
9400                .await;
9401
9402                assert_eq!(handler.stream_or_id, identifier);
9403                assert!(handler.id.is_none());
9404            });
9405        }
9406
9407        #[test]
9408        fn create_stream_handler_with_custom_string_id_registers_api_stream_and_id() {
9409            TOKIO_SHARED_RT.block_on(async {
9410                let ws_base = create_websocket_api(None, None, None);
9411                let identifier = "bar-api".to_string();
9412                let custom_id = StreamId::from("custom-123".to_string());
9413
9414                let handler = create_stream_handler::<Value>(
9415                    WebsocketBase::WebsocketApi(ws_base.clone()),
9416                    identifier.clone(),
9417                    Some(custom_id.clone()),
9418                    None,
9419                )
9420                .await;
9421
9422                assert_eq!(handler.stream_or_id, identifier);
9423                assert_eq!(handler.id, Some(custom_id));
9424            });
9425        }
9426
9427        #[test]
9428        fn create_stream_handler_with_custom_integer_id_registers_api_stream_and_id() {
9429            TOKIO_SHARED_RT.block_on(async {
9430                let ws_base = create_websocket_api(None, None, None);
9431                let identifier = "bar-api".to_string();
9432                let custom_id = StreamId::from(123u32);
9433
9434                let handler = create_stream_handler::<Value>(
9435                    WebsocketBase::WebsocketApi(ws_base.clone()),
9436                    identifier.clone(),
9437                    Some(custom_id.clone()),
9438                    None,
9439                )
9440                .await;
9441
9442                assert_eq!(handler.stream_or_id, identifier);
9443                assert_eq!(handler.id, Some(custom_id));
9444            });
9445        }
9446
9447        #[test]
9448        fn websocket_streams_without_url_path_registers_stream_key() {
9449            TOKIO_SHARED_RT.block_on(async {
9450                let ws = create_websocket_streams(Some("ws://example.com"), None, None);
9451                let stream_name = "foo".to_string();
9452
9453                let handler = create_stream_handler::<Value>(
9454                    WebsocketBase::WebsocketStreams(ws.clone()),
9455                    stream_name.clone(),
9456                    None,
9457                    None,
9458                )
9459                .await;
9460
9461                assert_eq!(handler.stream_or_id, stream_name);
9462                assert!(handler.id.is_none());
9463
9464                let map = ws.connection_streams.lock().await;
9465                assert!(map.contains_key("foo"));
9466            });
9467        }
9468
9469        #[test]
9470        fn websocket_streams_with_url_path_registers_prefixed_stream_key() {
9471            TOKIO_SHARED_RT.block_on(async {
9472                let ws = create_websocket_streams(Some("ws://example.com"), None, None);
9473
9474                {
9475                    let conn = ws.common.connection_pool[0].clone();
9476                    let mut st = conn.state.lock().await;
9477                    st.url_path = Some("path1".to_string());
9478                }
9479
9480                let stream_name = "foo".to_string();
9481
9482                let handler = create_stream_handler::<Value>(
9483                    WebsocketBase::WebsocketStreams(ws.clone()),
9484                    stream_name.clone(),
9485                    None,
9486                    Some("path1".to_string()),
9487                )
9488                .await;
9489
9490                assert_eq!(handler.stream_or_id, stream_name);
9491                assert!(handler.id.is_none());
9492
9493                let map = ws.connection_streams.lock().await;
9494                assert!(map.contains_key("path1::foo"));
9495            });
9496        }
9497
9498        #[test]
9499        fn websocket_streams_with_custom_id_preserves_id_and_registers_prefixed_key() {
9500            TOKIO_SHARED_RT.block_on(async {
9501                let ws = create_websocket_streams(Some("ws://example.com"), None, None);
9502
9503                {
9504                    let conn = ws.common.connection_pool[0].clone();
9505                    let mut st = conn.state.lock().await;
9506                    st.url_path = Some("path1".to_string());
9507                }
9508
9509                let stream_name = "bar".to_string();
9510                let custom_id = StreamId::from("my-custom-id".to_string());
9511
9512                let handler = create_stream_handler::<Value>(
9513                    WebsocketBase::WebsocketStreams(ws.clone()),
9514                    stream_name.clone(),
9515                    Some(custom_id.clone()),
9516                    Some("path1".to_string()),
9517                )
9518                .await;
9519
9520                assert_eq!(handler.stream_or_id, stream_name);
9521                assert_eq!(handler.id, Some(custom_id));
9522
9523                let map = ws.connection_streams.lock().await;
9524                assert!(map.contains_key("path1::bar"));
9525            });
9526        }
9527
9528        #[test]
9529        fn websocket_api_does_not_register_stream_in_connection_map() {
9530            TOKIO_SHARED_RT.block_on(async {
9531                let ws_base = create_websocket_api(None, None, None);
9532                let identifier = "foo-api".to_string();
9533
9534                let handler = create_stream_handler::<Value>(
9535                    WebsocketBase::WebsocketApi(ws_base.clone()),
9536                    identifier.clone(),
9537                    None,
9538                    Some("path1".to_string()),
9539                )
9540                .await;
9541
9542                assert_eq!(handler.stream_or_id, identifier);
9543                assert!(handler.id.is_none());
9544            });
9545        }
9546    }
9547
9548    mod websocket_connection_failure_reason {
9549        use super::*;
9550        use std::io::{Error as IoError, ErrorKind};
9551        use tokio_tungstenite::tungstenite::Error as TungsteniteError;
9552
9553        #[test]
9554        fn from_tungstenite_error_classifies_connection_closed() {
9555            let error = TungsteniteError::ConnectionClosed;
9556            let reason = WebsocketConnectionFailureReason::from_tungstenite_error(&error);
9557            assert!(matches!(
9558                reason,
9559                WebsocketConnectionFailureReason::ConnectionReset
9560            ));
9561            assert!(reason.should_reconnect());
9562        }
9563
9564        #[test]
9565        fn from_tungstenite_error_classifies_already_closed() {
9566            let error = TungsteniteError::AlreadyClosed;
9567            let reason = WebsocketConnectionFailureReason::from_tungstenite_error(&error);
9568            assert!(matches!(
9569                reason,
9570                WebsocketConnectionFailureReason::ConnectionReset
9571            ));
9572            assert!(reason.should_reconnect());
9573        }
9574
9575        #[test]
9576        fn from_tungstenite_error_classifies_io_errors() {
9577            // Test ConnectionReset
9578            let io_error = IoError::new(ErrorKind::ConnectionReset, "connection reset");
9579            let error = TungsteniteError::Io(io_error);
9580            let reason = WebsocketConnectionFailureReason::from_tungstenite_error(&error);
9581            assert!(matches!(
9582                reason,
9583                WebsocketConnectionFailureReason::ConnectionReset
9584            ));
9585            assert!(reason.should_reconnect());
9586
9587            // Test ConnectionAborted
9588            let io_error = IoError::new(ErrorKind::ConnectionAborted, "connection aborted");
9589            let error = TungsteniteError::Io(io_error);
9590            let reason = WebsocketConnectionFailureReason::from_tungstenite_error(&error);
9591            assert!(matches!(
9592                reason,
9593                WebsocketConnectionFailureReason::ConnectionReset
9594            ));
9595            assert!(reason.should_reconnect());
9596
9597            // Test TimedOut
9598            let io_error = IoError::new(ErrorKind::TimedOut, "timed out");
9599            let error = TungsteniteError::Io(io_error);
9600            let reason = WebsocketConnectionFailureReason::from_tungstenite_error(&error);
9601            assert!(matches!(
9602                reason,
9603                WebsocketConnectionFailureReason::NetworkInterruption
9604            ));
9605            assert!(reason.should_reconnect());
9606
9607            // Test UnexpectedEof
9608            let io_error = IoError::new(ErrorKind::UnexpectedEof, "unexpected eof");
9609            let error = TungsteniteError::Io(io_error);
9610            let reason = WebsocketConnectionFailureReason::from_tungstenite_error(&error);
9611            assert!(matches!(
9612                reason,
9613                WebsocketConnectionFailureReason::StreamEnded
9614            ));
9615            assert!(reason.should_reconnect());
9616
9617            // Test PermissionDenied
9618            let io_error = IoError::new(ErrorKind::PermissionDenied, "permission denied");
9619            let error = TungsteniteError::Io(io_error);
9620            let reason = WebsocketConnectionFailureReason::from_tungstenite_error(&error);
9621            assert!(matches!(
9622                reason,
9623                WebsocketConnectionFailureReason::AuthenticationFailure
9624            ));
9625            assert!(!reason.should_reconnect());
9626
9627            // Test other IO errors default to NetworkInterruption
9628            let io_error = IoError::other("other error");
9629            let error = TungsteniteError::Io(io_error);
9630            let reason = WebsocketConnectionFailureReason::from_tungstenite_error(&error);
9631            assert!(matches!(
9632                reason,
9633                WebsocketConnectionFailureReason::NetworkInterruption
9634            ));
9635            assert!(reason.should_reconnect());
9636        }
9637
9638        #[test]
9639        fn from_tungstenite_error_classifies_protocol_errors() {
9640            // Protocol error -> ProtocolViolation
9641            use tokio_tungstenite::tungstenite::error::ProtocolError;
9642            let protocol_error = ProtocolError::ResetWithoutClosingHandshake;
9643            let error = TungsteniteError::Protocol(protocol_error);
9644            let reason = WebsocketConnectionFailureReason::from_tungstenite_error(&error);
9645            assert!(matches!(
9646                reason,
9647                WebsocketConnectionFailureReason::UnexpectedClose
9648            ));
9649            assert!(reason.should_reconnect());
9650
9651            // UTF8 error -> ProtocolViolation
9652            let error = TungsteniteError::Utf8;
9653            let reason = WebsocketConnectionFailureReason::from_tungstenite_error(&error);
9654            assert!(matches!(
9655                reason,
9656                WebsocketConnectionFailureReason::ProtocolViolation
9657            ));
9658            assert!(!reason.should_reconnect());
9659        }
9660
9661        #[test]
9662        fn from_close_code_classifies_standard_codes() {
9663            // Normal closure
9664            let reason = WebsocketConnectionFailureReason::from_close_code(1000, false);
9665            assert!(matches!(
9666                reason,
9667                WebsocketConnectionFailureReason::NormalClose
9668            ));
9669            assert!(!reason.should_reconnect());
9670
9671            // Going away (server restart) -> ServerTemporaryError
9672            let reason = WebsocketConnectionFailureReason::from_close_code(1001, false);
9673            assert!(matches!(
9674                reason,
9675                WebsocketConnectionFailureReason::ServerTemporaryError
9676            ));
9677            assert!(reason.should_reconnect());
9678
9679            // Protocol error -> ProtocolViolation
9680            let reason = WebsocketConnectionFailureReason::from_close_code(1002, false);
9681            assert!(matches!(
9682                reason,
9683                WebsocketConnectionFailureReason::ProtocolViolation
9684            ));
9685            assert!(!reason.should_reconnect());
9686
9687            // Abnormal closure -> UnexpectedClose
9688            let reason = WebsocketConnectionFailureReason::from_close_code(1006, false);
9689            assert!(matches!(
9690                reason,
9691                WebsocketConnectionFailureReason::UnexpectedClose
9692            ));
9693            assert!(reason.should_reconnect());
9694
9695            // Policy violation -> PermanentServerError
9696            let reason = WebsocketConnectionFailureReason::from_close_code(1008, false);
9697            assert!(matches!(
9698                reason,
9699                WebsocketConnectionFailureReason::PermanentServerError
9700            ));
9701            assert!(!reason.should_reconnect());
9702
9703            // Server error -> ServerTemporaryError
9704            let reason = WebsocketConnectionFailureReason::from_close_code(1011, false);
9705            assert!(matches!(
9706                reason,
9707                WebsocketConnectionFailureReason::ServerTemporaryError
9708            ));
9709            assert!(reason.should_reconnect());
9710
9711            // TLS handshake failure -> ConfigurationError
9712            let reason = WebsocketConnectionFailureReason::from_close_code(1015, false);
9713            assert!(matches!(
9714                reason,
9715                WebsocketConnectionFailureReason::ConfigurationError
9716            ));
9717            assert!(!reason.should_reconnect());
9718
9719            // Business/application errors -> PermanentServerError
9720            let reason = WebsocketConnectionFailureReason::from_close_code(4000, false);
9721            assert!(matches!(
9722                reason,
9723                WebsocketConnectionFailureReason::PermanentServerError
9724            ));
9725            assert!(!reason.should_reconnect());
9726
9727            let reason = WebsocketConnectionFailureReason::from_close_code(4999, false);
9728            assert!(matches!(
9729                reason,
9730                WebsocketConnectionFailureReason::PermanentServerError
9731            ));
9732            assert!(!reason.should_reconnect());
9733
9734            // Unknown codes default to UnexpectedClose
9735            let reason = WebsocketConnectionFailureReason::from_close_code(9999, false);
9736            assert!(matches!(
9737                reason,
9738                WebsocketConnectionFailureReason::UnexpectedClose
9739            ));
9740            assert!(reason.should_reconnect());
9741        }
9742
9743        #[test]
9744        fn from_close_code_handles_user_initiated() {
9745            // Any code with user_initiated=true should return UserInitiatedClose
9746            let reason = WebsocketConnectionFailureReason::from_close_code(1000, true);
9747            assert!(matches!(
9748                reason,
9749                WebsocketConnectionFailureReason::UserInitiatedClose
9750            ));
9751            assert!(!reason.should_reconnect());
9752
9753            let reason = WebsocketConnectionFailureReason::from_close_code(1006, true);
9754            assert!(matches!(
9755                reason,
9756                WebsocketConnectionFailureReason::UserInitiatedClose
9757            ));
9758            assert!(!reason.should_reconnect());
9759
9760            let reason = WebsocketConnectionFailureReason::from_close_code(4000, true);
9761            assert!(matches!(
9762                reason,
9763                WebsocketConnectionFailureReason::UserInitiatedClose
9764            ));
9765            assert!(!reason.should_reconnect());
9766        }
9767
9768        #[test]
9769        fn should_reconnect_logic() {
9770            // Reconnectable failures
9771            assert!(WebsocketConnectionFailureReason::NetworkInterruption.should_reconnect());
9772            assert!(WebsocketConnectionFailureReason::ConnectionReset.should_reconnect());
9773            assert!(WebsocketConnectionFailureReason::ServerTemporaryError.should_reconnect());
9774            assert!(WebsocketConnectionFailureReason::UnexpectedClose.should_reconnect());
9775            assert!(WebsocketConnectionFailureReason::StreamEnded.should_reconnect());
9776
9777            // Non-reconnectable failures
9778            assert!(!WebsocketConnectionFailureReason::AuthenticationFailure.should_reconnect());
9779            assert!(!WebsocketConnectionFailureReason::ProtocolViolation.should_reconnect());
9780            assert!(!WebsocketConnectionFailureReason::ConfigurationError.should_reconnect());
9781            assert!(!WebsocketConnectionFailureReason::UserInitiatedClose.should_reconnect());
9782            assert!(!WebsocketConnectionFailureReason::PermanentServerError.should_reconnect());
9783            assert!(!WebsocketConnectionFailureReason::NormalClose.should_reconnect());
9784        }
9785
9786        #[test]
9787        fn debug_and_clone_work() {
9788            let reason = WebsocketConnectionFailureReason::NetworkInterruption;
9789            let cloned = reason;
9790            let debug_str = format!("{:?}", reason);
9791
9792            assert!(matches!(
9793                cloned,
9794                WebsocketConnectionFailureReason::NetworkInterruption
9795            ));
9796            assert!(debug_str.contains("NetworkInterruption"));
9797        }
9798    }
9799}